跳转至

D12 Luogu P3384【模板】轻重链剖分/树链剖分_哔哩哔哩_bilibili

树链剖分

重链

image.png

注意,重儿子只能有**一个**

图中标绿的路径即为重链,注意,每个叶子节点都是一条特殊的重链

注意,只要子节点是重儿子,它就是重链,不对父节点有特殊要求

特殊性质:

  • 所有重链所包含的节点相加,正好是所有的节点,不重不漏

  • 注意标红的结论3:eg.对于路径5-2-1-4-8-12,就剖分为了4条重链

使用重链剖分解决LCA问题

数组 fa[u]:存 u 的父结点 dep[u]:存 u 的深度 son[u]:存 u 的重儿子 sz[u]:存以 u 为根的子树的结点数 top[u]:存 u 所在重链的顶点 流程 1.第一遍 dfs,搞出 fa,dep,son 数组 2.第二遍 dfs,搞出 top 数组 3.让两个游标沿着各自的重链向上跳,跳到同一条重链上时,深度较小的那个游标所指向的点,就是 LCA

未验证的代码

/*////////ACACACACACACAC///////////
       . Code by Ntsc .
       . WHY NOT????? .
/*////////ACACACACACACAC///////////

#include<bits/stdc++.h>
#define ll long long
#define db double
#define rtn return
using namespace std;

const int N=1e5;
const int M=1e5;
const int Mod=1e5;
const int INF=1e5;

int sz[N],fa[N],dep[N],son[N],top[N];
int n,Q;
vector<int> e[N];

void add(int a,int b){
    e[a].push_back(b);
}

void dfs1(int u,int faa){//初始值1,0 
    fa[u]=faa;dep[u]=dep[faa]+1;sz[u]=1;//更新fa,dep,设置sz初始值 
    for(int i=0;i<e[u].size();i++){
        int v=e[u][i];
        if(v==faa)continue;
        dfs1(v,u);//注意不是dfs1(v,faa) 
        sz[u]+=sz[v];//回溯,sz[u]原本为1 ,加上已经处理了v的子树大小
        if(sz[v]>sz[son[u]])son[u]=v;//更新重儿子,son[u]存到是之前已经扫描过的u的儿子中最重的那个,把它的sz与v的sz比对 
    }   
    return ;
} 

void dfs2(int u,int t){
    top[u]=t;
    if(!son[u])return;//没有重儿子,那么说明u为叶子节点. 
    dfs2(son[u],t);//先走重儿子,继承u的top(即t) 
    for(int i=0;i<e[i].size();i++){//扫描轻儿子 
        int v=e[u][i];
        if(v==fa[u]||v==son[u])continue;//筛选 
        dfs2(v,v);//轻儿子不能继承t,只能开一条新的重链 
    }
    return ;
}

int lca(int u,int v){
    while(top[u]!=top[v]){
        if(dep[top[u]]<dep[top[v]])swap(u,v);//交换,使得u所在的重链top恒比v的深 ,即保证是更深的在往上跳而不是浅的一直在往上跳 
        u=fa[top[u]];
    }
    if(dep[u]>dep[v])return v;
    else return u;//返回更浅的那个,就是原来u,v的LCA 
}
signed main(){
    cin>>n>>Q;
    for(int i=1;i<=n;i++){
        int a,b;
        cin>>a>>b;
        add(a,b);add(b,a);
    }
    //......
    return 0;
}

例题 #1 树剖的简单应用 | Qtree3

题目描述

给出 \(N\) 个点的一棵树(\(N-1\) 条边),节点有白有黑,初始全为白。

有两种操作:

0 i:改变某点的颜色(原来是黑的变白,原来是白的变黑)。

1 v:询问 \(1\)\(v\) 的路径上的第一个黑点,若无,输出 \(-1\)

对于 \(1/3\) 的数据有 \(N=5000,Q=400000\)

对于 \(1/3\) 的数据有 \(N=10000,Q=300000\)

对于 \(1/3\) 的数据有 \(N=100000, Q=100000\)

此外,有\(1 \le i,v \le N\)


/*                                                                                
                      Keyblinds Guide
                    ###################
      @Ntsc 2024

      - Ctrl+Alt+G then P : Enter luogu problem details
      - Ctrl+Alt+B : Run all cases in CPH
      - ctrl+D : choose this and dump to the next
      - ctrl+Shift+L : choose all like this
      - ctrl+K then ctrl+W: close all
      - Alt+la/ra : move mouse to pre/nxt pos'

*/
#include <bits/stdc++.h>
#include <queue>
using namespace std;

#define rep(i, l, r) for (int i = l, END##i = r; i <= END##i; ++i)
#define per(i, r, l) for (int i = r, END##i = l; i >= END##i; --i)
#define pb push_back
#define mp make_pair
#define int long long
#define ull unsigned long long
#define pii pair<int, int>
#define ps second
#define pf first

// #define innt int
#define itn int
// #define inr intw
// #define mian main
// #define iont int

#define rd read()
int read(){
    int xx = 0, ff = 1;
    char ch = getchar();
    while (ch < '0' || ch > '9') {
        if (ch == '-')
            ff = -1;
        ch = getchar();
    }
    while (ch >= '0' && ch <= '9')
      xx = xx * 10 + (ch - '0'), ch = getchar();
    return xx * ff;
}
void write(int out) {
    if (out < 0)
        putchar('-'), out = -out;
    if (out > 9)
        write(out / 10);
    putchar(out % 10 + '0');
}

#define ell dbg('\n')
const char el='\n';
const bool enable_dbg = 1;
template <typename T,typename... Args>
void dbg(T s,Args... args) {
    if constexpr (enable_dbg){
    cerr << s;
    if(1)cerr<<' ';
        if constexpr (sizeof...(Args))
            dbg(args...);
    }
}

#define zerol = 1
#ifdef zerol
#define cdbg(x...) do { cerr << #x << " -> "; err(x); } while (0)
void err() { cerr << endl; }
template<template<typename...> class T, typename t, typename... A>
void err(T<t> a, A... x) { for (auto v: a) cerr << v << ' '; err(x...); }
template<typename T, typename... A>
void err(T a, A... x) { cerr << a << ' '; err(x...); }
#else
#define dbg(...)
#endif


const int N = 5e5 + 5;
const int INF = 1e18;
const int M = 1e7;
const int MOD = 1e9 + 7;

vector<int> e[N];

int sz[N];
int dep[N];
int son[N];
itn fa[N];

int _dfn[N],tid[N],top[N];
itn tim;


void add(itn x,int y){
    e[x].pb(y);
    e[y].pb(x);
}

void dfs(int x,int f){
    sz[x]=1;
    dep[x]=dep[f]+1;

    for(auto v:e[x]){
        if(v==f)continue;
        fa[v]=x;
        dfs(v,x);
        sz[x]+=sz[v];
        if(!son[x]||sz[son[x]]<sz[v])son[x]=v;
    }
}
set<int> st[N];


void query(int x){
    int ans=INF;
    while(x){
        itn k=*st[top[x]].begin();
        if(st[top[x]].size()){
            if(dep[_dfn[k]]<=dep[x])ans=_dfn[k];
        }
        x=fa[top[x]];// !!
    }
    write(ans==INF?-1:ans);
    puts("");
}

int col[N];

void change(int x){
    col[x]^=1;
    if(col[x])st[top[x]].insert(tid[x]);
    else st[top[x]].erase(tid[x]);
}

void dfs2(itn x,int t){
    _dfn[++tim]=x;
    tid[x]=tim;
    top[x]=t;
    if(!son[x])return ;
    dfs2(son[x],t);
    for(auto v:e[x]){
        if(v==fa[x]||v==son[x])continue;
        dfs2(v,v);
    }

}

void solve(){
    itn n=rd,q=rd;
    for(itn i=1;i<n;i++){
        add(rd,rd);
    }


    dfs(1,0);
    dfs2(1,1);

    // cerr<<"OK";
    while(q--){
        int op=rd,x=rd;
        if(op){
            query(x);
        }else{
            change(x);
        }
    }
}

signed main() {
    // freopen(".in","r",stdin);
    // freopen(".in","w",stdout);

    int T=1;
    while(T--){
        solve();
    }
    return 0;
}

