## Introduction

In short, this technique can usually "remove" a dimension from DP state. Before we dive in, let's look at a related problem first:

### Implementation

int l, r, ans;
while(l <= r) {
int mid = l + (r - l) / 2;
// calc() is algorithm A, first is M(mid), second is V(mid)
pair<int, int> res = calc(mid);
if(res.second >= need) ans = res.first - mid * need, l = mid + 1;
else r = mid - 1;
}
cout << ans << '\n';


## Example problems

### BZOJ2654 - Tree

#### Problem description

You're given a weighted, undirected and connected graph $G$ with $N$ nodes, and $M$ edges. Every edge has a weight $w$ and a color $c$, which is either black or white. Please calculate the sum of weights on the minimum spanning tree that contains exactly $K$ white edges.

$1\le N\le 5\times 10^4, N-1\le M\le 10^5, 1\le w\le 100, 0\le c\le 1$

#### Problem analysis

This problem has nothing to do with DP. It's just a practice problem for wqs binary search :)

Let $f(x)$ be the sum of weights on the minimum spanning tree that contains exactly $x$ white edges. Now, let's try to show that $f(x)$ is concave fucntion by showing that $$f(x)-f(x-1)\le f(x+1)-f(x), x\in \mathbb{N}$$ As every time we increase the number of white edges by $1$, we will choose the smallest white edge that is valid. This implies that our "net benefit"(amount of decrease on total weight) of adding the $i^{th}$ white edges will decrease as $i$ goes up, which implies that the inequality above is true.

So now need an algorithm $\mathcal{A}$ to calculate the minimum value of $f(x)+px$ given some value $p$. This can be done by reweighting the white edges (plus $p$ on the original weight) and run Minimum Spanning Tree (MST) algorithm on it. I used Kruskal algorithm in my implementation.

#### Problem solution

Note that the range of binary search is $[-100, 100]$(the maximum weight of $w$). The time complexity is $O(N\log N\log C)$, where $C=200$(the range of binary search).

code
#include <bits/stdc++.h>
#define F first
#define S second
using namespace std;
typedef long long ll;

const int N = (int)5e4 + 5;

struct DSU {
int fa[N], sz[N];

void init(int n) {
for(int i = 0 ; i < n ; ++i) fa[i] = i, sz[i] = 1;
}
int find(int x) {
return x == fa[x] ? x : fa[x] = find(fa[x]);
}
bool merge(int x, int y) {
x = find(x), y = find(y);
if(x == y) return false;
if(sz[x] < sz[y]) swap(x, y);
sz[x] += sz[y];
fa[y] = x;
sz[y] = 0;
return true;
}
} dsu;

struct Edge {
int  u, v, w, c;

bool operator<(const Edge& rhs) const {
return w < rhs.w || (w == rhs.w && c < rhs.c);
}
};
vector<Edge> edges;
int n, m, need;

pair<int, int> calc(int k) {
for(int i = 0 ; i < m ; ++i) if(!edges[i].c) edges[i].w += k;
sort(edges.begin(), edges.end());
dsu.init(n);
int cost = 0, cnt = 0;
for(int i = 0 ; i < m ; ++i) {
if(dsu.merge(edges[i].u, edges[i].v)) {
cost += edges[i].w;
if(!edges[i].c) cnt++;
}
}
for(int i = 0 ; i < m ; ++i) if(!edges[i].c) edges[i].w -= k;
return {cost, cnt};
}

void init() {
cin >> n >> m >> need;
for(int i = 0 ; i < m ; ++i) {
int u, v, w, c; cin >> u >> v >> w >> c;
edges.push_back({u, v, w, c});
}
}
void solve() {
int l = -100, r = 100, ans = 0;
while(l <= r) {
int mid = l + (r - l) / 2;
pair<int, int> res = calc(mid);
if(res.second >= need) ans = res.first - mid * need, l = mid + 1;
else r = mid - 1;
}
cout << ans << '\n';
}

int main() {
ios_base::sync_with_stdio(0), cin.tie(0);
init();
solve();
}



### 洛谷 U72600 - Commando EX

#### Problem description

