Powered by Starship v1.3
Binlifting <3
Jul 23 2024
10:13 PM

I very recently finished this USACO gold question, and my friend was interested in my (mildly scuffed) binary lifting approach. Here it is, I guess :P


Summarization of the problem:

  • We have a tree of size N <= 1e5
  • Each node is colored 1 <= c_i <= N
  • Answer M <= 1e6 queries of the following form: Between two nodes a_i, b_i, does there exist a node of color c_i?

Now, initially I misread the question, and thought that there existed only one node of each color. So, my first approach was to find the unique node x_i of color c_i, and then check whether or not dist(a_i, a_b) == dist(a_i, x_i) + dist(b_i, x_i). Because I am mildly obsessed with binlifting, I implemented as such:

#include <iostream>
#include <fstream>
#include <vector>
#include <set>
const int MAXN = 100000;
// binlift structures int at[MAXN], depth[MAXN], par[MAXN], jump[MAXN];
std::vector<int> adj[MAXN];
// single pass and accumulate jump structures void dfs(int i, int p){ for(int x : adj[i]) if(x != p){ if(depth[i] + depth[jump[jump[i]]] == depth[jump[i]]*2) jump[x] = jump[jump[i]]; else jump[x] = i;
depth[x] = depth[i]+1, par[x] = i, dfs(x, i); } }
// determine depth with binlift int dist(int a, int b){ if(depth[a] < depth[b]) return dist(b, a);
int res = depth[a] + depth[b];
// first equalize depth while(depth[a] > depth[b]){ if(depth[jump[a]] < depth[b]) a = par[a]; else a = jump[a]; }
// then move to LCA while(a != b){ if(jump[a] == jump[b]) a = par[a], b = par[b]; else a = jump[a], b = jump[b]; }
return res - depth[a]*2; }
int main(){
// DEFINITELY weirdest I/O ive ever used, and that's saying something #define fin std::cin #ifndef fin std::ifstream fin("milkvisits.in"); std::ofstream fout("milkvisits.out"); #else #define fout std::cout #endif
int n, m; fin >> n >> m; for(int i=0; i<n; ++i){ int t; fin >> t; at[t-1] = i; }
for(int i=1; i<n; ++i) int a, b; fin >> a >> b, --a, --b, adj[a].push_back(b), adj[b].push_back(a);
dfs(0, -1);
while(m--){ int a, b, c; fin >> a >> b >> c, --a, --b, c = at[c-1];
fout << (dist(a, b) == dist(a, c) + dist(b, c)); }
fout << '\n'; }

This solution doesn't pass samples, due to my misreading :P

But can we adapt this solution to multiple nodes of each color? With up to n nodes of the same color, checking each node individually can potentially blow up to n^2 time. We need a different approach.

This is actually where my weird binary lifting comes in handy: I only have to compute O(n) jumps, and the average length of each jump is relatively small. It turns out that I was able to get away with storing an std::set of colors contained across each jump, and accumulate all within my single DFS cycle. Time complexity works out, because each node is encapsulated within approximately O(log n) jumps:

#include <iostream>
#include <fstream>
#include <vector>
#include <set>
const int MAXN = 100000;
// binlift structures int at[MAXN], depth[MAXN], par[MAXN], jump[MAXN]; std::set<int> along[MAXN];
std::vector<int> adj[MAXN];
// single pass and accumulate not only jump structure locations, // but also update sets (somehow this avoids tle?) void dfs(int i, int p){ for(int x : adj[i]) if(x != p){ if(depth[i] + depth[jump[jump[i]]] == depth[jump[i]]*2){ jump[x] = jump[jump[i]]; for(int e : along[jump[i]]) along[x].insert(e); for(int e : along[i]) along[x].insert(e); }else jump[x] = i; along[x].insert(at[x]);
depth[x] = depth[i]+1, par[x] = i, dfs(x, i); } }
int main(){
#define fin std::cin #ifndef fin std::ifstream fin("milkvisits.in"); std::ofstream fout("milkvisits.out"); #else #define fout std::cout #endif
int n, m; fin >> n >> m; for(int i=0; i<n; ++i) fin >> at[i], --at[i];
for(int i=1; i<n; ++i) int a, b; fin >> a >> b, --a, --b, adj[a].push_back(b), adj[b].push_back(a);
dfs(0, -1);
while(m--){ int a, b, c; fin >> a >> b >> c, --a, --b, --c; int out = 0;
// binlift pass from a, b -> LCA and check all relevant // sets along the pass for type c
if(depth[a] < depth[b]) std::swap(a, b);
out |= (at[a] == c | at[b] == c);
while(depth[a] > depth[b]){ if(depth[jump[a]] < depth[b]) a = par[a]; else out |= along[a].count(c), a = jump[a]; out |= at[a] == c; }
while(a != b){ if(jump[a] == jump[b]) a = par[a], b = par[b]; else out |= along[a].count(c), out |= along[b].count(c), a = jump[a], b = jump[b]; out |= (at[a] == c | at[b] == c); }
fout << out; }
fout << '\n'; }

AC; O(n log^2 n) is barely slower than nlogn :D

tags: programming usaco