重链剖分+线段树实现树上修改与查询

重链剖分部分

324 树上修改与查询 树链剖分_哔哩哔哩_bilibili

top[u]: 存 u 所在重链的顶点

id[u]: 存 u 剖分后的新编号

nw[cnt]: 存新编号在树中所对应节点的权值

在剖分之后我们将每个节点都一一映射一个新的编号。按照代码中dfs2的顺序,一条一条重链上的点深度由浅入深依次编号。

直观图,新的编号如图中的标红数字

image.png

代码

int sz[N],fa[N],dep[N],son[N],top[N];
int n,Q,cnt;
vector<int> e[N];

void add(int a,int b){
    e[a].push_back(b);
}

void dfs1(int u,int faa){//初始值1,0 
    fa[u]=faa;dep[u]=dep[faa]+1;sz[u]=1;//更新fa,dep,设置sz初始值 
    for(int i=0;i<e[u].size();i++){
        int v=e[u][i];
        if(v==faa)continue;
        dfs1(v,u);//注意不是dfs1(v,faa) 
        sz[u]+=sz[v];//回溯,sz[u]原本为1 ,加上已经处理了v的子树大小
        if(sz[v]>sz[son[u]])son[u]=v;//更新重儿子,son[u]存到是之前已经扫描过的u的儿子中最重的那个,把它的sz与v的sz比对 
    }   
    return ;
} 

void dfs2(int u,int t){
    top[u]=t;id[u]=++cnt;nw[cnt]=w[u];
    if(!son[u])return;//没有重儿子,那么说明u为叶子节点. 
    dfs2(son[u],t);//先走重儿子,继承u的top(即t) 
    for(int i=0;i<e[i].size();i++){//扫描轻儿子 
        int v=e[u][i];
        if(v==fa[u]||v==son[u])continue;//筛选 
        dfs2(v,v);//轻儿子不能继承t,只能开一条新的重链 
    }
    return ;
}

线段树部分

我们将原来的树映射为序列,然后使用线段树来维护这段序列

代码

struct tree{
    int l,r;//左右节点代表的区间
    ll add,sum;//标记与区间和
}tr[N<<2];

void pushup(int u){
    tr[u].sum=tr[u*2].sum+tr[u*2+1].sum;
}
void build(int u,int l,int r){
    tr[u]={l,r,0,nw[r]};
    if(l==r)return ;
    int mid=l+r;
    build(u*2,l,mid);
    build(u*2+1,mid+1,r);
    pushup(u);//回溯时更新
}

查询

目的:求树上从x到y最短路径上的节点和

举个例子

image.png

我们要查询树上节点10-12的路径权值和,首先使用类似LCA算法找到路径经过了那几条重链,分别是10-6-2,1-4-8,12

然后映射到线段树中,分别是区间[7,9][1,3][5]

我们使用线段树快速求出这些区间和即可,注意,其中的[7,9][5]都是一段完整的重链,我们在LCA时已经知道了其l,r,直接线段树即可,但[1,3]不是一段完整的重链。幸好,在LCA算法结束时,u=8,v=1,映射后就是3和1.因此在LCA结束后我们还需要把ans加上这一条不完整的链

int query_lca(int u,int v){
    ll res=0;
    while(top[u]!=top[v]){
        if(dep[top[u]]<dep[top[v]])swap(u,v);//交换,使得u所在的重链top恒比v的深 ,即保证是更深的在往上跳而不是浅的一直在往上跳 
        res+=query(id[top[u]],id[u]);//注意id[top[u]]<=id[u]
        u=fa[top[u]];
    }
    if(dep[u]>dep[v])swap(u,v);
    res+=query(id[u],id[v])
    return res;
}

修改

目的:将树上从x到y的最短路径上每个节点都加上k

代码和上方query_lca基本上一模一样。注意这里省略了**线段树**相关代码如:

  • update()change()

  • query()

  • pushdown()懒标记

请自我回顾!

void updt_lca(int u,int v,int k){
    while(top[u]!=top[v]){
        if(dep[top[u]]<dep[top[v]])swap(u,v);//交换,使得u所在的重链top恒比v的深 ,即保证是更深的在往上跳而不是浅的一直在往上跳 
        update(1,id[top[u]],id[u],k);//注意id[top[u]]<=id[u]
        u=fa[top[u]];
    }
    if(dep[u]>dep[v])swap(u,v);
    update(1,id[u],id[v],k);
    return ;
}

例题

www.luogu.com.cn

如题,已知一棵包含 \(N\) 个结点的树(连通且无环),每个节点上包含一个数值,需要支持以下操作:

  • 1 x y z,表示将树从 \(x\)\(y\) 结点最短路径上所有节点的值都加上 \(z\)

  • 2 x y,表示求树从 \(x\)\(y\) 结点最短路径上所有节点的值之和。

  • 3 x z,表示将以 \(x\) 为根节点的子树内所有节点值都加上 \(z\)

  • 4 x 表示求以 \(x\) 为根节点的子树内所有节点值之和

对于 \(100\%\) 的数据: \(1\le N \leq {10}^5\)\(1\le M \leq {10}^5\)\(1\le R\le N\)\(1\le P \le 2^{31}-1\)

补充说明:观察可知,重链剖分后任意一棵子树的所有点会被映射到一个连续的区间中。这一点和dfs序颇有类似。

/*////////ACACACACACACAC///////////
       . Code  by  Ntsc .
       . Earn knowledge .
/*////////ACACACACACACAC///////////

#include<bits/stdc++.h>
#define int long long
#define db double
#define rtn return
using namespace std;

const int N=1e5+5;
const int M=1e5;
int MOD;
const int INF=1e5;

int n,m,p,q,r,T,s[N],ans;

int sz[N],fa[N],dep[N],son[N],top[N];
int w[N],nw[N],tail[N],id[N],Q,cnt;
int du[N],rt;
vector<int> e[N];

void add(int a,int b){
    e[a].push_back(b);
    e[b].push_back(a);//记得双向边 
    du[b]++;
}

void dfs1(int u,int faa){//初始值1,0 
    fa[u]=faa;dep[u]=dep[faa]+1;sz[u]=1;//更新fa,dep,设置sz初始值 
    for(int i=0;i<e[u].size();i++){
        int v=e[u][i];
        if(v==faa)continue;
        dfs1(v,u);//注意不是dfs1(v,faa) 
        sz[u]+=sz[v];//回溯,sz[u]原本为1 ,加上已经处理了v的子树大小
        if(sz[v]>sz[son[u]])son[u]=v;//更新重儿子,son[u]存到是之前已经扫描过的u的儿子中最重的那个,把它的sz与v的sz比对 
    }   
    return ;
} 

void dfs2(int u,int t){
    top[u]=t;id[u]=++cnt;nw[cnt]=w[u];
    tail[u]=id[u];
    if(!son[u])return;//没有重儿子,那么说明u为叶子节点. 
    dfs2(son[u],t);//先走重儿子,继承u的top(即t) 
    tail[u]=max(id[u],tail[son[u]]);
    for(int i=0;i<e[u].size();i++){//扫描轻儿子 
        int v=e[u][i];
        if(v==fa[u]||v==son[u])continue;//筛选 
        dfs2(v,v);//轻儿子不能继承t,只能开一条新的重链 
        tail[u]=max(id[u],tail[v]);//记录子树区间 
    }
    return ;
}

//SGT
struct tree{
    int l,r;//左右节点代表的区间
    int add,sum;//标记与区间和
}tr[N<<2];

void pushup(int u){
    tr[u].sum=(tr[u*2].sum+tr[u*2+1].sum)%MOD;
}

void addtag(int x,int tg){
    tr[x].add+=tg;
    tr[x].add%=MOD;
    tr[x].sum+=tg*(tr[x].r-tr[x].l+1);
    tr[x].sum%=MOD;
}

void pushdown(int x){
    if(tr[x].add){
        addtag(x<<1,tr[x].add);
        addtag(x<<1|1,tr[x].add);
        tr[x].add=0;
    }
}

