跳转至

树分块

树分块,就是像对序列分块一样设一个阈值 \(B\),然后在树上随机撒 \(\dfrac{n}{B}\) 个关键点,满足任意一个点到距离其最近的关键点距离不超过 \(\mathcal O(B)\) 级别,这样我们就可以预处理关键点两两之间的信息。

在询问两个点路径上的信息时,我们直接将预处理的信息拿出来使用,再额外加上两个端点到距离它们最近的关键点之间的路径的贡献即可算出答案,复杂度 \(\mathcal O(B^2+qB+\dfrac{n^2}{B})\),一般 \(B\)\(\sqrt{n}\)

对于“随机撒点”,我们有一个很显然的想法是将深度 \(\bmod B=0\) 设为关键点,不过则可能会被卡:

一条长度为 \(B\) 的链下面挂 \(n-B\) 个叶子。

因此我们进行优化:我们记 \(dis_i\)\(i\) 到距离它最**远**的叶子节点的距离,那么我们将深度 \(\bmod B=0\)\(dis_i\ge B\) 的点 \(i\) 设为关键节点即可,可以证明在这种构造方法下关键点个数是严格 \(\dfrac{n}{B}\) 级别的。

例题 #1 Count on a tree II

给定一个 \(n\) 个节点的树,每个节点上有一个整数,\(i\) 号点的整数为 \(val_i\)

