DP optimization - WQS Binary Search Optimization
tags:icpc
algorithm
dp
dp-optimization
binary-search
wqs
Outline
Introduction
In short, this technique can usually “remove” a dimension from DP state. Before we dive in, let’s look at a related problem first:
Convex / Concave function evaluation problem
Implementation
int l, r, ans;
while(l <= r) {
int mid = l + (r - l) / 2;
// calc() is algorithm A, first is M(mid), second is V(mid)
pair<int, int> res = calc(mid);
if(res.second >= need) ans = res.first - mid * need, l = mid + 1;
else r = mid - 1;
}
cout << ans << '\n';
Example problems
BZOJ2654 - Tree
Problem description
You’re given a weighted, undirected and connected graph $G$ with $N$ nodes, and $M$ edges. Every edge has a weight $w$ and a color $c$, which is either black or white. Please calculate the sum of weights on the minimum spanning tree that contains exactly $K$ white edges.
$1\le N\le 5\times 10^4, N-1\le M\le 10^5, 1\le w\le 100, 0\le c\le 1$
Problem analysis
This problem has nothing to do with DP. It’s just a practice problem for wqs binary search :)
Let $f(x)$ be the sum of weights on the minimum spanning tree that contains exactly $x$ white edges. Now, let’s try to show that $f(x)$ is concave fucntion by showing that $$ f(x)-f(x-1)\le f(x+1)-f(x), x\in \mathbb{N} $$ As every time we increase the number of white edges by $1$, we will choose the smallest white edge that is valid. This implies that our “net benefit”(amount of decrease on total weight) of adding the $i^{th}$ white edges will decrease as $i$ goes up, which implies that the inequality above is true.
So now need an algorithm $\mathcal{A}$ to calculate the minimum value of $f(x)+px$ given some value $p$. This can be done by reweighting the white edges (plus $p$ on the original weight) and run Minimum Spanning Tree (MST) algorithm on it. I used Kruskal algorithm in my implementation.
Problem solution
Note that the range of binary search is $[-100, 100]$(the maximum weight of $w$). The time complexity is $O(N\log N\log C)$, where $C=200$(the range of binary search).
code
#include <bits/stdc++.h>
#define F first
#define S second
using namespace std;
typedef long long ll;
const int N = (int)5e4 + 5;
struct DSU {
int fa[N], sz[N];
void init(int n) {
for(int i = 0 ; i < n ; ++i) fa[i] = i, sz[i] = 1;
}
int find(int x) {
return x == fa[x] ? x : fa[x] = find(fa[x]);
}
bool merge(int x, int y) {
x = find(x), y = find(y);
if(x == y) return false;
if(sz[x] < sz[y]) swap(x, y);
sz[x] += sz[y];
fa[y] = x;
sz[y] = 0;
return true;
}
} dsu;
struct Edge {
int u, v, w, c;
bool operator<(const Edge& rhs) const {
return w < rhs.w || (w == rhs.w && c < rhs.c);
}
};
vector<Edge> edges;
int n, m, need;
pair<int, int> calc(int k) {
for(int i = 0 ; i < m ; ++i) if(!edges[i].c) edges[i].w += k;
sort(edges.begin(), edges.end());
dsu.init(n);
int cost = 0, cnt = 0;
for(int i = 0 ; i < m ; ++i) {
if(dsu.merge(edges[i].u, edges[i].v)) {
cost += edges[i].w;
if(!edges[i].c) cnt++;
}
}
for(int i = 0 ; i < m ; ++i) if(!edges[i].c) edges[i].w -= k;
return {cost, cnt};
}
void init() {
cin >> n >> m >> need;
for(int i = 0 ; i < m ; ++i) {
int u, v, w, c; cin >> u >> v >> w >> c;
edges.push_back({u, v, w, c});
}
}
void solve() {
int l = -100, r = 100, ans = 0;
while(l <= r) {
int mid = l + (r - l) / 2;
pair<int, int> res = calc(mid);
if(res.second >= need) ans = res.first - mid * need, l = mid + 1;
else r = mid - 1;
}
cout << ans << '\n';
}
int main() {
ios_base::sync_with_stdio(0), cin.tie(0);
init();
solve();
}
洛谷 U72600 - Commando EX
Problem description
You’re given a sequence $a_1, a_2, \dots, a_N$, please split it in $K$ parts $[l_1, r_1], [l_2, r_2], \dots, [l_K, r_K]$ such that the following cost is minimized: $$ \sum_{i=1}^{K}\left( \sum_{j=l_i}^{r_i} a_j \right)^2 $$
$2\le K\le N\le 10^5, a_i \le 50$
Problem analysis
Let’s first write down the DP equations:
- DP state : $dp_{i, j}$ representes the minimum cost splitting $a_1, a_2, \dots, a_i$ into $j$ parts
- DP transition : $dp_{i, j}=\min_{0\le k\lt i} \left( dp_{k, j - 1} + cost(k + 1, i) \right)$
- Final answer : $dp_{N, K}$
If we calculate it directly, the time complexity is $O(N^2K)$, which is really bad… ($N^2K\approx 10^{15}$) So let’s try to reduce the time complexity!
First, if we rewrite the transition equation: $$ \begin{align*} dp_{i, j} &= \min_{0\le k\lt i} \left( dp_{k, j - 1} + cost(k + 1, i) \right) \\ &= \min_{0\le k\lt i} \left( dp_{k, j - 1} + (pre_i - pre_k)^2 \right) \\ &= \min_{0\le k\lt i} \left( dp_{k, j - 1} + pre_i^2 + pre_k^2 - 2\cdot pre_i\cdot pre_k \right) \\ &= pre_i^2 + \min_{0\le k\lt i} \left( -2\cdot pre_k \cdot pre_i + dp_{k, j - 1} + pre_k^2 \right) \\ &= pre_i^2 + \min_{0\le k\lt i} \left( m_k\cdot pre_i + b_k \right), m_k=-2\cdot pre_k, b_k=dp_{k, j - 1} + pre_k^2 \end{align*} $$ where the last line above is exactly the form that can be optimized to $O(N\log N)$ with convex hull optimization). Futhermore, we can optimize it to $O(N)$ as the queries($pre_i$) are increasing and the slopes($-2\cdot pre_k$) are decreasing. So now the time complexity is $O(NK)$, which is better but still unacceptable.
Second, let $f(x)=dp_{N, x}$. It’s not hard to notice that $f(x)$ is convex. As it is convex, we can use wqs binary search! But we still need to construct an efficient algorithm $\mathcal{A}$ to calculate the minimum value of $f(x)+px$ in order to use wqs. We can use another DP as $\mathcal{A}$:
- DP state : $dp^{’}_i$ represents the minimum cost splitting $a_1, a_2, \dots, a_i$
- DP transition : $dp^{’}_i=pre_i^2 + p + \min_{0\le k\lt i} \left( m_k\cdot pre_i + b_k \right), m_k=-2\cdot pre_k, b_k=dp^{’}_{k} + pre_k^2$
- Final answer : $dp^{’}_N$
Note that the new DP transition is almost same to the old DP transition. The new DP transition have only one dimension and an extra term $p$. This DP can be calculated in $O(N)$ similar to above (convex hull optimization).
Finally, we’ve successfully reduced the time complexity from $O(N^2K)$ to $O(N\log C)$!
Problem solution
Note that I set the range of binary search as $[0, pre_n^2]$. Also note that when calculating the new DP, we need to store the number of parts it splits ($cnt$ array in the code) and when we two transition point with same $dp^{’}$ value, we will choose the one with larger $cnt$ (acutally larger/smaller depends on how you implement the binary search). More details in code.
code
#include <bits/stdc++.h>
#define F first
#define S second
using namespace std;
typedef long long ll;
typedef pair<ll, ll> pll;
const int N = (int)1e5 + 5;
const ll inf = (ll)1e15;
ll n, k, a[N], pre[N], dp[N], cnt[N];
struct CHT {
struct Line {
ll m, b, c;
ll operator()(ll x) { return m * x + b; }
} dq[N * 2];
int l, r;
void init() {
dq[0] = {0, 0, 0};
l = 0, r = 1;
}
bool better_ins(Line& L, Line& L1, Line& L2) {
ll b1 = (L.b - L2.b) * (L2.m - L1.m), b2 = (L2.m - L.m) * (L1.b - L2.b);
return b1 < b2 || (b1 == b2 && L.c > L1.c);
}
bool better_qry(Line& L1, Line& L2, ll x) {
ll b1 = L1(x), b2 = L2(x);
return b1 > b2 || (b1 == b2 && L1.c < L2.c);
}
void insert(Line L) {
while(r - l >= 2 && better_ins(L, dq[r - 1], dq[r - 2])) --r;
dq[r++] = L;
}
pll query(ll x) {
while(r - l >= 2 && better_qry(dq[l], dq[l + 1], x)) ++l;
return {dq[l](x), dq[l].c};
}
} cht;
pll calc(ll x) {
dp[0] = cnt[0] = 0;
cht.init();
for(int i = 1 ; i <= n ; ++i) {
pll qry = cht.query(pre[i]);
dp[i] = pre[i] * pre[i] + x + qry.F;
cnt[i] = qry.S + 1;
cht.insert({-2LL * pre[i], dp[i] + pre[i] * pre[i], cnt[i]});
}
return {cnt[n], dp[n]};
}
void init() {
cin >> n >> k;
for(int i = 1 ; i <= n ; ++i) cin >> a[i], pre[i] = pre[i - 1] + a[i];
}
void solve() {
ll l = 0, r = pre[n] * pre[n], ans = -1;
while(l <= r) {
ll mid = l + (r - l) / 2;
pll res = calc(mid);
if(res.F >= k) ans = res.S - k * mid, l = mid + 1;
else r = mid - 1;
}
cout << ans << '\n';
}
int main() {
ios_base::sync_with_stdio(0), cin.tie(0);
init();
solve();
}
2018 ACM-ICPC Nanjing Regional pB - Tournament
Problem description
Given $N$ points $a_1, a_2, \dots, a_N$ on a number line, please find $K$ points $s_1, s_2, \dots, s_K$ such that $$ \sum_{i=1}^{N}\left( \min_{1\le j \le K} D(a_i, s_j) \right) $$ is minimized, where $D(a_i, s_j)$ is the distance between them.
Note that you only need to calculate the minimum value of it i.e. you don’t need to output the points you choose.
$1\le K\le N\le 3\times 10^5, 0=a_1\lt a_2\lt a_N\le 10^9$
Problem analysis
First, we write down the DP equations:
- DP state : $dp_{i, j}$ represents the minimum value of $\sum_{p=1}^{i}\left( \min_{1\le q \le j} D(a_p, s_q) \right)$
- DP transition : $dp_{i, j}=\min_{0\le k\lt i} \left( dp_{k, j - 1} + cost(k + 1, i) \right)$
- Final answer : $dp_{N, K}$
There’re several problems to handle in the DP above. First, we need to know how to calculate $cost(i, j)$. Then, we need to reduce the time complexity.
We know that $cost(i, j)$ represents $\min_s\left(\sum_{p=i}^{j} D(a_p, s)\right)$. It’s easy to observe that the best $s$ is the medium number of $a_i, a_{i+1}, \dots, a_j$, which implies that $cost(i, j)$ can be calculated in $O(1)$. Also, one can observe that the best transition points of $dp$ is monotone increasing i.e. $$ H(i, j)=\mathop{\arg\max}_k \left( dp_{k, j - 1} + cost(k + 1, i) \right) \implies H(i, j) \le H(i + 1, j) $$
which is the 1D/1D convex case of Knuth Optimization. Thus, the time complexity of the DP above can be reduced to $O(KN\log N)$.
Then, let $f(x)=dp_{N, x}$. Again, we can discover that $f(x)$ is convex either by intuition or other methods. So we can use wqs binary search to remove the last dimension of $dp$, and the time complexity will be $O(N\log N\log C)$.
Problem solution
Note that I added this two lines in order to pass the tight time limit :/
#pragma GCC optimize ("O3,unroll-loops,no-stack-protector")
#pragma GCC target ("sse,sse2,sse3,ssse3,sse4,popcnt,abm,mmx,avx,tune=native")
code
#pragma GCC optimize ("O3,unroll-loops,no-stack-protector")
#pragma GCC target ("sse,sse2,sse3,ssse3,sse4,popcnt,abm,mmx,avx,tune=native")
#include <bits/stdc++.h>
#define F first
#define S second
using namespace std;
typedef long long ll;
const int N = (int)3e5 + 5;
ll n, k, a[N], pre[N], dp[N], cnt[N];
inline ll cost(ll l, ll r) {
ll m = (l + r) / 2;
return pre[r] - pre[m] - pre[((r - l) & 1) ? m : m - 1] + pre[l - 1];
}
inline pair<ll, ll> calc(ll i, ll j) {
return {dp[i] + cost(i + 1, j), -cnt[i] - 1};
}
pair<ll, ll> calc(ll x) {
struct Node { ll p, l, r; };
deque<Node> dq;
dp[0] = cnt[0] = 0;
dq.push_back({0, 1, n});
for(int i = 1 ; i <= n ; ++i) {
dp[i] = calc(dq.front().p, i).F + x, cnt[i] = cnt[dq.front().p] + 1;
if(dq.front().r == i) dq.pop_front();
else dq.front().l++;
while(!dq.empty() && calc(i, dq.back().l) < calc(dq.back().p, dq.back().l)) dq.pop_back();
if(dq.empty()) dq.push_back({i, i + 1, n});
else {
ll l = dq.back().l, r = dq.back().r;
while(l < r) {
ll mid = r - (r - l) / 2;
if(calc(i, mid) < calc(dq.back().p, mid)) r = mid - 1;
else l = mid;
}
dq.back().r = l;
if(l != n) dq.push_back({i, l + 1, n});
}
}
return {cnt[n], dp[n]};
}
void init() {
cin >> n >> k;
for(int i = 1 ; i <= n ; ++i) cin >> a[i], pre[i] = pre[i - 1] + a[i];
}
void solve() {
ll l = 0, r = (ll)1e15, ans = -1;
while(l <= r) {
ll mid = l + (r - l) / 2;
auto [cc, val] = calc(mid);
if(cc >= k) l = mid + 1, ans = val - mid * k;
else r = mid - 1;
}
cout << ans << '\n';
}
int main() {
ios_base::sync_with_stdio(0), cin.tie(0);
init();
solve();
}
CF739E - Gosha is hunting
Problem description
You want to catch some Pokemons. You’re given $a$ Poke Balls and $b$ Ultra Balls. There are $N$ Pokemons out there numbered $1$ to $N$. You know that the probablity of catching the $i^{th}$ Pokemon using a Poke Ball is $p_i$, and the probablity of catching it using a Ultra Ball is $u_i$. Also, you can throw at most one Ball of each type at any Pokemon.
Now, you can choose at most $a$ Pokemons to throw Poke Balls at them and at most $b$ Pokemons to throw Ultra Balls at them. You can throw both types of Balls at a same Pokemon. The Pokemon is caught iff it is caught by any of these Balls, and the outcome of a throw doesn’t depend on the other throws. Please maximize the expected number of the Pokemons you can catch if you choose optimally.
$2\le N\le 2000, 0\le a, b\le N, 0\le p_i, u_i\le 1$
Problem analysis
Let write down the DP equations first:
- DP state : $dp_{i, j, k}$ represents the maximum expected number of Pokemons you can catch considering Pokemons from $1$ to $i$ using at most $j$ Poke Balls and at most $k$ Ultra Balls
- DP transition : $$ dp_{i, j, k}=\max \begin{cases} dp_{i-1, j, k}&, \text{don’t spend any Balls on Pokemon } i \\ dp_{i-1, j-1, k} + p_i&, \text{spend one Poke Ball on Pokemon } i \\ dp_{i-1, j, k-1} + q_i&, \text{spend one Ultra Ball on Pokemon } i \\ dp_{i-1, j-1, k-1} + p_i + q_i - p_iq_i&, \text{spend one Pole Ball and one Ultra Ball on Pokemon } i \end{cases} $$
- Final answer : $dp_{N, a, b}$
Let $f(x)=dp_{N, a, x}$. Note that $f(x)$ is a concave function i.e. $$ f(x)-f(x-1)\ge f(x+1)-f(x), x\in \mathbb{N} $$ , so we can use wqs binary search to remove the last dimension and achieve a time complexity of $O(N^2\log C)$.
The time complexity above is enough to pass this problem. However, we can do better. Let $g(x, y)=dp_{N, x, y}$. It can be observe that $g(x, y)$ is convex. Thus, the $dp_{N, a, b}$ can be calculated in $O(N\log^2 C)$ time by using two was binary searchs (so $N$ can actually be $10^5$ in this problem).
Problem solution
The time complexity of first solution is $O(N^2\log C)$, where the range of binary search is $[0, 1]$.
code
#include <bits/stdc++.h>
#define F first
#define S second
using namespace std;
typedef long long ll;
typedef unsigned long long ull;
typedef double ld;
const int N = 2000 + 5;
const ld inf = 1e9;
const ld eps = 1e-8;
int n, a, b, cnt[N][N];
ld p[N], u[N], dp[N][N];
inline int dcmp(ld x, ld y) {
if(abs(x - y) < eps) return 0;
return x < y ? -1 : 1;
}
void upd(int i, int j, ld val, int cc) {
if(dcmp(dp[i][j], val) < 0 || (dcmp(dp[i][j], val) == 0 && cnt[i][j] < cc)) {
dp[i][j] = val;
cnt[i][j] = cc;
}
}
pair<int, ld> calc(ld x) {
fill(dp[0], dp[0] + N, -inf);
dp[0][0] = cnt[0][0] = 0;
for(int i = 1 ; i <= n ; ++i) {
for(int j = 0 ; j <= a ; ++j) {
dp[i][j] = dp[i - 1][j], cnt[i][j] = cnt[i - 1][j];
upd(i, j, dp[i - 1][j] + u[i] - x, cnt[i - 1][j] + 1);
if(j > 0) {
upd(i, j, dp[i - 1][j - 1] + p[i], cnt[i - 1][j - 1]);
upd(i, j, dp[i - 1][j - 1] + p[i] + u[i] - p[i] * u[i] - x, cnt[i - 1][j - 1] + 1);
}
}
}
return {cnt[n][a], dp[n][a]};
}
void init() {
cin >> n >> a >> b;
for(int i = 1 ; i <= n ; ++i) cin >> p[i];
for(int i = 1 ; i <= n ; ++i) cin >> u[i];
}
void solve() {
ld L = 0, R = 1, ans = -1;
for(int i = 0 ; i < 100 ; ++i) {
ld M = (L + R) / 2;
pair<int, ld> res = calc(M);
if(res.F >= b) ans = res.S + M * b, L = M;
else R = M;
}
cout << fixed << setprecision(10) << ans << '\n';
}
int main() {
ios_base::sync_with_stdio(0), cin.tie(0);
init();
solve();
}
The time complexity of the solution below is $O(N\log ^2C)$.
code
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
typedef long double ld;
const int N = (int)1e5 + 5;
const ld eps = 1e-15;
inline int dcmp(ld x, ld y) {
if(abs(x - y) < eps) return 0;
return x < y ? -1 : 1;
}
int n, a, b;
ld p[N], u[N];
struct Node {
ld val;
int cp, cu;
bool operator<(const Node& rhs) const {
return dcmp(val, rhs.val) < 0 || (dcmp(val, rhs.val) == 0 && (cp < rhs.cp || (cp == rhs.cp && cu < rhs.cu)));
}
} dp[N];
inline void upd(int i, Node o) {
if(dp[i] < o) dp[i] = o;
}
Node calc(ld x, ld y) {
dp[0] = {0, 0, 0};
for(int i = 1 ; i <= n ; ++i) {
dp[i] = dp[i - 1];
upd(i, {dp[i - 1].val + p[i] - x, dp[i - 1].cp + 1, dp[i - 1].cu});
upd(i, {dp[i - 1].val + u[i] - y, dp[i - 1].cp, dp[i - 1].cu + 1});
upd(i, {dp[i - 1].val + p[i] + u[i] - p[i] * u[i] - x - y, dp[i - 1].cp + 1, dp[i - 1].cu + 1});
}
return dp[n];
}
void init() {
cin >> n >> a >> b;
for(int i = 1 ; i <= n ; ++i) cin >> p[i];
for(int i = 1 ; i <= n ; ++i) cin >> u[i];
}
void solve() {
ld l = 0, r = 1, ans;
while(r - l > eps) {
ld mid = (l + r) / 2;
Node res;
ld l2 = 0, r2 = 1;
while(r2 - l2 > eps) {
ld mid2 = (l2 + r2) / 2;
res = calc(mid, mid2);
if(res.cu >= b) l2 = mid2;
else r2 = mid2;
}
res = calc(mid, l2);
if(res.cp >= a) l = mid, ans = res.val + mid * a + l2 * b;
else r = mid;
}
cout << fixed << setprecision(15) << ans << '\n';
}
int main() {
ios_base::sync_with_stdio(0), cin.tie(0);
init();
solve();
}
BZOJ5252 - Link-Cut Tree
Problem description
Given a weight tree with $N$ nodes, please choose $K+1$ non-intersecting paths such that the sum of weights in those paths are maximized.
$0\le K\lt N\le 3\times 10^5$
Problem analysis
Let’s try to solve this problem with DP on tree:
- DP state : $dp_{u, i, k}, k=0, 1, 2$ represents the maximum weights choosing $i$ disjoint paths in the subtree of $u$ and the degree of $u$ is $k$ (not considering the edge connecting the father node).
- DP transition : suppose that we are handling child $v$, let $mx_k=\max \left\{ dp_{v, k, 0}, dp_{v, k, 1}, dp_{v, k, 2} \right\}$, then: $$ \begin{cases} dp_{u, i, 0} = dp_{u, i, 0} + \max_{0\le k\le i} mx_k \\ dp_{u, i, 1} = dp_{u, i, 1} + \max_{0\le k\le i} \left\{ \max \left( mx_k, dp_{u, i-k, 0} + dp_{v, k, 0} + w, dp_{u, i-k-1, 0} + dp_{v, k+1, 1} + w \right) \right\} \\ dp_{u, i, 2} = dp_{u, i, 2} + \max_{0\le k\le i} \left\{ \max \left( mx_k, dp_{u, i-k+1, 1} + dp_{v, k-1, 0} + w, dp_{u, i-k, 1} + dp_{v, k, 1} + w \right) \right\} \\ \end{cases} $$
- Final answer : $\max_{0\le k\le 2} \left( dp_{1, K, k} \right)$
Obviously, the time complexity is $O(NK)$. Then, let $f(x, k)=dp_{1, x, k}, k=0,1,2$. We can observe that all three of them $f(x, 0), f(x, 1), f(x, 2)$ are concave. So we can reduce time complexity to $O(N\log C)$ using wqs binary search.
Problem solution
The transition part above may not be clear, so please kindly refer to the code below. Also, this problem cannot be submitted on the judge; Thus, the testdata and a simple judge script is offered here.
code
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const int N = (int)3e5 + 5;
const ll inf = (ll)1e15;
struct Node {
ll v, cnt;
Node operator+(const Node& rhs) const {
return {v + rhs.v, cnt + rhs.cnt};
}
bool operator<(const Node& rhs) const {
return v < rhs.v || (v == rhs.v && cnt < rhs.cnt);
}
};
int n, k;
vector<pair<int, int> > G[N];
Node dp[3][N];
Node Max(int u) {
return max(dp[0][u], max(dp[1][u], dp[2][u]));
}
void dfs(int u, int p, ll x) {
dp[0][u] = {0, 0}, dp[1][u] = {-inf, 0}, dp[2][u] = {-x, 1};
for(auto [v, w] : G[u]) if(v != p) {
dfs(v, u, x);
Node tmp = Max(v);
dp[2][u] = max(dp[2][u] + tmp, dp[1][u] + max(dp[0][v] + (Node){w, 0}, dp[1][v] + (Node){w + x, -1}));
dp[1][u] = max(dp[1][u] + tmp, dp[0][u] + max(dp[0][v] + (Node){w - x, 1}, dp[1][v] + (Node){w, 0}));
dp[0][u] = dp[0][u] + tmp;
}
}
void init() {
cin >> n >> k; k++;
for(int i = 1 ; i < n ; ++i) {
int u, v, w; cin >> u >> v >> w;
G[u].push_back({v, w});
G[v].push_back({u, w});
}
}
void solve() {
ll l = -inf, r = inf, ans;
while(l <= r) {
ll mid = l + (r - l) / 2;
dfs(1, 0, mid);
Node res = Max(1);
if(res.cnt >= k) l = mid + 1, ans = res.v + mid * k;
else r = mid - 1;
}
cout << ans << '\n';
}
int main() {
ios_base::sync_with_stdio(0), cin.tie(0);
init();
solve();
}