void build(int u,int l,int r){
    tr[u]={l,r,0,0};
    if(l==r){
        tr[u]={l,r,0,nw[r]};return ;
    }
    int mid=l+r>>1;
    build(u*2,l,mid);
    build(u*2+1,mid+1,r);
    pushup(u);//回溯时更新
}


void update(int x,int pl,int pr,int k){
    if(tr[x].l>=pl&&tr[x].r<=pr){
        addtag(x,k);
        return ;
    }
    if(tr[x].l>pr||tr[x].r<pl)return ;

    pushdown(x);
    int mid=tr[x].l+tr[x].r>>1;
    if(pl<=mid)update(x<<1,pl,pr,k);
    if(pr>mid)update(x<<1|1,pl,pr,k);

    pushup(x);
}

int query(int x,int pl,int pr){
    if(tr[x].l>=pl&&tr[x].r<=pr){
        return tr[x].sum;
    }


    if(tr[x].l>pr||tr[x].r<pl)return 0;

    pushdown(x);
    int res=0;
    int mid=tr[x].l+tr[x].r>>1;
    if(pl<=mid)res+=query(x<<1,pl,pr);
    res%=MOD;
    if(pr>mid)res+=query(x<<1|1,pl,pr);
    res%=MOD;

    return res;
}

//spou
int query_road(int u,int v){
    int res=0;
    while(top[u]!=top[v]){
        if(dep[top[u]]<dep[top[v]])swap(u,v);//交换,使得u所在的重链top恒比v的深 ,即保证是更深的在往上跳而不是浅的一直在往上跳 
        res+=query(1,id[top[u]],id[u]);//注意id[top[u]]<=id[u]
        u=fa[top[u]];
    }
    if(dep[u]>dep[v])swap(u,v);
    res+=query(1,id[u],id[v]);
    res%=MOD;
    return res;
}

void updt_road(int u,int v,int k){
    while(top[u]!=top[v]){
        if(dep[top[u]]<dep[top[v]])swap(u,v);//交换,使得u所在的重链top恒比v的深 ,即保证是更深的在往上跳而不是浅的一直在往上跳 
        update(1,id[top[u]],id[u],k);//注意id[top[u]]<=id[u]
        u=fa[top[u]];
    }
    if(dep[u]>dep[v])swap(u,v);
    update(1,id[u],id[v],k);
    return ;
}

//debug

void debug(){
//  cerr<<"\ndebug = ";
//  for(int i=1;i<=n;i++){
//      cerr<<query(1,id[i],id[i])<<' ';
//  }
//  cerr<<"debug end\n";
//  
//  cerr<<"node v=";
//  for(int i=1;i<=n;i++)cerr<<nw[id[i]]<<' ';
//  cerr<<"node v end\n";
}


signed main(){
    cin>>n>>m>>rt>>MOD;
    for(int i=1;i<=n;i++){
        cin>>w[i];
    }

    for(int i=1;i<n;i++){
        int a,b;
        cin>>a>>b;
        add(a,b);
    }

//  for(int i=1;i<=n;i++){
//      if(!du[i]){
//          rt=i;break;//找根 
//      }
//  }

    dfs1(rt,0);
    dfs2(rt,rt);

//  cerr<<"sz[]=";
//  for(int i=1;i<=n;i++)cerr<<sz[i]<<' ';
//  cerr<<"sz[] endl\n"; 

//  cerr<<"rt= "<<rt<<endl;

//  for(int i=1;i<=n;i++)cerr<<"node "<<i<<" 's id = "<<id[i]<<" tail = "<<tail[i]<<endl;


//  cerr<<"start building...\n";

    build(1,1,n);

//  cerr<<"build finish\ninit ";
    debug();

    while(m--){
        int op,x,y,z;
        cin>>op>>x;
        if(op==1){
            cin>>y>>z;
            updt_road(x,y,z);
        }if(op==2){
            cin>>y;
            cout<<query_road(x,y)%MOD<<endl;
        }if(op==3){
            cin>>z;
            update(1,id[x],tail[x],z);

        }if(op==4){
            cout<<query(1,id[x],tail[x])%MOD<<endl;
        }

        debug();
    }

    return 0;
}

实链

见 LCT


重链剖分

把树拆分成若干条重链,用线段树这种静态数据结构来维护重链。通过对重链的拆分与组合,构造答案

实链剖分

把树拆分成若干条实链,用 splay 这种动态数据结构来维护实链。通过对实链的拆分与组合,构造答案

构建

一个节点只能选(任意)一个儿子做实儿子,其他都是虚儿子。

实边:父节点与实儿子之间的边,是双向边。

虚边:由虚儿子指向父节点的边,是单向边。(认父不认子,下面会说及是一个splay的根节点指向另外一颗splay中的某个节点)

实链:由实边构成的链。每条实链的节点深度是严格递增的。


摘抄学习笔记 | 树剖(树链剖分)

练习 重链剖分

重链剖分部分

324 树上修改与查询 树链剖分_哔哩哔哩_bilibili

top[u]: 存 u 所在重链的顶点

id[u]: 存 u 剖分后的新编号

nw[cnt]: 存新编号在树中所对应节点的权值

在剖分之后我们将每个节点都一一映射一个新的编号。按照代码中dfs2的顺序,一条一条重链上的点深度由浅入深依次编号。

直观图,新的编号如图中的标红数字

image.png

代码

int sz[N],fa[N],dep[N],son[N],top[N];
int n,Q,cnt;
vector<int> e[N];

void add(int a,int b){
    e[a].push_back(b);
}

void dfs1(int u,int faa){//初始值1,0 
    fa[u]=faa;dep[u]=dep[faa]+1;sz[u]=1;//更新fa,dep,设置sz初始值 
    for(int i=0;i<e[u].size();i++){
        int v=e[u][i];
        if(v==faa)continue;
        dfs1(v,u);//注意不是dfs1(v,faa) 
        sz[u]+=sz[v];//回溯,sz[u]原本为1 ,加上已经处理了v的子树大小
        if(sz[v]>sz[son[u]])son[u]=v;//更新重儿子,son[u]存到是之前已经扫描过的u的儿子中最重的那个,把它的sz与v的sz比对 
    }   
    return ;
} 

void dfs2(int u,int t){
    top[u]=t;id[u]=++cnt;nw[cnt]=w[u];
    if(!son[u])return;//没有重儿子,那么说明u为叶子节点. 
    dfs2(son[u],t);//先走重儿子,继承u的top(即t) 
    for(int i=0;i<e[i].size();i++){//扫描轻儿子 
        int v=e[u][i];
        if(v==fa[u]||v==son[u])continue;//筛选 
        dfs2(v,v);//轻儿子不能继承t,只能开一条新的重链 
    }
    return ;
}

线段树部分

我们将原来的树映射为序列,然后使用线段树来维护这段序列

代码

struct tree{
    int l,r;//左右节点代表的区间
    ll add,sum;//标记与区间和
}tr[N<<2];

void pushup(int u){
    tr[u].sum=tr[u*2].sum+tr[u*2+1].sum;
}
void build(int u,int l,int r){
    tr[u]={l,r,0,nw[r]};
    if(l==r)return ;
    int mid=l+r;
    build(u*2,l,mid);
    build(u*2+1,mid+1,r);
    pushup(u);//回溯时更新
}

查询

目的:求树上从x到y最短路径上的节点和

举个例子

image.png

我们要查询树上节点10-12的路径权值和,首先使用类似LCA算法找到路径经过了那几条重链,分别是10-6-2,1-4-8,12

然后映射到线段树中,分别是区间[7,9][1,3][5]

我们使用线段树快速求出这些区间和即可,注意,其中的[7,9][5]都是一段完整的重链,我们在LCA时已经知道了其l,r,直接线段树即可,但[1,3]不是一段完整的重链。幸好,在LCA算法结束时,u=8,v=1,映射后就是3和1.因此在LCA结束后我们还需要把ans加上这一条不完整的链

int query_lca(int u,int v){
    ll res=0;
    while(top[u]!=top[v]){
        if(dep[top[u]]<dep[top[v]])swap(u,v);//交换,使得u所在的重链top恒比v的深 ,即保证是更深的在往上跳而不是浅的一直在往上跳 
        res+=query(id[top[u]],id[u]);//注意id[top[u]]<=id[u]
        u=fa[top[u]];
    }
    if(dep[u]>dep[v])swap(u,v);
    res+=query(id[u],id[v])
    return res;
}