\(m\) 次询问,每次给出 \(u',v\),您需要将其解密得到 \(u,v\),并查询 \(u\)\(v\) 的路径上有多少个不同的整数。

解密方式:\(u=u' \operatorname{xor} lastans\)

\(lastans\) 为上一次询问的答案,若无询问则为 \(0\)

对于 \(100\%\) 的数据,\(1\le u,v\le n\le 4\times 10^4\)\(1\le m\le 10^5\)\(0\le u',val_i<2^{31}\)


首先对于计数问题,我们使用 bitset来统计出现的不同整数个数(离散化)。我们使用树分块后,预处理出每两个相邻关键点之间的答案,然后在询问时将预处理答案和散点的答案(bitset形式)位或即可。

/*                                                                                
                      Keyblinds Guide
                    ###################
      @Ntsc 2024

      - Ctrl+Alt+G then P : Enter luogu problem details
      - Ctrl+Alt+B : Run all cases in CPH
      - ctrl+D : choose this and dump to the next
      - ctrl+Shift+L : choose all like this
      - ctrl+K then ctrl+W: close all
      - Alt+la/ra : move mouse to lstans/nxt pos'

*/
#include <bits/stdc++.h>
#include <queue>
using namespace std;

#define rep(i, l, r) for (int i = l, END##i = r; i <= END##i; ++i)
#define per(i, r, l) for (int i = r, END##i = l; i >= END##i; --i)
#define pb push_back
#define mp make_pair
#define int long long
#define pii pair<int, int>
#define ps second
#define pf first

// #define innt int
#define itn int
// #define inr intw
// #define mian main
// #define iont int

#define rd read()
int read(){
    int xx = 0, ff = 1;
    char ch = getchar();
    while (ch < '0' || ch > '9') {
        if (ch == '-')
            ff = -1;
        ch = getchar();
    }
    while (ch >= '0' && ch <= '9')
      xx = xx * 10 + (ch - '0'), ch = getchar();
    return xx * ff;
}


void write(int out) {
    if (out < 0)
        putchar('-'), out = -out;
    if (out > 9)
        write(out / 10);
    putchar(out % 10 + '0');
}



#define ell dbg('\n')
const char el='\n';
const bool enable_dbg = 1;
template <typename T,typename... Args>
void dbg(T s,Args... args) {
    if constexpr (enable_dbg){
    cerr << s;
    if(1)cerr<<' ';
        if constexpr (sizeof...(Args))
            dbg(args...);
    }
}



#define zerol = 1
#ifdef zerol
#define cdbg(x...) do { cerr << #x << " -> "; err(x); } while (0)
void err() { cerr << endl; }
template<template<typename...> class T, typename t, typename... A>
void err(T<t> a, A... x) { for (auto v: a) cerr << v << ' '; err(x...); }
template<typename T, typename... A>
void err(T a, A... x) { cerr << a << ' '; err(x...); }
#else
#define dbg(...)
#endif


const int MAXN=4e4;
const int N=4e4+5;
const int M=205;


int n,Q,a[N],key[N],uni[N],num=0;
int h[N],to[MAXN*2+5],nxt[MAXN*2+5],tot=0;


int id[N],pcnt=0,B,buc[N];
int dis[N],dep[N],fa[N][16+2];
bitset<N> col[21000],vis;


void add(int u,int v){
    to[++tot]=v;nxt[tot]=h[u];h[u]=tot;
    to[++tot]=u;nxt[tot]=h[v];h[v]=tot;
}

void dfs(int x,int f){
    fa[x][0]=f;
    for(int e=h[x];e;e=nxt[e]){
        int y=to[e];
        if(y==f)continue;

        dep[y]=dep[x]+1;
        dfs(y,x);dis[x]=max(dis[x],dis[y]+1);
    } if(dep[x]%B==0&&dis[x]>=B) id[x]=++pcnt;
}


int Bcol[M][M],cc=0;

void dfs(int x,int f,int rt){
    buc[a[x]]++;if(buc[a[x]]==1) vis.set(a[x]);
    if(id[x]) col[Bcol[id[rt]][id[x]]]=vis;
    for(int e=h[x];e;e=nxt[e]){
        int y=to[e];
        if(y==f)continue;
        dfs(y,x,rt);
    } buc[a[x]]--;if(!buc[a[x]]) vis.reset(a[x]);
}


int lca(int x,int y){
    if(dep[x]<dep[y]) swap(x,y);
    for(int i=16;~i;i--) if(dep[x]-(1<<i)>=dep[y]) x=fa[x][i];
    if(x==y) return x;
    for(int i=16;~i;i--) if(fa[x][i]^fa[y][i]) x=fa[x][i],y=fa[y][i];
    return fa[x][0];
}


int jump(int x){
    while(x){
        if(id[x]) return x;
        x=fa[x][0];
    } 
    return 0;
}


int get_kanc(int x,int k){
    for(int i=16;~i;i--) if(k>>i&1) x=fa[x][i];
    return x;
}


int lstans;


void solve(int x,int y){
    int l=lca(x,y),fx=jump(x),fy=jump(y);
    if(dep[fx]<dep[l]&&dep[fy]<dep[l]){
        while(x!=l) vis.set(a[x]),x=fa[x][0];
        while(y!=l) vis.set(a[y]),y=fa[y][0];
        vis.set(a[l]);

        lstans=vis.count();
        write(lstans);
        puts("");
    } else if(dep[fx]>=dep[l]&&dep[fy]>=dep[l]){

        vis=col[Bcol[id[fx]][id[fy]]];
        while(x!=fx) vis.set(a[x]),x=fa[x][0];
        while(y!=fy) vis.set(a[y]),y=fa[y][0];

        lstans=vis.count();
        write(lstans);
        puts("");
    } else{

        if(dep[fy]>=dep[l]) swap(x,y),swap(fx,fy);

        int z=get_kanc(x,max(dep[x]-dep[l]-(B<<1|1),0ll));
        int near=-1;
        while(z!=l){
            if(id[z]) near=z;
            z=fa[z][0];
        } if(id[l]) near=l;

        if(fx!=near) vis=col[Bcol[id[fx]][id[near]]];
        while(x!=fx) vis.set(a[x]),x=fa[x][0];
        while(near!=l) vis.set(a[near]),near=fa[near][0];
        while(y!=l) vis.set(a[y]),y=fa[y][0];
        vis.set(a[l]);

        lstans=vis.count();
        write(lstans);
        puts("");
    }
}

void solve(){
    n=rd,Q=rd;

    for(int i=1;i<=n;i++) {
        a[i]=rd,key[i]=a[i];
    }
    key[0]=-1;

    sort(key+1,key+n+1);
    for(int i=1;i<=n;i++) 
        if(key[i]^key[i-1]) uni[++num]=key[i];//
    for(int i=1;i<=n;i++) 
        a[i]=lower_bound(uni+1,uni+num+1,a[i])-uni;
    for(int i=1,u,v;i<n;i++) 
        add(rd,rd);
    B=sqrt(n);

    dfs(1,0);

    dep[0]=-1;
    for(int i=1;i<=pcnt;i++) 
        for(int j=1;j<=i;j++) 
            Bcol[i][j]=Bcol[j][i]=++cc;
    for(int i=1;i<=n;i++) {
        if(id[i]) memset(buc,0,sizeof(buc));
        dfs(i,0,i);
    }
    for(int i=1;i<=16;i++) 
        for(int j=1;j<=n;j++)
            fa[j][i]=fa[fa[j][i-1]][i-1];


    while(Q--){
        int x=rd,y=rd;

        x^=lstans;
        vis.reset();
        solve(x,y);

    }
}


signed main() {
    // freopen(".in","r",stdin);
    // freopen(".in","w",stdout);

    int T=1;
    while(T--){
        solve();
    }
    return 0;
}