Outline

Introduction

SoS DP is a technique to solve the following problem: Given $n=2^m$ integers $a_1, a_2, \dots, a_m$, we want to calculate $$ f(mask) = \sum_{i\subseteq mask} a_i $$ for $mask=0, 1, \dots, n-1$. For example, for $n=2^2$, $a=\{1, 2, 3, 4\}$, $$ f(10) = a[00]+a[10] = 1 + 3 = 4 $$

Implementation

Complexity of brute-force calculation is $O(4^m)$ or $O(3^m)$. Further optimization requires the some observations. Let $S(mask)$ be all the set of all submask of $mask$ i.e. $$ S(mask) = \{submask | submask \subseteq mask \} $$ Then, we split $S(mask)$ into two parts. Define $S(mask, i)$ as the set of all submask of $mask$ AND those submask differ with $mask$ ONLY in the first $i$ bits (0-base). More formally, $$ S(mask, i) = \{submask | submask\subseteq mask \land mask \oplus submask < 2^i \} $$ For example, $S(110)=\{000, 010, 100, 110\}, S(110, 1)=\{100, 110\}$. Under this definition, $S(mask)=S(mask, m-1)$. Furthermore, we have the following relationship: $$ S(mask, i) = \begin{cases} S(mask, i-1) &, \text{mask $i^{th}$ bit = 0} \\ S(mask, i-1) \bigcup S(mask \oplus 2^i, i-1) &, \text{mask $i^{th}$ bit = 1} \\ \end{cases} $$ For example, $S(110, 1)=S(110, 0) + S(100, 0) = \{110\} + \{100\}$.

With this relationship, we can reduce the complexity to $O(m2^m)$!

int f(int mask, int k) {
  if (dp[mask] != -1) return dp[mask];
  if (k == -1) return a[mask];
  dp[mask] = f(mask, k - 1);
  if (mask & (1 << k))
    dp[mask] += f(mask ^ (1 << k), k - 1);
  return dp[mask];
}
for (int i = 0; i < (1<<m); i++) f(i, m-1);

or calculate it iteratively (faster)

for (int i = 0; i < (1<<m); i++) f[i] = a[i];
for (int i = 0; i < m; i++) {
  for (int mask = 0; mask < (1<<m); mask++) {
    if (mask & (1 << i))
      f[i] += f[mask ^ (1 << i)];
  }
}

Example problems

CF165E - Compatible Numbers

Problem description

Problem analysis

Problem solution

code
#include <bits/stdc++.h>
using namespace std;
using ll = long long;

const int N = (int)1e6 + 5;
const int M = 22;

int n, a[N], f[1 << M], e[1 << M];

void init() {
  cin >> n;
  for(int i = 0 ; i < n ; ++i) {
    cin >> a[i];
    e[a[i]] = 1;
  }
}
void solve() {
  for(int i = 0 ; i < (1 << M) ; ++i) f[i] = (e[i] ? i : -1);
  for(int i = 0 ; i < M ; ++i) {
    for(int msk = 0 ; msk < (1 << M) ; ++msk) if(msk & (1 << i)) f[msk] = max(f[msk], f[msk ^ (1 << i)]);
  }
  int msk = (1 << M) - 1;
  for(int i = 0 ; i < n ; ++i) cout << f[msk ^ a[i]] << " \n"[i == n - 1];
}

int main() {
  ios_base::sync_with_stdio(0), cin.tie(0);
  init();
  solve();
}

CF383E - Vowels

Problem description

Problem analysis

Problem solution

code
#include <bits/stdc++.h>
using namespace std;
using ll = long long;

const int N = (int)1e4 + 5;
const int M = 24;

int n, f[1 << M];

void init() {
  cin >> n;
  for(int i = 0 ; i < n ; ++i) {
    string s; cin >> s;
    int msk = (1 << (s[0] - 'a')) | (1 << (s[1] - 'a')) | (1 << (s[2] - 'a'));
    f[msk] += 1;
  }
}
void solve() {
  for(int i = 0 ; i < M ; ++i) {
    for(int j = 0 ; j < (1 << M) ; ++j) if(j & (1 << i)) {
      f[j] += f[j ^ (1 << i)];
    }
  }
  int ans = 0;
  for(int i = 0 ; i < (1 << M) ; ++i) {
    ans ^= (n - f[i]) * (n - f[i]);
  }
  cout << ans << '\n';
}

int main() {
  ios_base::sync_with_stdio(0), cin.tie(0);
  init();
  solve();
}

CF449D - Jzzhu and Numbers

Problem description

Problem analysis

Problem solution

code
#include <bits/stdc++.h>
using namespace std;
using ll = long long;

const int N = (int)1e6 + 5;
const int M = (int)1e9 + 7;
const int B = 20;

int n, a[N], f[1 << B], pw[N];

void init() {
  cin >> n;
  pw[0] = 1;
  for(int i = 0 ; i < n ; ++i) cin >> a[i], pw[i + 1] = pw[i] * 2 % M;
}
void solve() {
  for(int i = 0 ; i < n ; ++i) f[a[i]]++;
  for(int i = 0 ; i < B ; ++i) {
    for(int j = 0 ; j < (1 << B) ; ++j) if(!(j & (1 << i))) {
      f[j] += f[j ^ (1 << i)];
    }
  }
  int ans = 0;
  for(int i = 0 ; i < (1 << B) ; ++i) {
    ans = ((ll)ans + (pw[f[i]] - 1) * (__builtin_popcount(i) & 1 ? -1 : 1) + M) % M;
  }
  cout << ans << '\n';
}