修改

目的:将树上从x到y的最短路径上每个节点都加上k

代码和上方query_lca基本上一模一样。注意这里省略了**线段树**相关代码如:

  • update()change()

  • query()

  • pushdown()懒标记

请自我回顾!

void updt_lca(int u,int v,int k){
    while(top[u]!=top[v]){
        if(dep[top[u]]<dep[top[v]])swap(u,v);//交换,使得u所在的重链top恒比v的深 ,即保证是更深的在往上跳而不是浅的一直在往上跳 
        update(1,id[top[u]],id[u],k);//注意id[top[u]]<=id[u]
        u=fa[top[u]];
    }
    if(dep[u]>dep[v])swap(u,v);
    update(1,id[u],id[v],k);
    return ;
}

练习 #1 [NOI2015] 软件包管理器

题目背景

Linux 用户和 OSX 用户一定对软件包管理器不会陌生。通过软件包管理器,你可以通过一行命令安装某一个软件包,然后软件包管理器会帮助你从软件源下载软件包,同时自动解决所有的依赖(即下载安装这个软件包的安装所依赖的其它软件包),完成所有的配置。Debian/Ubuntu 使用的 apt-get,Fedora/CentOS 使用的 yum,以及 OSX 下可用的 homebrew 都是优秀的软件包管理器。

题目描述

你决定设计你自己的软件包管理器。不可避免地,你要解决软件包之间的依赖问题。如果软件包 \(a\) 依赖软件包 \(b\),那么安装软件包 \(a\) 以前,必须先安装软件包 \(b\)。同时,如果想要卸载软件包 \(b\),则必须卸载软件包 \(a\)

现在你已经获得了所有的软件包之间的依赖关系。而且,由于你之前的工作,除 \(0\) 号软件包以外,在你的管理器当中的软件包都会依赖一个且仅一个软件包,而 \(0\) 号软件包不依赖任何一个软件包。且依赖关系不存在环(即不会存在 \(m\) 个软件包 \(a_1,a_2, \dots , a_m\),对于 \(i<m\)\(a_i\) 依赖 \(a_{i+1}\),而 \(a_m\) 依赖 \(a_1\) 的情况)。

现在你要为你的软件包管理器写一个依赖解决程序。根据反馈,用户希望在安装和卸载某个软件包时,快速地知道这个操作实际上会改变多少个软件包的安装状态(即安装操作会安装多少个未安装的软件包,或卸载操作会卸载多少个已安装的软件包),你的任务就是实现这个部分。

注意,安装一个已安装的软件包,或卸载一个未安装的软件包,都不会改变任何软件包的安装状态,即在此情况下,改变安装状态的软件包数为 \(0\)


思路

考虑暴力,很显然是一棵树,并且每个软件包的依赖就是其父亲。那么首先我们知道0号软件是树根。我们先按照依赖建树,给每个节点标记vis为是否安装。起初除0以外的点的vis都是0。

  • 如果我们要安装一个软件v,那么答案就是从v到0的路径上vis为0的点数。

  • 如果我们要卸载一个软件v,那么答案就是v的子树中所有vis为1的点数。

并且在操作完成后,记得更新被修改点的状态。

首先我们在每个点维护子树中被选中的点数cnt。那么对于安装我们可以O(\log n)解决,即在往上走的过程中维护vis和cnt和ans。

问题在于对于卸载操作,我们需要把一条链上的cnt同时减去v的子树大小,并且把v的子树内节点的信息全部置0。前面的简单,但是后面的最劣是O(n)的。

对于维护答案,安装操作的答案就是路径上0的个数,卸载操作的答案就是子树中1的个数。

好了,到此为止,很像树剖的板子。但是树剖是区间加,这里是区间修改,所以还需要一些修改。

树剖就是把树划分成若干个区间然后用类似线段树维护。那么本题就是线段树的区间赋值区间查询问题,小小修改一下addtag函数即可。

注意点编号的偏移!


acwing:TLE

luogu:AC

/*////////ACACACACACACAC///////////
       . Code  by  Ntsc .
       . Earn knowledge .
/*////////ACACACACACACAC///////////

#include<bits/stdc++.h>
#define int long long
#define db double
#define rtn return
using namespace std;

const int N=5e5+5;
const int M=1e5;
const int INF=1e5;

int n,m,p,q,r,T,s[N],ans;

int sz[N],fa[N],dep[N],son[N],top[N];
int w[N],nw[N],tail[N],id[N],Q,cnt;
int du[N],rt;
vector<int> e[N];

void add(int a,int b){
    e[a].push_back(b);
    e[b].push_back(a);//记得双向边 
    du[b]++;
}

void dfs1(int u,int faa){//初始值1,0 
    fa[u]=faa;dep[u]=dep[faa]+1;sz[u]=1;//更新fa,dep,设置sz初始值 
    for(int i=0;i<e[u].size();i++){
        int v=e[u][i];
        if(v==faa)continue;
        dfs1(v,u);//注意不是dfs1(v,faa) 
        sz[u]+=sz[v];//回溯,sz[u]原本为1 ,加上已经处理了v的子树大小
        if(sz[v]>sz[son[u]])son[u]=v;//更新重儿子,son[u]存到是之前已经扫描过的u的儿子中最重的那个,把它的sz与v的sz比对 
    }   
    return ;
} 

void dfs2(int u,int t){
    top[u]=t;id[u]=++cnt;nw[cnt]=w[u];
    tail[u]=id[u];
    if(!son[u])return;//没有重儿子,那么说明u为叶子节点. 
    dfs2(son[u],t);//先走重儿子,继承u的top(即t) 
    tail[u]=max(id[u],tail[son[u]]);
    for(int i=0;i<e[u].size();i++){//扫描轻儿子 
        int v=e[u][i];
        if(v==fa[u]||v==son[u])continue;//筛选 
        dfs2(v,v);//轻儿子不能继承t,只能开一条新的重链 
        tail[u]=max(id[u],tail[v]);//记录子树区间 
    }
    return ;
}

//SGT
struct tree{
    int l,r;//左右节点代表的区间
    int add,sum;//标记与区间和
}tr[N<<2];

void pushup(int u){
    tr[u].sum=(tr[u*2].sum+tr[u*2+1].sum);
}

void addtag(int x,int tg){
    tr[x].add=tg;
    tr[x].sum=tg*(tr[x].r-tr[x].l+1);
}

void pushdown(int x){
    if(-1!=tr[x].add){
        addtag(x<<1,tr[x].add);
        addtag(x<<1|1,tr[x].add);
        tr[x].add=-1;
    }
}

void build(int u,int l,int r){
    // if(l>r)return ;
    cerr<<u<<' ';
    tr[u]={l,r,-1,0};
    if(l==r){
        // tr[u]={l,r,-1,nw[r]};
        return ;
    }
    int mid=l+r>>1;
    build(u*2,l,mid);
    build(u*2+1,mid+1,r);
    pushup(u);//回溯时更新
}


void update(int x,int pl,int pr,int k){
    if(tr[x].l>=pl&&tr[x].r<=pr){
        addtag(x,k);
        return ;
    }
    if(tr[x].l>pr||tr[x].r<pl)return ;

    pushdown(x);
    int mid=tr[x].l+tr[x].r>>1;
    if(pl<=mid)update(x<<1,pl,pr,k);
    if(pr>mid)update(x<<1|1,pl,pr,k);

    pushup(x);
}

int query(int x,int pl,int pr){
    if(tr[x].l>=pl&&tr[x].r<=pr){
        return tr[x].sum;
    }


    if(tr[x].l>pr||tr[x].r<pl)return 0;

    pushdown(x);
    int res=0;
    int mid=tr[x].l+tr[x].r>>1;
    if(pl<=mid)res+=query(x<<1,pl,pr);
    if(pr>mid)res+=query(x<<1|1,pl,pr);

    return res;
}

