BZOJ 1103 [POI2007]大都市meg

2018.03.03

题目大意

给你一颗树,每条边初始长度=1,每次指定两个节点使其间路长变为0,每次求1号点到某个点的距离。


这道题裸树剖……

注意这里是边权不是点权,写错可是拿不到分的= =

#include <cstdio>
#include <algorithm>
using namespace std;
#define lson pos<<1
#define rson pos<<1|1
#define maxN 250010
int n,m,depth[maxN],fa[maxN],son[maxN],top[maxN],siz[maxN],num[maxN],inx,iny;
int laz[maxN<<2],val[maxN<<2];
int to[maxN<<1],nex[maxN<<1],head[maxN<<1],tot,cnt;
char mode[5];
void addedge(int tx,int ty) {to[++tot]=ty,nex[tot]=head[tx],head[tx]=tot;}
void pushup(int pos) {val[pos]=val[lson]+val[rson];}
void dfs1(int pos,int pre)
{
    fa[pos]=pre;
    depth[pos]=depth[pre]+1;
    siz[pos]=1;
    for(int i=head[pos];i;i=nex[i])
        if(to[i]!=pre)
        {
            dfs1(to[i],pos);
            siz[pos]+=siz[to[i]];
            if(siz[son[pos]]<siz[to[i]]) son[pos]=to[i];
        }
}
void dfs2(int pos,int pre)
{
    num[pos]=++cnt;
    top[pos]=pre;
    if(son[pos]) dfs2(son[pos],pre);
    for(int i=head[pos];i;i=nex[i]) if(to[i]!=fa[pos]&&to[i]!=son[pos]) dfs2(to[i],to[i]);
}
void build(int pos,int l,int r)
{
    if(l==r) {val[pos]=1;return;}
    int mid=(l+r)>>1;
    build(lson,l,mid);
    build(rson,mid+1,r);
    pushup(pos);
    return;
}
void upd(int pos,int l,int r,int x,int y)
{
    if(x<=l&&r<=y)
    {
        val[pos]=0;
        laz[pos]=true;
        return;
    }
    if(laz[pos]) return;
    int mid=(l+r)>>1;
    if(x<=mid) upd(lson,l,mid,x,y);
    if(y>mid) upd(rson,mid+1,r,x,y);
    pushup(pos);
    return;
}
void update(int tx,int ty)
{
    while(top[tx]!=top[ty])
    {
        if(depth[top[tx]]<depth[top[ty]]) swap(tx,ty);
        upd(1,1,n,num[top[tx]],num[tx]);
        tx=fa[top[tx]];
    }
    if(tx>ty) swap(tx,ty);
    if(tx!=ty) upd(1,1,n,num[tx]+1,num[ty]);
}
int getsum(int pos,int l,int r,int x,int y)
{
    if(x<=l&&r<=y) return val[pos];
    if(laz[pos]) return 0;
    int mid=(l+r)>>1,ret=0;
    if(x<=mid) ret+=getsum(lson,l,mid,x,y);
    if(y>mid) ret+=getsum(rson,mid+1,r,x,y);
    return ret;
}
int query(int tx,int ty)
{
    int ret=0;
    while(top[tx]!=top[ty])
    {
        if(depth[top[tx]]<depth[top[ty]]) swap(tx,ty);
        ret+=getsum(1,1,n,num[top[tx]],num[tx]);
        tx=fa[top[tx]];
    }
    if(tx>ty) swap(tx,ty);
    if(tx!=ty) ret+=getsum(1,1,n,num[tx]+1,num[ty]);
    return ret;
}
int main()
{
    scanf("%d",&n);
    for(int i=1;i<n;++i) scanf("%d%d",&inx,&iny),addedge(inx,iny),addedge(iny,inx);
    dfs1(1,0);
    dfs2(1,1);
    build(1,1,n);
    scanf("%d",&m);
    m+=n-1;
    while(m--)
    {
        scanf("%s",mode);
        switch(mode[0])
        {
            case 'A':
            {
                scanf("%d%d",&inx,&iny);
                update(inx,iny);
                break;
            }
            case 'W':
            {
                scanf("%d",&inx);
                printf("%d\n",query(1,inx));
                break;
            }
        }
    }
}