int main() {
  ios_base::sync_with_stdio(0), cin.tie(0);
  init();
  solve();
}

Codechef - STR_FUNC

Problem description

Problem analysis

Problem solution

code
#include <bits/stdc++.h>
using namespace std;
using ll = long long;

const int N = (int)6e5 + 5;
const int M = (int)1e9 + 7;
const int B = 21;

int n, a[N], f[N], g[N], dpf[N][B], dpg[N][B];

inline ll sqr(int x) {
  return (ll)x * x % M;
}

void init() {
  cin >> n;
  for(int i = 0 ; i < n ; ++i) cin >> a[i];
}
void solve() {
  int ans = 0;
  for(int i = 0 ; i < n ; ++i) {
    // SoS over g
    for(int j = 0 ; j < B - 1 ; ++j) {
      dpf[i][j + 1] = dpf[i][j];
      if(i & (1 << j)) dpf[i][j + 1] = ((ll)dpf[i][j + 1] + dpf[i ^ (1 << j)][j]) % M;
    }
    // Get f[i] (dpf[i][B - 1] = \subsetneq over g[j], j \in i)
    f[i] = (sqr(a[i]) + sqr(dpf[i][B - 1])) % M;
    // SoS over f
    dpg[i][0] = sqr(f[i]);
    for(int j = 0 ; j < B - 1 ; ++j) {
      dpg[i][j + 1] = dpg[i][j];
      if(i & (1 << j)) dpg[i][j + 1] = ((ll)dpg[i][j + 1] + dpg[i ^ (1 << j)][j]) % M;
    }
    // Get g[i]
    g[i] = dpg[i][B - 1];
    // update dpf[i][B - 1] to \subseteq i.e. include g[i] into dpf
    dpf[i][0] = g[i];
    for(int j = 0 ; j < B - 1 ; ++j) {
      dpf[i][j + 1] = dpf[i][j];
      if(i & (1 << j)) dpf[i][j + 1] = ((ll)dpf[i][j + 1] + dpf[i ^ (1 << j)][j]) % M;
    }
    // add result to ans
    ans = ((ll)ans + (ll)i * f[i] % M * g[i] % M) % M;
  }
  cout << ans << '\n';
}

int main() {
  ios_base::sync_with_stdio(0), cin.tie(0);
  init();
  solve();
}

CF800D - Varying Kibbits

Problem description

Problem analysis

Problem solution

code
#include <bits/stdc++.h>
using namespace std;
using ll = long long;

const int N = (int)1e6 + 5;
const int M = (int)1e9 + 7;
const int B = 6;
const int ten[] = {1, 10, 100, 1000, 10000, 100000, 1000000};

int n, cnt[N], sum[N], sum2[N], two[N], ans[N];

inline int add(int x, int y) { 
  ll z = (ll)x + y;
  return z >= M ? z - M : (z < 0 ? z + M : z);
}
inline int mul(int x, int y) {
  return (ll)x * y % M;
}

void init() {
  cin >> n;
  two[0] = 1;
  for(int i = 0 ; i < n ; ++i) {
    two[i + 1] = add(two[i], two[i]);
    int x; cin >> x;
    cnt[x]++;
    sum[x] = add(sum[x], x);
    sum2[x] = add(sum2[x], mul(x, x));
  }
}
void solve() {
  // calculate f(S) >= x
  for(int j = 0 ; j < B ; ++j) {
    for(int i = ten[B] - 1 ; i >= 0 ; --i) if(i / ten[j] % 10 != 9) {
      cnt[i] += cnt[i + ten[j]];
      sum[i] = add(sum[i], sum[i + ten[j]]);
      sum2[i] = add(sum2[i], sum2[i + ten[j]]);
    }
  }
  // 2^(sz(S) - 2) * (\sigma x^2 + (\sigma x)^2), x \in S, sz(S) > 1
  for(int i = ten[B] - 1 ; i >= 0 ; --i) {
    if(cnt[i] == 1) ans[i] = sum2[i];
    else if(cnt[i] > 1) ans[i] = mul(two[cnt[i] - 2], add(sum2[i], mul(sum[i], sum[i])));
  }
  // now, exclude f(S) > x
  // with inclusion-exclusion,
  /*for(int i = 0 ; i < ten[B] - 1 ; ++i) {
    for(int j = 1 ; j < (1 << B) ; ++j) {
      int msk = i, flag = 1;
      for(int k = 0 ; k < B ; ++k) if(j & (1 << k)) {
        if(i / ten[k] % 10 != 9) msk += ten[k];
        else flag = 0;
      }
      if(flag) ans[i] = add(ans[i], (__builtin_popcount(j) & 1 ? -ans[msk] : ans[msk]));
    }
  }*/
  // or with another SoS
  for(int j = 0 ; j < B ; ++j) {
    for(int i = 0 ; i < ten[B] - 1 ; ++i) {
      if(i / ten[j] % 10 != 9) ans[i] = add(ans[i], -ans[i + ten[j]]);
    }
  }
  // ans
  ll res = 0;
  for(int i = ten[B] - 1 ; i >= 0 ; --i) res ^= (ll)i * ans[i];
  cout << res << '\n';
}

int main() {
  ios_base::sync_with_stdio(0), cin.tie(0);
  init();
  solve();
}

More problems

Reference