BZOJ 3697 采药人的路径

2018.03.06

题目大意

给你一棵树,每条边的边权为1/-1,求这棵树上有多少路径满足它可以分成两段路径,这两段路径分别的边权和均为0。


看到是处理所有可行路径立马想到可以使用点分治来解决问题。

但是两段路径边权和都为0这个条件有点苛刻= =

设$f_{i,0/1}$表示$x$扫过子树中无/有休息站的,边权和为$i$的路径的个数,$g_{i,0/1}$表示$x$正在扫的子树中无/有休息站的,边权和为$i$的路径的个数。之后用$g$更新$f$就好了。注意起点为$x$的情况。

(80%文本摘自CQZhangyu博客

#include <cstdio>
#include <cctype>
#include <cstring>
#include <algorithm>
using namespace std;
#define N 100010
inline char nc()
{
	static char buf[100000],*p1,*p2;
	return p1==p2&&(p2=(p1=buf)+fread(buf,1,100000,stdin),p1==p2)?EOF:*p1++;
}
inline int ri()
{
	int ret=0;char tmp=nc();
	while(!isdigit(tmp)) tmp=nc();
	while(isdigit(tmp)) ret=ret*10+(tmp^'0'),tmp=nc();
	return ret;
}
int n,nn,rt,siz[N],msi[N],dis[N],f[N<<1][2],q[N],ers[N],cnt[N<<1];
long long ans;
bool avai[N],vis[N];
int to[N<<1],nxt[N<<1],val[N<<1],head[N],tot;
inline void ae(int x,int y,int z){to[++tot]=y,nxt[tot]=head[x],head[x]=tot,val[tot]=z;}
void getrt(int pos,int pre)
{
	siz[pos]=1,msi[pos]=0;
	for(int i=head[pos];i;i=nxt[i])
		if(!vis[to[i]]&&to[i]!=pre)
		{
			getrt(to[i],pos);
			siz[pos]+=siz[to[i]];
			msi[pos]=max(msi[pos],siz[to[i]]);
		}
	msi[pos]=max(msi[pos],nn-siz[pos]);
	if(msi[pos]<msi[rt]) rt=pos;
}
void dfs(int pos,int pre)
{
	cnt[n+dis[pos]]++;
	if(cnt[n+dis[pos]]>1) avai[pos]=true;
	q[++q[0]]=pos;
	for(int i=head[pos];i;i=nxt[i])
		if(!vis[to[i]]&&to[i]!=pre)
		{
			avai[to[i]]=false;
			dis[to[i]]=dis[pos]+val[i],dfs(to[i],pos);
		}
	cnt[n+dis[pos]]--;
}
void solve(int pos)
{
	ers[0]=0;
	vis[pos]=true;
	for(int i=head[pos];i;i=nxt[i])
		if(!vis[to[i]])
		{
			q[0]=0,dis[to[i]]=val[i],avai[to[i]]=false,dfs(to[i],pos);
			for(int j=1;j<=q[0];++j)
			{
				if(avai[q[j]]) ans+=f[n-dis[q[j]]][0];
				ans+=f[n-dis[q[j]]][1];
				if(avai[q[j]]&&!dis[q[j]]) ans++;
				if((!avai[q[j]])&&(!dis[q[j]])) avai[q[j]]=true;
			}
			for(int j=1;j<=q[0];++j)
				f[ers[++ers[0]]=n+dis[q[j]]][avai[q[j]]]++;
		}
	for(int i=1;i<=ers[0];++i) f[ers[i]][0]=f[ers[i]][1]=0;
	for(int i=head[pos];i;i=nxt[i])
		if(!vis[to[i]])
			nn=siz[to[i]],rt=0,getrt(to[i],0),solve(rt);
}
int main()
{
	n=ri();
	for(int i=2,tx,ty,tz;i<=n;++i)
	{
		tx=ri(),ty=ri(),tz=ri();
		if(!tz) tz=-1;
		ae(tx,ty,tz),ae(ty,tx,tz);
	}
	msi[0]=0x3f3f3f3f,nn=n,rt=0;getrt(1,0),solve(rt);
	printf("%lld\n",ans);
	return 0;
}