题解 P4178 【Tree】

这道题,树上点对路径问题,再看数据范围,n40000n\leq 40000,十有八九是淀粉质。

那么这道题,我们就应该想想看怎么快速地求出一个节点uu的子树里有多少经过uu路径,满足长度不超过kk

首先,我们可以把uu的子树里所有节点的深度全都跑出来(淀粉质常规操作,窝就不多说了),存在一个数组ss里面,接着把它排好序(STLSTL sortsort足矣)。

于此同时我们需要标记出uu的子树里每个节点vv属于哪棵子树(即color[v])。对于uu的孩子kk,我们还需要统计一个cnt[k],表示kk的子树里有多少个节点。

接着我们拿两个指针llrr分别从左右两端向中间扫描ss。之前我们已经统计好的cnt[k]cnt[k],现在表示(l,r](l,r]里有多少kk的子孙。因为ss已经排好序了,所以当s[l]+s[r]ks[l]+s[r]\leq k时,对于任意ii满足l<irl<i\leq r,都有s[l]+s[i]ks[l]+s[i]\leq k,所以这时候我们给最终答案加上rlcnt[color[s[l]]]r-l-cnt[color[s[l]]],同时更新cnt,就行了。

放代码:

#include <bits/stdc++.h>
#define debug printf("Running %s on line %d...\n",__FUNCTION__,__LINE__)
#define in inline
#define re register
#define mid (l+r>>1)
using namespace std;
in int read()
{
    int ans=0,f=1;char c=getchar();
    for (;!isdigit(c);c=getchar()) if (c=='-') f=-1;
    for (;isdigit(c);c=getchar()) ans=(ans<<3)+(ans<<1)+(c^48);
    return ans*f;
}
int sum,n,k,ans;
int nex[100005],head[100005],tail[100005],weight[100005],tot;
in void addedge(int u,int v,int w)
{
    nex[++tot]=head[u];
    head[u]=tot;
    tail[tot]=v;
    weight[tot]=w;
}
int maxx[100005],sz[100005],root;
int vis[100005];
void findroot(int u,int fa)
{
    sz[u]=1,maxx[u]=0;
    for (int e=head[u];e;e=nex[e])
    {
        int v=tail[e];
        if (v==fa||vis[v]) continue;
        findroot(v,u);
        sz[u]+=sz[v];
        maxx[u]=max(maxx[u],sz[v]); 
    }
    maxx[u]=max(maxx[u],sum-sz[u]);
    if (maxx[u]<maxx[root]) root=u;
}
int dis[100005],color[100005],cnt[100005];
int s[100005],top,c;
void getdis(int u,int fa)
{
    s[++top]=u,cnt[color[u]=c]++;
    for (int e=head[u];e;e=nex[e])
    {
        int v=tail[e];
        if (vis[v]||v==fa) continue;
        dis[v]=dis[u]+weight[e];
        getdis(v,u);
    }
}
void calc(int u)
{
    top=0;
    for (int e=head[u];e;e=nex[e])
    {
        int v=tail[e];
        if (vis[v]) continue;
        c=v,dis[v]=weight[e];
        getdis(v,u);
    }
    sort(s+1,s+top+1,[](int x,int y){
        return dis[x]<dis[y];
    });
    int l=0,r=top;
    while (true)
    {
        while (r>l && dis[s[r]]+dis[s[l]]>k) cnt[color[s[r]]]--,r--;
        if (r==l) break;
        ans+=r-l-cnt[color[s[l]]];
        l++;
        if (l) cnt[color[s[l]]]--;
    }
}
void solve(int u)
{
    vis[u]=1;
    calc(u);
    for (int e=head[u];e;e=nex[e])
    {
        int v=tail[e];
        if (vis[v]) continue;
        maxx[root=0]=sum=sz[v];
        findroot(v,0);
        solve(root);
    }
}
int main()
{
    n=read();
    for (int i=1;i<n;i++)
    {
        int u=read(),v=read(),w=read();
        addedge(u,v,w);
        addedge(v,u,w);
    }
    k=read();
    maxx[root]=sum=n;
    findroot(1,0);
    solve(root);
    cout<<ans<<endl;
    return 0;
}