From zxy的思维技巧 而来。
这真是一道套路的好题啊。
首先要计算二元组 (a,b) 的个数,看数据范围肯定枚举一个 a ,然后计算所有满足的 b ,最后求和后除以 2 。
然后具体的考虑怎么计算所谓 b 的个数,考虑一个点 u ,这实际上是一个树上连通块的大小,而连通块的端点就是由其中经过 u 的给出路径的端点产生,然后又是一个套路,树上连通块的大小,边权和之类的东西都可以用端点的 \texttt{dfs} 序刻画,比如这道题里,假设端点是 v_1,v_2,v_3 \dotsb v_k ,那么这个连通块的点集大小 |V|=de_{v_1}-de_{lca(v_1,v_2,v_3 \dotsb v_k)}+\sum^{k}_{i=2}(de_{v_i}-de_{lca(v_i,v_{i-1})}) ,这个答案是线段树易于维护的,那么接下来要考虑的只有如何在枚举 u 的时候动态维护端点集合,而一条路径上的两个端点只会对这条路径上的点造成影响,自然考虑树上差分端点的影响,而树上差分需要合并子树的影响,所以上线段树合并就好了。
#include<bits/stdc++.h>
using namespace std;
#define fio(x) freopen(x".in","r",stdin),freopen(x".out","w",stdout)
#define tio() freopen("in.txt","r",stdin),freopen("out.txt","w",stdout)
using ll=long long;
const int N=1e5+10;
int n,m;
vector<int> g[N];
int fa[N][19],de[N],dfn[N],rnk[N],cnt;
void dfs(int u,int fath){
fa[u][0]=fath;
de[u]=de[fath]+1;
dfn[u]=++cnt;
rnk[cnt]=u;
for(int i=1;i<19;++i)
fa[u][i]=fa[fa[u][i-1]][i-1];
for(int v:g[u]){
if(v==fath)continue;
dfs(v,u);
}
}
int lca(int x,int y){
if(de[x]<de[y])swap(x,y);
int c=de[x]-de[y];
for(int i=0;c;c>>=1,++i)
if(c&1)x=fa[x][i];
if(x==y)return x;
for(int i=18;i>=0;--i){
if(fa[x][i]==fa[y][i])continue;
x=fa[x][i],y=fa[y][i];
}
return fa[x][0];
}
vector<int>opt[N];
ll ans;
struct st{
struct node{
int lc,rc;
int v;
int lk,rk;
int sum;
}d[N<<7|1];
int tot;
int nn(node x){
d[++tot]=x;
return tot;
}
void pushup(int p){
int l=d[p].lc,r=d[p].rc;
d[p].lk=d[l].lk?d[l].lk:d[r].lk;
d[p].rk=d[r].rk?d[r].rk:d[l].rk;
d[p].sum=0;
if(l)d[p].sum+=d[l].sum;
if(r)d[p].sum+=d[r].sum;
if(l&&r)d[p].sum-=de[lca(d[r].lk,d[l].rk)];
}
void add(int p,int x,int k,int l=1,int r=n){
if(l==r){
d[p].v+=k;
if(d[p].v) d[p].lk=d[p].rk=rnk[l],d[p].sum=de[rnk[l]];
else d[p].sum=d[p].lk=d[p].rk=0;
return ;
}
int mid=l+r>>1;
if(x<=mid){
if(!d[p].lc)d[p].lc=nn({});
add(d[p].lc,x,k,l,mid);
}
else{
if(!d[p].rc)d[p].rc=nn({});
add(d[p].rc,x,k,mid+1,r);
}
pushup(p);
}
int merge(int p,int q,int l=1,int r=n){
if(!p)return q;
if(!q)return p;
if(l==r){
d[p].v+=d[q].v;
if(d[p].v) d[p].lk=d[p].rk=rnk[l],d[p].sum=de[rnk[l]];
else d[p].sum=d[p].lk=d[p].rk=0;
return p;
}
int mid=l+r>>1;
d[p].lc=merge(d[p].lc,d[q].lc,l,mid);
d[p].rc=merge(d[p].rc,d[q].rc,mid+1,r);
pushup(p);
return p;
}
void print(int p,int l=1,int r=n){
if(l==r){
printf("(%d %d) ",rnk[l],d[p].v);
return ;
}
int mid=l+r>>1;
if(d[p].lc)print(d[p].lc,l,mid);
if(d[p].rc)print(d[p].rc,mid+1,r);
}
}t;
void dfs2(int u){
for(int v:g[u]){
if(v==fa[u][0])continue;
dfs2(v);
int tmp=t.merge(u,v);
}
for(int x:opt[u])
t.add(u,x,-1);
ans+=t.d[u].sum-de[lca(t.d[u].lk,t.d[u].rk)];
// printf("%d %d\n",u,t.d[u].sum-de[lca(t.d[u].lk,t.d[u].rk)]);
}
int main(){
scanf("%d%d",&n,&m);
t.tot=n;
for(int i=1;i<n;++i){
int u,v;
scanf("%d%d",&u,&v);
g[u].push_back(v);
g[v].push_back(u);
}
dfs(1,0);
for(int i=1;i<=m;++i){
int x,y;
scanf("%d%d",&x,&y);
int w=lca(x,y);
t.add(x,dfn[x],1),t.add(y,dfn[x],1);
t.add(x,dfn[y],1),t.add(y,dfn[y],1);
opt[w].push_back(dfn[x]),opt[w].push_back(dfn[y]);
opt[fa[w][0]].push_back(dfn[x]),opt[fa[w][0]].push_back(dfn[y]);
}
dfs2(1);
printf("%lld\n",ans>>1);
// t.print(1);
// printf("\n");
// t.print(3);
return 0;
}