//spou
int query_road(int u,int v){
    int res=0;
    while(top[u]!=top[v]){
        if(dep[top[u]]<dep[top[v]])swap(u,v);//交换,使得u所在的重链top恒比v的深 ,即保证是更深的在往上跳而不是浅的一直在往上跳 
        res+=query(1,id[top[u]],id[u]);//注意id[top[u]]<=id[u]
        u=fa[top[u]];
    }
    if(dep[u]>dep[v])swap(u,v);
    res+=query(1,id[u],id[v]);
    return res;
}

void updt_road(int u,int v,int k){
    while(top[u]!=top[v]){
        if(dep[top[u]]<dep[top[v]])swap(u,v);//交换,使得u所在的重链top恒比v的深 ,即保证是更深的在往上跳而不是浅的一直在往上跳 
        update(1,id[top[u]],id[u],k);//注意id[top[u]]<=id[u]
        u=fa[top[u]];
    }
    if(dep[u]>dep[v])swap(u,v);
    update(1,id[u],id[v],k);
    return ;
}


signed main(){

    ios::sync_with_stdio(false);
    cin.tie(0),cout.tie(0);//防止tle

    cin>>n;
    rt=1;


    for(int i=1;i<n;i++){
        int a;
        cin>>a;
        add(a+1,i+1);
    }

    dfs1(rt,0);
    dfs2(rt,rt);

    build(1,1,n);

    cin>>m;

    while(m--){
        int x,y,z;
        string op;
        cin>>op>>x;
        x++;//注意点的偏移!!
        if(op=="install"){
            cout<<dep[x]-query_road(1,x)<<endl;
            updt_road(1,x,1);
        }else{
            cout<<query(1,id[x],tail[x])<<endl;
            update(1,id[x],tail[x],0);
        }

    }

    return 0;
}

https://cdn.luogu.com.cn/upload/pic/1505.png

练习 #2【模板】重链剖分/树链剖分

题目描述

如题,已知一棵包含 \(N\) 个结点的树(连通且无环),每个节点上包含一个数值,需要支持以下操作:

  • 1 x y z,表示将树从 \(x\)\(y\) 结点最短路径上所有节点的值都加上 \(z\)

  • 2 x y,表示求树从 \(x\)\(y\) 结点最短路径上所有节点的值之和。

  • 3 x z,表示将以 \(x\) 为根节点的子树内所有节点值都加上 \(z\)

  • 4 x 表示求以 \(x\) 为根节点的子树内所有节点值之和


代码

/*////////ACACACACACACAC///////////
       . Code  by  Ntsc .
       . Earn knowledge .
/*////////ACACACACACACAC///////////

#include<bits/stdc++.h>
#define int long long
#define db double
#define rtn return
using namespace std;

const int N=1e5+5;
const int M=1e5;
int MOD;
const int INF=1e5;

int n,m,p,q,r,T,s[N],ans;

int sz[N],fa[N],dep[N],son[N],top[N];
int w[N],nw[N],tail[N],id[N],Q,cnt;
int du[N],rt;
vector<int> e[N];

void add(int a,int b){
    e[a].push_back(b);
    e[b].push_back(a);//记得双向边 
    du[b]++;
}

void dfs1(int u,int faa){//初始值1,0 
    fa[u]=faa;dep[u]=dep[faa]+1;sz[u]=1;//更新fa,dep,设置sz初始值 
    for(int i=0;i<e[u].size();i++){
        int v=e[u][i];
        if(v==faa)continue;
        dfs1(v,u);//注意不是dfs1(v,faa) 
        sz[u]+=sz[v];//回溯,sz[u]原本为1 ,加上已经处理了v的子树大小
        if(sz[v]>sz[son[u]])son[u]=v;//更新重儿子,son[u]存到是之前已经扫描过的u的儿子中最重的那个,把它的sz与v的sz比对 
    }   
    return ;
} 

void dfs2(int u,int t){
    top[u]=t;id[u]=++cnt;nw[cnt]=w[u];
    tail[u]=id[u];
    if(!son[u])return;//没有重儿子,那么说明u为叶子节点. 
    dfs2(son[u],t);//先走重儿子,继承u的top(即t) 
    tail[u]=max(id[u],tail[son[u]]);
    for(int i=0;i<e[u].size();i++){//扫描轻儿子 
        int v=e[u][i];
        if(v==fa[u]||v==son[u])continue;//筛选 
        dfs2(v,v);//轻儿子不能继承t,只能开一条新的重链 
        tail[u]=max(id[u],tail[v]);//记录子树区间 
    }
    return ;
}

//SGT
struct tree{
    int l,r;//左右节点代表的区间
    int add,sum;//标记与区间和
}tr[N<<2];

void pushup(int u){
    tr[u].sum=(tr[u*2].sum+tr[u*2+1].sum)%MOD;
}

void addtag(int x,int tg){
    tr[x].add+=tg;
    tr[x].add%=MOD;
    tr[x].sum+=tg*(tr[x].r-tr[x].l+1);
    tr[x].sum%=MOD;
}

void pushdown(int x){
    if(tr[x].add){
        addtag(x<<1,tr[x].add);
        addtag(x<<1|1,tr[x].add);
        tr[x].add=0;
    }
}

void build(int u,int l,int r){
    tr[u]={l,r,0,0};
    if(l==r){
        tr[u]={l,r,0,nw[r]};return ;
    }
    int mid=l+r>>1;
    build(u*2,l,mid);
    build(u*2+1,mid+1,r);
    pushup(u);//回溯时更新
}


void update(int x,int pl,int pr,int k){
    if(tr[x].l>=pl&&tr[x].r<=pr){
        addtag(x,k);
        return ;
    }
    if(tr[x].l>pr||tr[x].r<pl)return ;

    pushdown(x);
    int mid=tr[x].l+tr[x].r>>1;
    if(pl<=mid)update(x<<1,pl,pr,k);
    if(pr>mid)update(x<<1|1,pl,pr,k);

    pushup(x);
}

int query(int x,int pl,int pr){
    if(tr[x].l>=pl&&tr[x].r<=pr){
        return tr[x].sum;
    }


    if(tr[x].l>pr||tr[x].r<pl)return 0;

    pushdown(x);
    int res=0;
    int mid=tr[x].l+tr[x].r>>1;
    if(pl<=mid)res+=query(x<<1,pl,pr);
    res%=MOD;
    if(pr>mid)res+=query(x<<1|1,pl,pr);
    res%=MOD;

    return res;
}

//spou
int query_road(int u,int v){
    int res=0;
    while(top[u]!=top[v]){
        if(dep[top[u]]<dep[top[v]])swap(u,v);//交换,使得u所在的重链top恒比v的深 ,即保证是更深的在往上跳而不是浅的一直在往上跳 【注意不是比较dep[u],dep[v]】
        res+=query(1,id[top[u]],id[u]);//注意id[top[u]]<=id[u]
        u=fa[top[u]];
    }
    if(dep[u]>dep[v])swap(u,v);
    res+=query(1,id[u],id[v]);
    res%=MOD;
    return res;
}

void updt_road(int u,int v,int k){
    while(top[u]!=top[v]){
        if(dep[top[u]]<dep[top[v]])swap(u,v);//交换,使得u所在的重链top恒比v的深 ,即保证是更深的在往上跳而不是浅的一直在往上跳 
        update(1,id[top[u]],id[u],k);//注意id[top[u]]<=id[u]
        u=fa[top[u]];
    }
    if(dep[u]>dep[v])swap(u,v);
    update(1,id[u],id[v],k);
    return ;
}

//debug

void debug(){
//  cerr<<"\ndebug = ";
//  for(int i=1;i<=n;i++){
//      cerr<<query(1,id[i],id[i])<<' ';
//  }
//  cerr<<"debug end\n";
//  
//  cerr<<"node v=";
//  for(int i=1;i<=n;i++)cerr<<nw[id[i]]<<' ';
//  cerr<<"node v end\n";
}