You're given a sequence $a_1, a_2, \dots, a_N$, please split it in $K$ parts $[l_1, r_1], [l_2, r_2], \dots, [l_K, r_K]$ such that the following cost is minimized: $$\sum_{i=1}^{K}\left( \sum_{j=l_i}^{r_i} a_j \right)^2$$

$2\le K\le N\le 10^5, a_i \le 50$

#### Problem analysis

Let's first write down the DP equations:

• DP state : $dp_{i, j}$ representes the minimum cost splitting $a_1, a_2, \dots, a_i$ into $j$ parts
• DP transition : $dp_{i, j}=\min_{0\le k\lt i} \left( dp_{k, j - 1} + cost(k + 1, i) \right)$
• Final answer : $dp_{N, K}$

If we calculate it directly, the time complexity is $O(N^2K)$, which is really bad... ($N^2K\approx 10^{15}$) So let's try to reduce the time complexity!

First, if we rewrite the transition equation: \begin{align*} dp_{i, j} &= \min_{0\le k\lt i} \left( dp_{k, j - 1} + cost(k + 1, i) \right) \\ &= \min_{0\le k\lt i} \left( dp_{k, j - 1} + (pre_i - pre_k)^2 \right) \\ &= \min_{0\le k\lt i} \left( dp_{k, j - 1} + pre_i^2 + pre_k^2 - 2\cdot pre_i\cdot pre_k \right) \\ &= pre_i^2 + \min_{0\le k\lt i} \left( -2\cdot pre_k \cdot pre_i + dp_{k, j - 1} + pre_k^2 \right) \\ &= pre_i^2 + \min_{0\le k\lt i} \left( m_k\cdot pre_i + b_k \right), m_k=-2\cdot pre_k, b_k=dp_{k, j - 1} + pre_k^2 \end{align*} where the last line above is exactly the form that can be optimized to $O(N\log N)$ with convex hull optimization). Futhermore, we can optimize it to $O(N)$ as the queries($pre_i$) are increasing and the slopes($-2\cdot pre_k$) are decreasing. So now the time complexity is $O(NK)$, which is better but still unacceptable.

Second, let $f(x)=dp_{N, x}$. It's not hard to notice that $f(x)$ is convex. As it is convex, we can use wqs binary search! But we still need to construct an efficient algorithm $\mathcal{A}$ to calculate the minimum value of $f(x)+px$ in order to use wqs. We can use another DP as $\mathcal{A}$:

• DP state : $dp^{'}_i$ represents the minimum cost splitting $a_1, a_2, \dots, a_i$
• DP transition : $dp^{'}_i=pre_i^2 + p + \min_{0\le k\lt i} \left( m_k\cdot pre_i + b_k \right), m_k=-2\cdot pre_k, b_k=dp^{'}_{k} + pre_k^2$
• Final answer : $dp^{'}_N$

Note that the new DP transition is almost same to the old DP transition. The new DP transition have only one dimension and an extra term $p$. This DP can be calculated in $O(N)$ similar to above (convex hull optimization).

Finally, we've successfully reduced the time complexity from $O(N^2K)$ to $O(N\log C)$!

#### Problem solution

Note that I set the range of binary search as $[0, pre_n^2]$. Also note that when calculating the new DP, we need to store the number of parts it splits ($cnt$ array in the code) and when we two transition point with same $dp^{'}$ value, we will choose the one with larger $cnt$ (acutally larger/smaller depends on how you implement the binary search). More details in code.

code
#include <bits/stdc++.h>
#define F first
#define S second
using namespace std;
typedef long long ll;
typedef pair<ll, ll> pll;

const int N = (int)1e5 + 5;
const ll inf = (ll)1e15;

ll n, k, a[N], pre[N], dp[N], cnt[N];

struct CHT {
struct Line {
ll m, b, c;

ll operator()(ll x) { return m * x + b; }
} dq[N * 2];
int l, r;

void init() {
dq[0] = {0, 0, 0};
l = 0, r = 1;
}
bool better_ins(Line& L, Line& L1, Line& L2) {
ll b1 = (L.b - L2.b) * (L2.m - L1.m), b2 = (L2.m - L.m) * (L1.b - L2.b);
return b1 < b2 || (b1 == b2 && L.c > L1.c);
}
bool better_qry(Line& L1, Line& L2, ll x) {
ll b1 = L1(x), b2 = L2(x);
return b1 > b2 || (b1 == b2 && L1.c < L2.c);
}
void insert(Line L) {
while(r - l >= 2 && better_ins(L, dq[r - 1], dq[r - 2])) --r;
dq[r++] = L;
}
pll query(ll x) {
while(r - l >= 2 && better_qry(dq[l], dq[l + 1], x)) ++l;
return {dq[l](x), dq[l].c};
}
} cht;
pll calc(ll x) {
dp[0] = cnt[0] = 0;
cht.init();
for(int i = 1 ; i <= n ; ++i) {
pll qry = cht.query(pre[i]);
dp[i] = pre[i] * pre[i] + x + qry.F;
cnt[i] = qry.S + 1;
cht.insert({-2LL * pre[i], dp[i] + pre[i] * pre[i], cnt[i]});
}
return {cnt[n], dp[n]};
}

void init() {
cin >> n >> k;
for(int i = 1 ; i <= n ; ++i) cin >> a[i], pre[i] = pre[i - 1] + a[i];
}
void solve() {
ll l = 0, r = pre[n] * pre[n], ans = -1;
while(l <= r) {
ll mid = l + (r - l) / 2;
pll res = calc(mid);
if(res.F >= k) ans = res.S - k * mid, l = mid + 1;
else r = mid - 1;
}
cout << ans << '\n';
}

int main() {
ios_base::sync_with_stdio(0), cin.tie(0);
init();
solve();
}



### 2018 ACM-ICPC Nanjing Regional pB - Tournament

#### Problem description

Given $N$ points $a_1, a_2, \dots, a_N$ on a number line, please find $K$ points $s_1, s_2, \dots, s_K$ such that $$\sum_{i=1}^{N}\left( \min_{1\le j \le K} D(a_i, s_j) \right)$$ is minimized, where $D(a_i, s_j)$ is the distance between them.

Note that you only need to calculate the minimum value of it i.e. you don't need to output the points you choose.

$1\le K\le N\le 3\times 10^5, 0=a_1\lt a_2\lt a_N\le 10^9$

#### Problem analysis

First, we write down the DP equations:

• DP state : $dp_{i, j}$ represents the minimum value of $\sum_{p=1}^{i}\left( \min_{1\le q \le j} D(a_p, s_q) \right)$
• DP transition : $dp_{i, j}=\min_{0\le k\lt i} \left( dp_{k, j - 1} + cost(k + 1, i) \right)$
• Final answer : $dp_{N, K}$

There're several problems to handle in the DP above. First, we need to know how to calculate $cost(i, j)$. Then, we need to reduce the time complexity.

We know that $cost(i, j)$ represents $\min_s\left(\sum_{p=i}^{j} D(a_p, s)\right)$. It's easy to observe that the best $s$ is the medium number of $a_i, a_{i+1}, \dots, a_j$, which implies that $cost(i, j)$ can be calculated in $O(1)$. Also, one can observe that the best transition points of $dp$ is monotone increasing i.e. $$H(i, j)=\mathop{\arg\max}_k \left( dp_{k, j - 1} + cost(k + 1, i) \right) \implies H(i, j) \le H(i + 1, j)$$

which is the 1D/1D convex case of Knuth Optimization. Thus, the time complexity of the DP above can be reduced to $O(KN\log N)$.

Then, let $f(x)=dp_{N, x}$. Again, we can discover that $f(x)$ is convex either by intuition or other methods. So we can use wqs binary search to remove the last dimension of $dp$, and the time complexity will be $O(N\log N\log C)$.

#### Problem solution

Note that I added this two lines in order to pass the tight time limit :/

#pragma GCC optimize ("O3,unroll-loops,no-stack-protector")
#pragma GCC target ("sse,sse2,sse3,ssse3,sse4,popcnt,abm,mmx,avx,tune=native")

code
#pragma GCC optimize ("O3,unroll-loops,no-stack-protector")
#pragma GCC target ("sse,sse2,sse3,ssse3,sse4,popcnt,abm,mmx,avx,tune=native")
#include <bits/stdc++.h>
#define F first
#define S second
using namespace std;
typedef long long ll;

const int N = (int)3e5 + 5;

ll n, k, a[N], pre[N], dp[N], cnt[N];

inline ll cost(ll l, ll r) {
ll m = (l + r) / 2;
return pre[r] - pre[m] - pre[((r - l) & 1) ? m : m - 1] + pre[l - 1];
}
inline pair<ll, ll> calc(ll i, ll j) {
return {dp[i] + cost(i + 1, j), -cnt[i] - 1};
}
pair<ll, ll> calc(ll x) {
struct Node { ll p, l, r; };
deque<Node> dq;
dp[0] = cnt[0] = 0;
dq.push_back({0, 1, n});
for(int i = 1 ; i <= n ; ++i) {
dp[i] = calc(dq.front().p, i).F + x, cnt[i] = cnt[dq.front().p] + 1;
if(dq.front().r == i) dq.pop_front();
else dq.front().l++;

while(!dq.empty() && calc(i, dq.back().l) < calc(dq.back().p, dq.back().l)) dq.pop_back();
if(dq.empty()) dq.push_back({i, i + 1, n});
else {
ll l = dq.back().l, r = dq.back().r;
while(l < r) {
ll mid = r - (r - l) / 2;
if(calc(i, mid) < calc(dq.back().p, mid)) r = mid - 1;
else l = mid;
}
dq.back().r = l;
if(l != n) dq.push_back({i, l + 1, n});
}
}
return {cnt[n], dp[n]};
}

void init() {
cin >> n >> k;
for(int i = 1 ; i <= n ; ++i) cin >> a[i], pre[i] = pre[i - 1] + a[i];
}
void solve() {
ll l = 0, r = (ll)1e15, ans = -1;
while(l <= r) {
ll mid = l + (r - l) / 2;
auto [cc, val] = calc(mid);
if(cc >= k) l = mid + 1, ans = val - mid * k;
else r = mid - 1;
}
cout << ans << '\n';
}

int main() {
ios_base::sync_with_stdio(0), cin.tie(0);
init();
solve();
}



### CF739E - Gosha is hunting

#### Problem description

You want to catch some Pokemons. You're given $a$ Poke Balls and $b$ Ultra Balls. There are $N$ Pokemons out there numbered $1$ to $N$. You know that the probablity of catching the $i^{th}$ Pokemon using a Poke Ball is $p_i$, and the probablity of catching it using a Ultra Ball is $u_i$. Also, you can throw at most one Ball of each type at any Pokemon.

Now, you can choose at most $a$ Pokemons to throw Poke Balls at them and at most $b$ Pokemons to throw Ultra Balls at them. You can throw both types of Balls at a same Pokemon. The Pokemon is caught iff it is caught by any of these Balls, and the outcome of a throw doesn't depend on the other throws. Please maximize the expected number of the Pokemons you can catch if you choose optimally.

$2\le N\le 2000, 0\le a, b\le N, 0\le p_i, u_i\le 1$

#### Problem analysis

Let write down the DP equations first:

• DP state : $dp_{i, j, k}$ represents the maximum expected number of Pokemons you can catch considering Pokemons from $1$ to $i$ using at most $j$ Poke Balls and at most $k$ Ultra Balls
• DP transition : $$dp_{i, j, k}=\max \begin{cases} dp_{i-1, j, k}&, \text{don't spend any Balls on Pokemon } i \\ dp_{i-1, j-1, k} + p_i&, \text{spend one Poke Ball on Pokemon } i \\ dp_{i-1, j, k-1} + q_i&, \text{spend one Ultra Ball on Pokemon } i \\ dp_{i-1, j-1, k-1} + p_i + q_i - p_iq_i&, \text{spend one Pole Ball and one Ultra Ball on Pokemon } i \end{cases}$$
• Final answer : $dp_{N, a, b}$

Let $f(x)=dp_{N, a, x}$. Note that $f(x)$ is a concave function i.e. $$f(x)-f(x-1)\ge f(x+1)-f(x), x\in \mathbb{N}$$ , so we can use wqs binary search to remove the last dimension and achieve a time complexity of $O(N^2\log C)$.

The time complexity above is enough to pass this problem. However, we can do better. Let $g(x, y)=dp_{N, x, y}$. It can be observe that $g(x, y)$ is convex. Thus, the $dp_{N, a, b}$ can be calculated in $O(N\log^2 C)$ time by using two was binary searchs (so $N$ can actually be $10^5$ in this problem).

#### Problem solution

The time complexity of first solution is $O(N^2\log C)$, where the range of binary search is $[0, 1]$.

code
#include <bits/stdc++.h>
#define F first
#define S second
using namespace std;
typedef long long ll;
typedef unsigned long long ull;
typedef double ld;

const int N = 2000 + 5;
const ld inf = 1e9;
const ld eps = 1e-8;

int n, a, b, cnt[N][N];
ld p[N], u[N], dp[N][N];

inline int dcmp(ld x, ld y) {
if(abs(x - y) < eps) return 0;
return x < y ? -1 : 1;
}
void upd(int i, int j, ld val, int cc) {
if(dcmp(dp[i][j], val) < 0 || (dcmp(dp[i][j], val) == 0 && cnt[i][j] < cc)) {
dp[i][j] = val;
cnt[i][j] = cc;
}
}
pair<int, ld> calc(ld x) {
fill(dp[0], dp[0] + N, -inf);
dp[0][0] = cnt[0][0] = 0;
for(int i = 1 ; i <= n ; ++i) {
for(int j = 0 ; j <= a ; ++j) {
dp[i][j] = dp[i - 1][j], cnt[i][j] = cnt[i - 1][j];
upd(i, j, dp[i - 1][j] + u[i] - x, cnt[i - 1][j] + 1);
if(j > 0) {
upd(i, j, dp[i - 1][j - 1] + p[i], cnt[i - 1][j - 1]);
upd(i, j, dp[i - 1][j - 1] + p[i] + u[i] - p[i] * u[i] - x, cnt[i - 1][j - 1] + 1);
}
}
}
return {cnt[n][a], dp[n][a]};
}

void init() {
cin >> n >> a >> b;
for(int i = 1 ; i <= n ; ++i) cin >> p[i];
for(int i = 1 ; i <= n ; ++i) cin >> u[i];
}
void solve() {
ld L = 0, R = 1, ans = -1;
for(int i = 0 ; i < 100 ; ++i) {
ld M = (L + R) / 2;
pair<int, ld> res = calc(M);
if(res.F >= b) ans = res.S + M * b, L = M;
else R = M;
}
cout << fixed << setprecision(10) << ans << '\n';
}

int main() {
ios_base::sync_with_stdio(0), cin.tie(0);
init();
solve();
}



The time complexity of the solution below is $O(N\log ^2C)$.

code
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
typedef long double ld;

const int N = (int)1e5 + 5;
const ld eps = 1e-15;

inline int dcmp(ld x, ld y) {
if(abs(x - y) < eps) return 0;
return x < y ? -1 : 1;
}

int n, a, b;
ld p[N], u[N];
struct Node {
ld val;
int cp, cu;

bool operator<(const Node& rhs) const {
return dcmp(val, rhs.val) < 0 || (dcmp(val, rhs.val) == 0 && (cp < rhs.cp || (cp == rhs.cp && cu < rhs.cu)));
}
} dp[N];

inline void upd(int i, Node o) {
if(dp[i] < o) dp[i] = o;
}
Node calc(ld x, ld y) {
dp[0] = {0, 0, 0};
for(int i = 1 ; i <= n ; ++i) {
dp[i] = dp[i - 1];
upd(i, {dp[i - 1].val + p[i] - x, dp[i - 1].cp + 1, dp[i - 1].cu});
upd(i, {dp[i - 1].val + u[i] - y, dp[i - 1].cp, dp[i - 1].cu + 1});
upd(i, {dp[i - 1].val + p[i] + u[i] - p[i] * u[i] - x - y, dp[i - 1].cp + 1, dp[i - 1].cu + 1});
}
return dp[n];
}

void init() {
cin >> n >> a >> b;
for(int i = 1 ; i <= n ; ++i) cin >> p[i];
for(int i = 1 ; i <= n ; ++i) cin >> u[i];
}
void solve() {
ld l = 0, r = 1, ans;
while(r - l > eps) {
ld mid = (l + r) / 2;
Node res;
ld l2 = 0, r2 = 1;
while(r2 - l2 > eps) {
ld mid2 = (l2 + r2) / 2;
res = calc(mid, mid2);
if(res.cu >= b) l2 = mid2;
else r2 = mid2;
}
res = calc(mid, l2);
if(res.cp >= a) l = mid, ans = res.val + mid * a + l2 * b;
else r = mid;
}
cout << fixed << setprecision(15) << ans << '\n';
}

int main() {
ios_base::sync_with_stdio(0), cin.tie(0);
init();
solve();
}



#### Problem description

Given a weight tree with $N$ nodes, please choose $K+1$ non-intersecting paths such that the sum of weights in those paths are maximized.

$0\le K\lt N\le 3\times 10^5$

#### Problem analysis

Let's try to solve this problem with DP on tree:

• DP state : $dp_{u, i, k}, k=0, 1, 2$ represents the maximum weights choosing $i$ disjoint paths in the subtree of $u$ and the degree of $u$ is $k$ (not considering the edge connecting the father node).
• DP transition : suppose that we are handling child $v$, let $mx_k=\max \left\{ dp_{v, k, 0}, dp_{v, k, 1}, dp_{v, k, 2} \right\}$, then: $$\begin{cases} dp_{u, i, 0} = dp_{u, i, 0} + \max_{0\le k\le i} mx_k \\ dp_{u, i, 1} = dp_{u, i, 1} + \max_{0\le k\le i} \left\{ \max \left( mx_k, dp_{u, i-k, 0} + dp_{v, k, 0} + w, dp_{u, i-k-1, 0} + dp_{v, k+1, 1} + w \right) \right\} \\ dp_{u, i, 2} = dp_{u, i, 2} + \max_{0\le k\le i} \left\{ \max \left( mx_k, dp_{u, i-k+1, 1} + dp_{v, k-1, 0} + w, dp_{u, i-k, 1} + dp_{v, k, 1} + w \right) \right\} \\ \end{cases}$$
• Final answer : $\max_{0\le k\le 2} \left( dp_{1, K, k} \right)$

Obviously, the time complexity is $O(NK)$. Then, let $f(x, k)=dp_{1, x, k}, k=0,1,2$. We can observe that all three of them $f(x, 0), f(x, 1), f(x, 2)$ are concave. So we can reduce time complexity to $O(N\log C)$ using wqs binary search.

#### Problem solution

The transition part above may not be clear, so please kindly refer to the code below. Also, this problem cannot be submitted on the judge; Thus, the testdata and a simple judge script is offered here.

code
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;

const int N = (int)3e5 + 5;
const ll inf = (ll)1e15;

struct Node {
ll v, cnt;

Node operator+(const Node& rhs) const {
return {v + rhs.v, cnt + rhs.cnt};
}
bool operator<(const Node& rhs) const {
return v < rhs.v || (v == rhs.v && cnt < rhs.cnt);
}
};

int n, k;
vector<pair<int, int> > G[N];
Node dp[3][N];

Node Max(int u) {
return max(dp[0][u], max(dp[1][u], dp[2][u]));
}
void dfs(int u, int p, ll x) {
dp[0][u] = {0, 0}, dp[1][u] = {-inf, 0}, dp[2][u] = {-x, 1};
for(auto [v, w] : G[u]) if(v != p) {
dfs(v, u, x);
Node tmp = Max(v);
dp[2][u] = max(dp[2][u] + tmp, dp[1][u] + max(dp[0][v] + (Node){w, 0}, dp[1][v] + (Node){w + x, -1}));
dp[1][u] = max(dp[1][u] + tmp, dp[0][u] + max(dp[0][v] + (Node){w - x, 1}, dp[1][v] + (Node){w, 0}));
dp[0][u] = dp[0][u] + tmp;
}
}

void init() {
cin >> n >> k; k++;
for(int i = 1 ; i < n ; ++i) {
int u, v, w; cin >> u >> v >> w;
G[u].push_back({v, w});
G[v].push_back({u, w});
}
}
void solve() {
ll l = -inf, r = inf, ans;
while(l <= r) {
ll mid = l + (r - l) / 2;
dfs(1, 0, mid);
Node res = Max(1);
if(res.cnt >= k) l = mid + 1, ans = res.v + mid * k;
else r = mid - 1;
}
cout << ans << '\n';
}

int main() {
ios_base::sync_with_stdio(0), cin.tie(0);
init();
solve();
}