1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71
| #include<bits/stdc++.h> using namespace std; #define int long long struct ss{ int node,nxt; }e[600005]; inline int read() { int re=0,k=1;char ch=getchar(); while(ch<'0'||ch>'9'){if(ch=='-')k=-1;ch=getchar();} while(ch>='0'&&ch<='9'){re=(re<<1)+(re<<3)+ch-48;ch=getchar();} return re*k; } int u,v,n,col[200005],f[200005],del[200005],s[200005],sz[200005],head[200005],tot,g[200005]; void add(int u,int v) { e[++tot].nxt=head[u]; e[tot].node=v; head[u]=tot; } void getsz(int x,int fa) { for(int i=head[x];i;i=e[i].nxt) { int t=e[i].node; if(t==fa)continue; getsz(t,x); sz[x]+=sz[t]; } } void dfs(int x,int fa) { int lst=s[col[fa]]; s[col[fa]]=x; if(s[col[x]])del[s[col[x]]]+=sz[x]; else g[col[x]]+=sz[x]; for(int i=head[x];i;i=e[i].nxt) { int t=e[i].node; if(t==fa)continue; dfs(t,x); } int now=sz[x]-del[x]; //cerr<<col[fa]<<" "<<now<<" "<<sz[x]<<endl; f[col[fa]]+=now*(now+1)>>1; s[col[fa]]=lst; } signed main() { n=read(); for(int i=1;i<=n;i++) { col[i]=read();sz[i]=1; } for(int i=1;i<=n-1;i++) { u=read(); v=read(); add(u,v); add(v,u); } getsz(1,0); dfs(1,0); for(int i=1;i<=n;i++) g[i]=(n-g[i])*(n-g[i]+1)>>1; int ans=n*(n+1)>>1; for(int i=1;i<=n;i++) { printf("%lld\n",ans-f[i]-g[i]); } }
|