signed main(){
    cin>>n>>m>>rt>>MOD;
    for(int i=1;i<=n;i++){
        cin>>w[i];
    }

    for(int i=1;i<n;i++){
        int a,b;
        cin>>a>>b;
        add(a,b);
    }

//  for(int i=1;i<=n;i++){
//      if(!du[i]){
//          rt=i;break;//找根 
//      }
//  }

    dfs1(rt,0);
    dfs2(rt,rt);

//  cerr<<"sz[]=";
//  for(int i=1;i<=n;i++)cerr<<sz[i]<<' ';
//  cerr<<"sz[] endl\n"; 

//  cerr<<"rt= "<<rt<<endl;

//  for(int i=1;i<=n;i++)cerr<<"node "<<i<<" 's id = "<<id[i]<<" tail = "<<tail[i]<<endl;


//  cerr<<"start building...\n";

    build(1,1,n);

//  cerr<<"build finish\ninit ";
    debug();

    while(m--){
        int op,x,y,z;
        cin>>op>>x;
        if(op==1){
            cin>>y>>z;
            updt_road(x,y,z);
        }if(op==2){
            cin>>y;
            cout<<query_road(x,y)%MOD<<endl;
        }if(op==3){
            cin>>z;
            update(1,id[x],tail[x],z);

        }if(op==4){
            cout<<query(1,id[x],tail[x])%MOD<<endl;
        }

        debug();
    }

    return 0;
}

对于 \(100\%\) 的数据: \(1\le N \leq {10}^5\)\(1\le M \leq {10}^5\)\(1\le R\le N\)\(1\le P \le 2^{31}-1\)。所有输入的数均在 int 范围内。

练习 #3 [ZJOI2008] 树的统计

题目描述

一棵树上有 \(n\) 个节点,编号分别为 \(1\)\(n\),每个节点都有一个权值 \(w\)

我们将以下面的形式来要求你对这棵树完成一些操作:

I. CHANGE u t : 把结点 \(u\) 的权值改为 \(t\)

II. QMAX u v: 询问从点 \(u\) 到点 \(v\) 的路径上的节点的最大权值。

III. QSUM u v: 询问从点 \(u\) 到点 \(v\) 的路径上的节点的权值和。

注意:从点 \(u\) 到点 \(v\) 的路径上的节点包括 \(u\)\(v\) 本身。


思路

裸的模板耶!

如果没有操作II,那么其实我们可以用欧拉序解决这个问题。欧拉序的做法如下:

参考COT2 - Count on a tree II 学习笔记 | 莫队

大概就是先用线段树维护欧拉序,对于每个点

其权值为v,就在第一个欧拉序处将其分治为v,第二个则赋值为-v。这样求x-y路径和就是求x入到y入的区间和

/*
CB Ntsc111
*/

#include <bits/stdc++.h>
using namespace std;

#define ull unsigned int
#define pii pair<int, int>
#define pf to
#define ps second
#define pb push_back
#define int long long

#define err cerr << "Error"
#define rd read()

#define ot write
#define nl putchar('\n')
int read() {
  int xx = 0, ff = 1;
  char ch = getchar();
  while (ch < '0' || ch > '9') {
    if (ch == '-')
      ff = -1;
    ch = getchar();
  }
  while (ch >= '0' && ch <= '9')
    xx = xx * 10 + (ch - '0'), ch = getchar();
  return xx * ff;
}
void write(int out) {
  if (out < 0)
    putchar('-'), out = -out;
  if (out > 9)
    write(out / 10);
  putchar(out % 10 + '0');
}

const int INF = 1e9;
const int N = 2000005;

#define N 100005
#define inf 1000000000
using namespace std;
int n, q, a[4 * N];
struct Edge {
  int u, v, next;
} G[N];
int tot = 0, head[N];
int sz[100005], wson[100005], fa[100005], d[100005], top[100005];
int tpos[100005], pre[100005], cnt = 0;
inline void add(int u, int v) {
  G[++tot].u = u;
  G[tot].v = v;
  G[tot].next = head[u];
  head[u] = tot;
  G[++tot].u = v;
  G[tot].v = u;
  G[tot].next = head[v];
  head[v] = tot;
}
void dfs1(int u, int f) {
  sz[u] = 1;
  for (int i = head[u]; i; i = G[i].next) {
    int v = G[i].v;
    if (v == f)
      continue;
    d[v] = d[u] + 1;
    fa[v] = u;
    dfs1(v, u);
    sz[u] += sz[v];
    if (sz[v] > sz[wson[u]])
      wson[u] = v;
  }
}
void dfs2(int u, int TP) {
  tpos[u] = ++cnt;
  pre[cnt] = u;
  top[u] = TP;
  if (wson[u])
    dfs2(wson[u], TP);
  for (int i = head[u]; i; i = G[i].next) {
    int v = G[i].v;
    if (v == fa[u] || v == wson[u])
      continue;
    dfs2(v, v);
  }
}
int sum[4 * N], mx[4 * N];
inline void pushup(int x) {
  sum[x] = sum[x * 2] + sum[x * 2 + 1];
  mx[x] = max(mx[x * 2], mx[x * 2 + 1]);
}
void build(int x, int l, int r) {
  int mid = (l + r) / 2;
  if (l == r) {
    sum[x] = mx[x] = a[pre[l]];
    return;
  }
  build(x * 2, l, mid);
  build(x * 2 + 1, mid + 1, r);
  pushup(x);
}
void update(int x, int l, int r, int q, int v) {
  int mid = (l + r) / 2;
  if (l == r) {
    sum[x] = mx[x] = v;
    return;
  }
  if (q <= mid)
    update(x * 2, l, mid, q, v);
  else
    update(x * 2 + 1, mid + 1, r, q, v);
  pushup(x);
}
int querysum(int x, int l, int r, int ql, int qr) {
  int mid = (l + r) / 2, ans = 0;
  if (ql <= l && r <= qr)
    return sum[x];
  if (ql <= mid)
    ans += querysum(x * 2, l, mid, ql, qr);
  if (qr > mid)
    ans += querysum(x * 2 + 1, mid + 1, r, ql, qr);
  pushup(x);
  return ans;
}
int querymax(int x, int l, int r, int ql, int qr) {
  int mid = (l + r) / 2, ans = -inf;
  if (ql <= l && r <= qr)
    return mx[x];
  if (ql <= mid)
    ans = max(ans, querymax(x * 2, l, mid, ql, qr));
  if (qr > mid)
    ans = max(ans, querymax(x * 2 + 1, mid + 1, r, ql, qr));
  pushup(x);
  return ans;
}
int Querysum(int u, int v) {
  int ans = 0;
  while (top[u] != top[v]) {
    if (d[top[u]] < d[top[v]])
      swap(u, v);
    ans += querysum(1, 1, n, tpos[top[u]], tpos[u]);
    u = fa[top[u]];
  }
  if (d[u] < d[v])
    swap(u, v);
  ans += querysum(1, 1, n, tpos[v], tpos[u]);
  return ans;
}
int Querymax(int u, int v) {
  int ans = -inf;
  while (top[u] != top[v]) {
    if (d[top[u]] < d[top[v]])
      swap(u, v);
    ans = max(ans, querymax(1, 1, n, tpos[top[u]], tpos[u]));
    u = fa[top[u]];
  }
  if (d[u] < d[v])
    swap(u, v);
  ans = max(ans, querymax(1, 1, n, tpos[v], tpos[u]));
  return ans;
}
signed main() {
  n = rd;
  for (int i = 1; i < n; i++) {
    int u = rd, v = rd;
    add(u, v);
  }
  for (int i = 1; i <= n; i++)
    a[i] = rd;
  d[1] = 1;
  fa[1] = 1;
  dfs1(1, -1);
  dfs2(1, 1);
  build(1, 1, n);
  q = rd;
  while (q--) {
    char s[10];
    scanf("%s", s);
    int x = rd, y = rd;
    if (s[1] == 'H')
      update(1, 1, n, tpos[x], y);
    if (s[1] == 'M')
      printf("%lld\n", Querymax(x, y));
    if (s[1] == 'S')
      printf("%lld\n", Querysum(x, y));
  }
  return 0;
}

对于 \(100 \%\) 的数据,保证 \(1\le n \le 3\times 10^4\)\(0\le q\le 2\times 10^5\)

中途操作中保证每个节点的权值 \(w\)\(-3\times 10^4\)\(3\times 10^4\) 之间。


