fogflea
fogflea
发布于 2025-07-18 / 10 阅读
0
0

[ZJOI2019] 语言

From zxy的思维技巧 而来。

传送门

这真是一道套路的好题啊。

首先要计算二元组 (a,b) 的个数,看数据范围肯定枚举一个 a ,然后计算所有满足的 b ,最后求和后除以 2

然后具体的考虑怎么计算所谓 b 的个数,考虑一个点 u ,这实际上是一个树上连通块的大小,而连通块的端点就是由其中经过 u 的给出路径的端点产生,然后又是一个套路,树上连通块的大小,边权和之类的东西都可以用端点的 \texttt{dfs} 序刻画,比如这道题里,假设端点是 v_1,v_2,v_3 \dotsb v_k ,那么这个连通块的点集大小 |V|=de_{v_1}-de_{lca(v_1,v_2,v_3 \dotsb v_k)}+\sum^{k}_{i=2}(de_{v_i}-de_{lca(v_i,v_{i-1})}) ,这个答案是线段树易于维护的,那么接下来要考虑的只有如何在枚举 u 的时候动态维护端点集合,而一条路径上的两个端点只会对这条路径上的点造成影响,自然考虑树上差分端点的影响,而树上差分需要合并子树的影响,所以上线段树合并就好了。

#include<bits/stdc++.h>
using namespace std;
#define fio(x) freopen(x".in","r",stdin),freopen(x".out","w",stdout)
#define tio() freopen("in.txt","r",stdin),freopen("out.txt","w",stdout)
using ll=long long;
const int N=1e5+10;
int n,m;
vector<int> g[N];
int fa[N][19],de[N],dfn[N],rnk[N],cnt;
void dfs(int u,int fath){
    fa[u][0]=fath;
    de[u]=de[fath]+1;
    dfn[u]=++cnt;
    rnk[cnt]=u;
    for(int i=1;i<19;++i)
        fa[u][i]=fa[fa[u][i-1]][i-1];
    for(int v:g[u]){
        if(v==fath)continue;
        dfs(v,u);
    }
}
int lca(int x,int y){
    if(de[x]<de[y])swap(x,y);
    int c=de[x]-de[y];
    for(int i=0;c;c>>=1,++i)
        if(c&1)x=fa[x][i];
    if(x==y)return x;
    for(int i=18;i>=0;--i){
        if(fa[x][i]==fa[y][i])continue;
        x=fa[x][i],y=fa[y][i];
    }
    return fa[x][0];
}
vector<int>opt[N];
ll ans;
struct st{
    struct node{
        int lc,rc;
        int v;
        int lk,rk;
        int sum;
    }d[N<<7|1];
    int tot;
    int nn(node x){
        d[++tot]=x;
        return tot;
    }
    void pushup(int p){
        int l=d[p].lc,r=d[p].rc;
        d[p].lk=d[l].lk?d[l].lk:d[r].lk;
        d[p].rk=d[r].rk?d[r].rk:d[l].rk;
        d[p].sum=0;
        if(l)d[p].sum+=d[l].sum;
        if(r)d[p].sum+=d[r].sum;
        if(l&&r)d[p].sum-=de[lca(d[r].lk,d[l].rk)];
    }
    void add(int p,int x,int k,int l=1,int r=n){
        if(l==r){
            d[p].v+=k;
            if(d[p].v) d[p].lk=d[p].rk=rnk[l],d[p].sum=de[rnk[l]];
            else d[p].sum=d[p].lk=d[p].rk=0;
            return ;
        }
        int mid=l+r>>1;
        if(x<=mid){
            if(!d[p].lc)d[p].lc=nn({});
            add(d[p].lc,x,k,l,mid);
        }
        else{
            if(!d[p].rc)d[p].rc=nn({});
            add(d[p].rc,x,k,mid+1,r);
        }
        pushup(p);
    }
    int merge(int p,int q,int l=1,int r=n){
        if(!p)return q;
        if(!q)return p;
        if(l==r){
            d[p].v+=d[q].v;
            if(d[p].v) d[p].lk=d[p].rk=rnk[l],d[p].sum=de[rnk[l]];
            else d[p].sum=d[p].lk=d[p].rk=0;
            return p;
        }
        int mid=l+r>>1;
        d[p].lc=merge(d[p].lc,d[q].lc,l,mid);
        d[p].rc=merge(d[p].rc,d[q].rc,mid+1,r);
        pushup(p);
        return p;
    }
    void print(int p,int l=1,int r=n){
        if(l==r){
            printf("(%d %d) ",rnk[l],d[p].v);
            return ;
        }
        int mid=l+r>>1;
        if(d[p].lc)print(d[p].lc,l,mid);
        if(d[p].rc)print(d[p].rc,mid+1,r);
    }
}t;
void dfs2(int u){
    for(int v:g[u]){
        if(v==fa[u][0])continue;
        dfs2(v);
        int tmp=t.merge(u,v);
    }
    for(int x:opt[u])
        t.add(u,x,-1);
    ans+=t.d[u].sum-de[lca(t.d[u].lk,t.d[u].rk)];
    // printf("%d %d\n",u,t.d[u].sum-de[lca(t.d[u].lk,t.d[u].rk)]);
}
int main(){
    scanf("%d%d",&n,&m);
    t.tot=n;
    for(int i=1;i<n;++i){
        int u,v;
        scanf("%d%d",&u,&v);
        g[u].push_back(v);
        g[v].push_back(u);
    }
    dfs(1,0);
    for(int i=1;i<=m;++i){
        int x,y;
        scanf("%d%d",&x,&y);
        int w=lca(x,y);
        t.add(x,dfn[x],1),t.add(y,dfn[x],1);
        t.add(x,dfn[y],1),t.add(y,dfn[y],1);
        opt[w].push_back(dfn[x]),opt[w].push_back(dfn[y]);
        opt[fa[w][0]].push_back(dfn[x]),opt[fa[w][0]].push_back(dfn[y]);
    }
    dfs2(1);
    printf("%lld\n",ans>>1);
    // t.print(1);
    // printf("\n");
    // t.print(3);
    return 0;
}

评论