Powered by Starship v1.3
CSES Fixed-Length Paths I
Oct 4 2024
4:10 PM

Problem statement

Actually very similar problem to IOI11p2 (Race) which I did last weekend, using centroid decomposition and then naive dp. Decided to write this one in C*, which was surprisingly easy.

#include <iframe>
#include <vec>
// kind of a hack; c* does not have foreach #define nxt adj[i].d[j]
@cgmain
@vec[int][2e5] adj; @m int n, k;
bool[2e5] vis;
// centroid decmp int[2e5] w;
fn pick(int i, int p, int sz) -> int: w[i] = 1;
int res = sz-1; bool work = 1;
int j; for(j=0; j<adj[i].sz; ++j) if(nxt!=p && !vis[nxt]) int r = pick(nxt, i, sz); if(r!=-1) return r end w[i] = w[i] + w[nxt]; work = work & (w[nxt] <= sz/2); res = res - w[nxt]; end end
if(res <= (sz/2) && work) return i end return -1; end
fn count(int i, int p) -> int: int res = 1;
int j; for(j=0; j<adj[i].sz; ++j) if(nxt!=p && !vis[nxt]) res = res + count(nxt, i); end end
return res; end
fn inline centroid(int i) -> int: int sz = count(i, -1); return pick(i, -1, sz); end
// naive dp; for each subtree "going away" from // some node i, count # of nodes at depth 0..k. // we use only two arrays to save mem ll[2e5+1] build; ll[2e5+1] final; int maxd; ll res;
// divide and conquer thing fn depths(int i, int p, int d): if(maxd < d) maxd = d end if(d > k) return end ++build[d];
int j; for(j=0; j<adj[i].sz; ++j) if(nxt!=p && !vis[nxt]) depths(nxt, i, d+1); end end end
fn relax(int i): i = centroid(i);
vis[i] = 1;
if(count(i, -1) >= k) int j; for(j=0; j<=k; ++j) final[j] = 0 end
// for all subtrees 'connected' to i, process depths for(j=0; j<adj[i].sz; ++j) if(!vis[nxt]) maxd = 0;
depths(nxt, i, 1);
// do some PIE stuff to avoid overcount (path from // node x of depth 3 to node y of depth 2 won't result // in path of length 5 if x,y in same subtree etc
int l; for(l=0; l<=maxd; ++l) if(l+l>k) break end res = res - build[l]*build[k-l] * (1 + (l+l!=k)) end
for(l=0; l<=maxd; ++l) final[l] = final[l] + build[l]; build[l] = 0; end end end
++final[0];
// the rest of the PIE stuff for(j=0; j+j<=k; ++j) res = res + final[j]*final[k-j] * (1 + (j+j!=k)) end
for(j=0; j<adj[i].sz; ++j) if(!vis[nxt]) relax(nxt); end end end end
fn inline main: @poll n k;
for(i=1; i<n; ++i) @m int a, b; @poll a b; adj[a-1]:push(b-1); adj[b-1]:push(a-1); end
relax(0);
@fmt res/2; end

kinda clean I hope?

tags: programming cses