长链

请先回顾重链剖分

image.png

那么什么是长链剖分呢? 对应地,我们定义:

  • 长儿子:父结点的所有儿子中子树最大深度最深的结点(如果有多个则只选择第一个)

  • 链顶为这条链深度最大的节点(?

  • 链底为这条链深度最小的节点(?

  • 其它定义类似或相同。

那么既然我们要求k级祖先,那么首先我们可以想到的较优的算法是倍增法。或者我们牺牲空间复杂度,来做到O(1)查询。

例题 #1 树上 K 级祖先

给定一棵 \(n\) 个点的有根树。

\(q\) 次询问,第 \(i\) 次询问给定 \(x_i, k_i\),要求点 \(x_i\)\(k_i\) 级祖先,答案为 \(ans_i\)。特别地,\(ans_0 = 0\)

本题中的询问将在程序内生成。

给定一个随机种子 \(s\) 和一个随机函数 \(\operatorname{get}(x)\)

#define ui unsigned int
ui s;

inline ui get(ui x) {
    x ^= x << 13;
    x ^= x >> 17;
    x ^= x << 5;
    return s = x; 
}

你需要按顺序依次生成询问。

\(d_i\) 为点 \(i\) 的深度,其中根的深度为 \(1\)

对于第 \(i\) 次询问,\(x_i = ((\operatorname{get}(s) \operatorname{xor} ans_{i-1}) \bmod n) + 1\)\(k_i = (\operatorname{get}(s) \operatorname{xor} ans_{i-1}) \bmod d_{x_i}\)

对于 \(100\%\) 的数据,\(2 \le n \le 5 \times 10^5\)\(1 \le q \le 5 \times 10^6\)\(1 \le s < 2^{32}\)

长链剖分解决

首先我们知道一个性质:任意一个点的k级祖先所在链的链长一定大于等于k。结合长链的性质易证。

我们先倍增出每个节点的\(2^t\)级祖先。

然后,我们对于每条链处理。如果链长是\(len\),那么在链头处记录链头向上的\(len\)个祖先,并记录向下的\(len\)个在链内的节点。

假设我要查询\(u\)\(k\)级祖先\(v\),那么我先跳到其\(2^{⌊log_2​k⌋}\)级祖先\(v'\)处。那么我们现在还需要往上跳\(k'=k-2^{⌊log_2​k⌋}\)层。

由性质得,\(u\)\(2^{⌊log_2​k⌋}\)级祖先(即\(v'\))所在链的链长\(len\)一定大于等于\(2^{⌊log_2​k⌋}\),那么现在我们已经预处理出了\(v'\)所在的链的链头\(t\)上下的\(len\)个节点。

我们又知道\(k'<2^{⌊log_2​k⌋}\),所以我们就已经处理出了\(v'\)上第\(k'\)个节点(我们记\(l=dep(t)-dep(v')\),如果\(l<k'\),那么我们访问\(t\)向上的\(k'-l\)个节点,否则我们就访问\(t\)向下\(l-k'\)个节点),直接访问即可。

部分数组解释

int dep[N], fa[N][22];//点的深度,父亲
int mxdp[N];//点所在子树的最大深度
int son[N],top[N];//点的长儿子,点所在链的链头
vector<int> up[N],down[N];//每个点向上/下第i点的编号

代码

/*////////ACACACACACACAC///////////
       . Coding by Ntsc .
       . Prove Yourself .
/*////////ACACACACACACAC///////////

#include <bits/stdc++.h>
#define ll long long
#define db double
#define rtn return
#define pb push_back
using namespace std;

const int N = 5e5 + 5;
const int M = 1e5;
const int Mod = 1e5;
const int INF = 1e5;
vector<int> e[N];
int n, m, sum, ans1, dep[N], fa[N][22], k, tmp, ans2;

int mxdp[N],son[N],top[N],lg[N],q,lstans;
ll ans;
vector<int> up[N],down[N];

int rt;
#define ui unsigned int
ui s;

inline ui get(ui x) {
    x ^= x << 13;
    x ^= x >> 17;
    x ^= x << 5;
    return s = x; 
}


void add(int a, int b) { e[a].push_back(b); }

void dfs1(int u) {  //求深度
    for (int i = 1;i<20; i++) fa[u][i] = fa[fa[u][i - 1]][i - 1];  // att!
    mxdp[u]=dep[u] = dep[fa[u][0]] + 1;

    for (auto v:e[u]) {
        fa[v][0] = u;
        dfs1(v);
        if(mxdp[u]<mxdp[v]){
            son[u]=v;//更新长儿子 
            mxdp[u]=mxdp[v];
        }

    }
}
void dfs2(int u, int tfa) {  //求down,up,top


    top[u]=tfa;

    if(u==tfa){//链头更新 
        for(int i=0,v=u;i<=mxdp[u]-dep[u];i++){
            up[u].pb(v);v=fa[v][0];
        }for(int i=0,v=u;i<=mxdp[u]-dep[u];i++){
            down[u].pb(v);v=son[v];
        }
    }

    if(son[u])dfs2(son[u],tfa);//同链继承 

    for (auto v:e[u]) {
        if (v == son[u])
            continue;
        dfs2(v, v);//异链重开 

    }
}
int query(int x,int k){
    if(!k)return x;

    x=fa[x][lg[k]];
    k-=1<<lg[k];
    k-=dep[x]-dep[top[x]];x=top[x];

    if(k>=0){
        return up[x][k];
    }else{
        return down[x][-k];
    }
}

signed main() {
    scanf("%d%d%d", &n, &q,&s);  
    lg[1]=0;
    for(int i=2;i<=n;i++){
        lg[i]=lg[i>>1]+1;
    }

    for(int i=1;i<=n;i++){
        int f;cin>>f;
        if(!f)rt=i;//注意根不是1 
        else add(f,i);
    }
    dfs1(rt);

//    cerr<<"OK\n";
    dfs2(rt,rt);

    for (int i = 1; i <= q; i++) {
//      cerr<<"run at="<<i<<endl;
        int x=((get(s)^lstans)%n)+1;
        int k=(get(s)^lstans)%dep[x];
        lstans=query(x,k);
        ans^=1ll*i*lstans;

    }

    cout<<ans<<endl;
    return 0;
}

例题 #2 [Cnoi2019] 雪松果树

雪松果树是一个以 \(1\) 为根有着 \(N\) 个节点的树。

除此之外, Cirno还有 \(Q\) 个询问,每个询问是一个二元组 \((u,k)\) ,表示询问 \(u\) 节点的 \(k\)-cousin 有多少个。

我们定义:

节点 \(u\)\(1\)-father 为 路径 \((1, u)\) (不含 u)上距 u 最近的节点

节点 \(u\)\(k\)-father 为 节点 「\(u\)\((k-1)\)-father」 的 1-father

节点 \(u\)\(k\)-son 为所有 \(k\)-father 为 \(u\) 的节点

节点 \(u\)\(k\)-cousin 为 节点「 \(u\)\(k\)-father」的 \(k\)-son (不包含 \(u\) 本身)

输入格式

第一行,两个整数 \(N\), \(Q\)

第二行, \(N-1\) 个整数,第 \(i\) 个表示 \(i+1\) 号节点的 1-father

以下 \(Q\) 行,每行一个二元组\((u,k)\)

输出格式

一行,\(Q\) 个数,每一个表示一个询问的答案。若 u 不存在 k-father,输出 0。

对于 70%-100% 的数据 \(N,Q \le 10^6\)

另外存在一组记 \(20\) 分的 hack 数据。


本题使用dsu on tree(重链)可以写,但是会被卡常

因此我们考虑时空复杂度更加优秀的长链剖分。

还是一样的,我们用一个数组f_i表示比当前点深i的子节点的信息。为什么不是记录当前子树内dep为i的节点的信息呢?下面再说。

那么我们使用长链剖分后,我们就可以一样把信息从我们的长儿子那里继承过来。

但是这样可以优化时间复杂度吗?

/*
                      Keyblinds Guide
                    ###################
      @Ntsc 2024

      - Ctrl+Alt+getId then P : Enter luogu problem details
      - Ctrl+Alt+B : Run all cases in CPH
      - ctrl+D : choose this and dump to the nxt
      - ctrl+Shift+L : choose all like this
      - ctrl+K then ctrl+W: close all
      - Alt+la/ra : move mouse to pre/nxt pos'

*/
#include <bits/stdc++.h>
#include <queue>
using namespace std;

#define rep(i, l, r) for (int i = l, END##i = r; i <= END##i; ++i)
#define per(i, r, l) for (int i = r, END##i = l; i >= END##i; --i)
#define pb push_back
// #define int long long
#define ull unsigned long long
#define pii pair<int, int>
#define ps second
#define pf first
#define mp make_pair

// #define innt int
#define itn int
// #define inr intw
// #define mian main
// #define iont int

#define rd read()
int read() {
    int xx = 0, ff = 1;
    char ch = getchar();
    while (ch < '0' || ch > '9') {
        if (ch == '-')
            ff = -1;
        ch = getchar();
    }
    while (ch >= '0' && ch <= '9')
        xx = xx * 10 + (ch - '0'), ch = getchar();
    return xx * ff;
}
void write(int out) {
    if (out < 0)
        putchar('-'), out = -out;
    if (out > 9)
        write(out / 10);
    putchar(out % 10 + '0');
}

#define ell dbg('\n')
const char el='\n';
const bool enable_dbg = 1;
template <typename T,typename... Args>
void dbg(T s,Args... args) {
    if constexpr (enable_dbg) {
        cerr << s;
        if(1)cerr<<' ';
        if constexpr (sizeof...(Args))
            dbg(args...);
    }
}

#define zerol = 1
#ifdef zerol
#define cdbg(x...) do { cerr << #x << " -> "; err(x); } while (0)
void err() {
    cerr << endl;
}
template<template<typename...> class T, typename t, typename... A>
void err(T<t> a, A... x) {
    for (auto v: a) cerr << v << ' ';
    err(x...);
}
template<typename T, typename... A>
void err(T a, A... x) {
    cerr << a << ' ';
    err(x...);
}
#else
#define dbg(...)
#endif



const int N = 1e6 + 10;
const int INF = 2e17;
const int M = 1e3 + 10;


/*

*/
int n,m,fa[N],len[N],son[N],u[N],k[N],anc[N],ans[N],tmp[N<<1],*f[N],*id=tmp,sta[N],top;
int head[N],tot,h[N],cnt;

struct node
{
    int to,nxt,id;
}e[N],q[N];


inline void addedge(int x,int y)
{
    e[++tot].to=y;
    e[tot].nxt=head[x];
    head[x]=tot;
}
inline void addque(int x,int y,int z)
{
    q[++cnt].to=y;
    q[cnt].id=z;
    q[cnt].nxt=h[x];
    h[x]=cnt;
}
void dfs1(int x)
{

    /*
    使用dfs栈来离线求出每个节点的k级首先询问
    O(N)
    */
    sta[++top]=x;
    vector<pair<int,int> >::iterator it;
    for(int i=h[x];i;i=q[i].nxt)
        if(top>q[i].to) anc[q[i].id]=sta[top-q[i].to];
    for(int i=head[x];i;i=e[i].nxt) dfs1(e[i].to);
    top--;
}
void dfs2(int x)
{
    /*
    使用长链剖分+dp的形式
    x的dp可以直接从长儿子里继承,其他的包i里合并
    因为一条路径上最多只有一条短链,
    所以合并的时间复杂度加起来是O(N)的

    */
    for(int i=head[x];i;i=e[i].nxt)
    {
        int y=e[i].to;dfs2(y);
        if(len[y]>len[son[x]]) son[x]=y;
    }
    len[x]=len[son[x]]+1;
}
void dfs3(int x)
{
    for(int i=head[x];i;i=e[i].nxt)
    {
        int y=e[i].to;
        if(y==son[x]) continue;
        f[y]=id;id+=len[y];
        //因为不能合并,所以另外分配空间
        dfs3(y);
    }
    if(son[x]) f[son[x]]=f[x]+1,dfs3(son[x]);//长儿子的dp数组我们要继承,我们我们就人指针赋值为后移一idx

    /*
    因为我们的dp数组记录的是地址,并且长儿子就在x的下方,dp的下标代表深度
    因此只需要直接偏移即可。


    */
    f[x][0]=1;

    for(int i=head[x];i;i=e[i].nxt)
    {
        int y=e[i].to;
        if(y==son[x]) continue;
        /*
        对于短儿子,我们直接暴力合并即可。

        */
        for(int j=1;j<=len[y];j++) f[x][j]+=f[y][j-1];
    }
    for(int i=h[x];i;i=q[i].nxt) ans[q[i].id]=f[x][q[i].to]-1;
}

void solve()
{
    n=rd,m=rd;
    for(int i=2;i<=n;i++) fa[i]=read(),addedge(fa[i],i);
    for(int i=1;i<=m;i++) u[i]=read(),k[i]=read(),addque(u[i],k[i],i);
    dfs1(1);

    memset(h,0,sizeof(h));
    cnt=0;
    for(int i=1;i<=m;i++)
        if(anc[i]) addque(anc[i],k[i],i);

    dfs2(1);

    f[1]=id; //给长儿子分配最前面的一段dp数组
    id+=len[1]+1; //id论据下一个链(短链)分配的dp数组的begin指针

    dfs3(1);
    for(int i=1;i<=m;i++) cout<<ans[i]<<' ';
    puts("");

}


signed main() {
    // freopen("kujou.in","r",stdin);
    // freopen("kujou.out","w",stdout);

    int T=1;
    while(T--){
        solve();
    }

    return 0;
}

例题 #3 [POI2014] HOT-Hotels

给定一棵树,在树上选 \(3\) 个点,要求两两距离相等,求方案数。

\(n≤10^5\)


#include<bits/stdc++.h>
using namespace std;
#define int long long
#define itn int 
#define pb push_back
#define rd read()
#define cdbg(x) cerr<<#x<<" : "<<x<<endl;
inline int read(){
    int x;
    cin>>x;
    return x;
}

/*
策略


考虑n^2的dp
f_{i,j}表示到i距离为j的点的数量
 合法的答案应该是一个三叉口形式,即除了f还有一条链
g_{i,j}表示满足a->l,b->l距离为d时i->l距离为d+j的(a,b)数量

答案可以哦组合得到


*/
const int INF=2e9;
const int N=5e4+5;

vector<int> e[N];

void add(int a,int b){
    e[a].pb(b);
    e[b].pb(a);
}

int *f[N],*g[N];
itn p[N<<2];
int *o=p;
int ans;
int len[N];
int son[N];

int n;

void dfs(int x,int fa){
    // dep[x]=dep[fa]+1;
    for(auto v:e[x]){
        if(v==fa)continue;
        dfs(v,x);
        if(len[v]>len[son[x]])son[x]=v;
    }
    len[x]=len[son[x]]+1;
}

void dfs2(int x,int fa){
    if(son[x]){
        f[son[x]]=f[x]+1;
        g[son[x]]=g[x]-1;
        dfs2(son[x],x);
    }

    f[x][0]=1;
    ans+=g[x][0];

    for(auto v:e[x]){
        if(v==fa||v==son[x])continue;
        f[v]=o;
        o+=len[v]*2;
        g[v]=o;
        o+=len[v]*2;

        dfs2(v,x);
        for(int i=0;i<=len[v];i++){
            if(i)ans+=f[x][i-1]*g[v][i];
            ans+=g[x][i+1]*f[v][i];
        }

        for(int i=0;i<len[v];i++){
            g[x][i+1]+=f[x][i+1]*f[v][i];
            if(i)g[x][i-1]+=g[v][i];
            f[x][i+1]+=f[v][i];
        }

    }
}

signed main(){
     n=rd;
    for(int i=1;i<n;i++){
        add(rd,rd);
    }


    dfs(1,0);
    f[1]=o;
    o+=len[1]*2;

    g[1]=o;
    o+=len[1]*2;

    dfs2(1,0);

    // cdbg(g[6][2]);
    // cdbg(g[5][1]);

    cout<<ans<<endl;


    return 0;
}

题单

树剖题单