A Simple Introduction to SoS(Sum over Subset) Dynamic Programming
tags:icpc
algorithm
dp
sum-over-subset
under-construction
Outline
Introduction
SoS DP is a technique to solve the following problem: Given $n=2^m$ integers $a_1, a_2, \dots, a_m$, we want to calculate $$ f(mask) = \sum_{i\subseteq mask} a_i $$ for $mask=0, 1, \dots, n-1$. For example, for $n=2^2$, $a=\{1, 2, 3, 4\}$, $$ f(10) = a[00]+a[10] = 1 + 3 = 4 $$
Implementation
Complexity of brute-force calculation is $O(4^m)$ or $O(3^m)$. Further optimization requires the some observations. Let $S(mask)$ be all the set of all submask of $mask$ i.e. $$ S(mask) = \{submask | submask \subseteq mask \} $$ Then, we split $S(mask)$ into two parts. Define $S(mask, i)$ as the set of all submask of $mask$ AND those submask differ with $mask$ ONLY in the first $i$ bits (0-base). More formally, $$ S(mask, i) = \{submask | submask\subseteq mask \land mask \oplus submask < 2^i \} $$ For example, $S(110)=\{000, 010, 100, 110\}, S(110, 1)=\{100, 110\}$. Under this definition, $S(mask)=S(mask, m-1)$. Furthermore, we have the following relationship: $$ S(mask, i) = \begin{cases} S(mask, i-1) &, \text{mask $i^{th}$ bit = 0} \\ S(mask, i-1) \bigcup S(mask \oplus 2^i, i-1) &, \text{mask $i^{th}$ bit = 1} \\ \end{cases} $$ For example, $S(110, 1)=S(110, 0) + S(100, 0) = \{110\} + \{100\}$.
With this relationship, we can reduce the complexity to $O(m2^m)$!
int f(int mask, int k) {
if (dp[mask] != -1) return dp[mask];
if (k == -1) return a[mask];
dp[mask] = f(mask, k - 1);
if (mask & (1 << k))
dp[mask] += f(mask ^ (1 << k), k - 1);
return dp[mask];
}
for (int i = 0; i < (1<<m); i++) f(i, m-1);
or calculate it iteratively (faster)
for (int i = 0; i < (1<<m); i++) f[i] = a[i];
for (int i = 0; i < m; i++) {
for (int mask = 0; mask < (1<<m); mask++) {
if (mask & (1 << i))
f[i] += f[mask ^ (1 << i)];
}
}
Example problems
CF165E - Compatible Numbers
Problem description
Problem analysis
Problem solution
code
#include <bits/stdc++.h>
using namespace std;
using ll = long long;
const int N = (int)1e6 + 5;
const int M = 22;
int n, a[N], f[1 << M], e[1 << M];
void init() {
cin >> n;
for(int i = 0 ; i < n ; ++i) {
cin >> a[i];
e[a[i]] = 1;
}
}
void solve() {
for(int i = 0 ; i < (1 << M) ; ++i) f[i] = (e[i] ? i : -1);
for(int i = 0 ; i < M ; ++i) {
for(int msk = 0 ; msk < (1 << M) ; ++msk) if(msk & (1 << i)) f[msk] = max(f[msk], f[msk ^ (1 << i)]);
}
int msk = (1 << M) - 1;
for(int i = 0 ; i < n ; ++i) cout << f[msk ^ a[i]] << " \n"[i == n - 1];
}
int main() {
ios_base::sync_with_stdio(0), cin.tie(0);
init();
solve();
}
CF383E - Vowels
Problem description
Problem analysis
Problem solution
code
#include <bits/stdc++.h>
using namespace std;
using ll = long long;
const int N = (int)1e4 + 5;
const int M = 24;
int n, f[1 << M];
void init() {
cin >> n;
for(int i = 0 ; i < n ; ++i) {
string s; cin >> s;
int msk = (1 << (s[0] - 'a')) | (1 << (s[1] - 'a')) | (1 << (s[2] - 'a'));
f[msk] += 1;
}
}
void solve() {
for(int i = 0 ; i < M ; ++i) {
for(int j = 0 ; j < (1 << M) ; ++j) if(j & (1 << i)) {
f[j] += f[j ^ (1 << i)];
}
}
int ans = 0;
for(int i = 0 ; i < (1 << M) ; ++i) {
ans ^= (n - f[i]) * (n - f[i]);
}
cout << ans << '\n';
}
int main() {
ios_base::sync_with_stdio(0), cin.tie(0);
init();
solve();
}
CF449D - Jzzhu and Numbers
Problem description
Problem analysis
Problem solution
code
#include <bits/stdc++.h>
using namespace std;
using ll = long long;
const int N = (int)1e6 + 5;
const int M = (int)1e9 + 7;
const int B = 20;
int n, a[N], f[1 << B], pw[N];
void init() {
cin >> n;
pw[0] = 1;
for(int i = 0 ; i < n ; ++i) cin >> a[i], pw[i + 1] = pw[i] * 2 % M;
}
void solve() {
for(int i = 0 ; i < n ; ++i) f[a[i]]++;
for(int i = 0 ; i < B ; ++i) {
for(int j = 0 ; j < (1 << B) ; ++j) if(!(j & (1 << i))) {
f[j] += f[j ^ (1 << i)];
}
}
int ans = 0;
for(int i = 0 ; i < (1 << B) ; ++i) {
ans = ((ll)ans + (pw[f[i]] - 1) * (__builtin_popcount(i) & 1 ? -1 : 1) + M) % M;
}
cout << ans << '\n';
}
int main() {
ios_base::sync_with_stdio(0), cin.tie(0);
init();
solve();
}
Codechef - STR_FUNC
Problem description
Problem analysis
Problem solution
code
#include <bits/stdc++.h>
using namespace std;
using ll = long long;
const int N = (int)6e5 + 5;
const int M = (int)1e9 + 7;
const int B = 21;
int n, a[N], f[N], g[N], dpf[N][B], dpg[N][B];
inline ll sqr(int x) {
return (ll)x * x % M;
}
void init() {
cin >> n;
for(int i = 0 ; i < n ; ++i) cin >> a[i];
}
void solve() {
int ans = 0;
for(int i = 0 ; i < n ; ++i) {
// SoS over g
for(int j = 0 ; j < B - 1 ; ++j) {
dpf[i][j + 1] = dpf[i][j];
if(i & (1 << j)) dpf[i][j + 1] = ((ll)dpf[i][j + 1] + dpf[i ^ (1 << j)][j]) % M;
}
// Get f[i] (dpf[i][B - 1] = \subsetneq over g[j], j \in i)
f[i] = (sqr(a[i]) + sqr(dpf[i][B - 1])) % M;
// SoS over f
dpg[i][0] = sqr(f[i]);
for(int j = 0 ; j < B - 1 ; ++j) {
dpg[i][j + 1] = dpg[i][j];
if(i & (1 << j)) dpg[i][j + 1] = ((ll)dpg[i][j + 1] + dpg[i ^ (1 << j)][j]) % M;
}
// Get g[i]
g[i] = dpg[i][B - 1];
// update dpf[i][B - 1] to \subseteq i.e. include g[i] into dpf
dpf[i][0] = g[i];
for(int j = 0 ; j < B - 1 ; ++j) {
dpf[i][j + 1] = dpf[i][j];
if(i & (1 << j)) dpf[i][j + 1] = ((ll)dpf[i][j + 1] + dpf[i ^ (1 << j)][j]) % M;
}
// add result to ans
ans = ((ll)ans + (ll)i * f[i] % M * g[i] % M) % M;
}
cout << ans << '\n';
}
int main() {
ios_base::sync_with_stdio(0), cin.tie(0);
init();
solve();
}
CF800D - Varying Kibbits
Problem description
Problem analysis
Problem solution
code
#include <bits/stdc++.h>
using namespace std;
using ll = long long;
const int N = (int)1e6 + 5;
const int M = (int)1e9 + 7;
const int B = 6;
const int ten[] = {1, 10, 100, 1000, 10000, 100000, 1000000};
int n, cnt[N], sum[N], sum2[N], two[N], ans[N];
inline int add(int x, int y) {
ll z = (ll)x + y;
return z >= M ? z - M : (z < 0 ? z + M : z);
}
inline int mul(int x, int y) {
return (ll)x * y % M;
}
void init() {
cin >> n;
two[0] = 1;
for(int i = 0 ; i < n ; ++i) {
two[i + 1] = add(two[i], two[i]);
int x; cin >> x;
cnt[x]++;
sum[x] = add(sum[x], x);
sum2[x] = add(sum2[x], mul(x, x));
}
}
void solve() {
// calculate f(S) >= x
for(int j = 0 ; j < B ; ++j) {
for(int i = ten[B] - 1 ; i >= 0 ; --i) if(i / ten[j] % 10 != 9) {
cnt[i] += cnt[i + ten[j]];
sum[i] = add(sum[i], sum[i + ten[j]]);
sum2[i] = add(sum2[i], sum2[i + ten[j]]);
}
}
// 2^(sz(S) - 2) * (\sigma x^2 + (\sigma x)^2), x \in S, sz(S) > 1
for(int i = ten[B] - 1 ; i >= 0 ; --i) {
if(cnt[i] == 1) ans[i] = sum2[i];
else if(cnt[i] > 1) ans[i] = mul(two[cnt[i] - 2], add(sum2[i], mul(sum[i], sum[i])));
}
// now, exclude f(S) > x
// with inclusion-exclusion,
/*for(int i = 0 ; i < ten[B] - 1 ; ++i) {
for(int j = 1 ; j < (1 << B) ; ++j) {
int msk = i, flag = 1;
for(int k = 0 ; k < B ; ++k) if(j & (1 << k)) {
if(i / ten[k] % 10 != 9) msk += ten[k];
else flag = 0;
}
if(flag) ans[i] = add(ans[i], (__builtin_popcount(j) & 1 ? -ans[msk] : ans[msk]));
}
}*/
// or with another SoS
for(int j = 0 ; j < B ; ++j) {
for(int i = 0 ; i < ten[B] - 1 ; ++i) {
if(i / ten[j] % 10 != 9) ans[i] = add(ans[i], -ans[i + ten[j]]);
}
}
// ans
ll res = 0;
for(int i = ten[B] - 1 ; i >= 0 ; --i) res ^= (ll)i * ans[i];
cout << res << '\n';
}
int main() {
ios_base::sync_with_stdio(0), cin.tie(0);
init();
solve();
}
More problems
- Codechef - COVERING
- COCI 2011/2012 Problem KOSARE
- Hackerrank - Vim War
- Hackerrank - Subset
- Csacademy - Good Subpermutation