BZOJ 1036 [ZJOI2008]树的统计Count

2018.03.03

题目大意

请你维护一颗点权树,支持:查询A到B路径上的点权和,最大点权,单点修改。


直接树链剖分一下就好了…… 注意一下更新顺序,相比之下算比较简单的题了= =

#include <cstdio>
#include <algorithm>
using namespace std;
#define lson pos<<1
#define rson pos<<1|1
int n,inx,iny,q,initval[30010],fa[30010],depth[30010],siz[30010],son[30010],num[30010],top[30010],val[30010],head[30010],nex[60010],to[60010],tot,cnt;
int sum[2000010],maxn[2000010];
char mode[10];
void addedge(int tx,int ty) {to[++tot]=ty,nex[tot]=head[tx],head[tx]=tot;}
void dfs1(int pos,int pre)
{
    fa[pos]=pre;
    depth[pos]=depth[pre]+1;
    siz[pos]=1;
    for(int i=head[pos];i;i=nex[i])
        if(to[i]!=pre)
        {
            dfs1(to[i],pos);
            siz[pos]+=siz[to[i]];
            if(siz[son[pos]]<siz[to[i]]) son[pos]=to[i];
        }
    return;
}
void dfs2(int pos,int pre)
{
    num[pos]=++cnt;
    top[pos]=pre;
    val[num[pos]]=initval[pos];
    if(son[pos]) dfs2(son[pos],pre);
    for(int i=head[pos];i;i=nex[i])
        if(to[i]!=fa[pos]&&to[i]!=son[pos])
            dfs2(to[i],to[i]);
    return;
}
void build(int pos,int l,int r)
{
    if(l==r)
    {
        sum[pos]=maxn[pos]=val[l];
        return;
    }
    int mid=(l+r)>>1;
    build(lson,l,mid);
    build(rson,mid+1,r);
    sum[pos]=sum[lson]+sum[rson];
    maxn[pos]=max(maxn[lson],maxn[rson]);
    return;
}
void update(int pos,int l,int r,int tx,int ty)
{
    if(l==r)
    {
        sum[pos]=maxn[pos]=ty;
        return;
    }
    int mid=(l+r)>>1;
    if(tx<=mid) update(lson,l,mid,tx,ty);
    else update(rson,mid+1,r,tx,ty);
    sum[pos]=sum[lson]+sum[rson];
    maxn[pos]=max(maxn[lson],maxn[rson]);
}
int amax(int pos,int l,int r,int x,int y)
{
    if(x==0||y==0) return -0x3f3f3f3f;
    if(x<=l&&r<=y) return maxn[pos];
    int mid=(l+r)>>1,ret=-0x3f3f3f3f;
    if(x<=mid) ret=max(ret,amax(lson,l,mid,x,y));
    if(y>mid) ret=max(ret,amax(rson,mid+1,r,x,y));
    return ret;
}
int qmax(int tx,int ty)
{
    int ret=-0x3f3f3f3f;
    while(top[tx]!=top[ty])
    {
        if(depth[top[tx]]<depth[top[ty]]) swap(tx,ty);
        ret=max(ret,amax(1,1,n,num[top[tx]],num[tx]));
        tx=fa[top[tx]];
    }
    if(depth[tx]>depth[ty]) swap(tx,ty);
    ret=max(ret,amax(1,1,n,num[tx],num[ty]));
    return ret;
}
int asum(int pos,int l,int r,int x,int y)
{
    if(x==0||y==0) return 0;
    if(x<=l&&r<=y) return sum[pos];
    int mid=(l+r)>>1,ret=0;
    if(x<=mid) ret+=asum(lson,l,mid,x,y);
    if(y>mid) ret+=asum(rson,mid+1,r,x,y);
    return ret;
}
int qsum(int tx,int ty)
{
    int ret=0;
    while(top[tx]!=top[ty])
    {
        if(depth[top[tx]]<depth[top[ty]]) swap(tx,ty);
        ret+=asum(1,1,n,num[top[tx]],num[tx]);
        tx=fa[top[tx]];
    }
    if(depth[tx]>depth[ty]) swap(tx,ty);
    ret+=asum(1,1,n,num[tx],num[ty]);
    return ret;
}
int main()
{
    scanf("%d",&n);
    for(int i=1;i<n;++i) scanf("%d%d",&inx,&iny),addedge(inx,iny),addedge(iny,inx);
    for(int i=1;i<=n;++i) scanf("%d",initval+i);
    dfs1(1,0);
    dfs2(1,1);
    build(1,1,n);
    scanf("%d",&q);
    while(q--)
    {
        scanf("%s",mode);
        switch(mode[1])
        {
            case 'H':
            {
                scanf("%d%d",&inx,&iny);
                update(1,1,n,num[inx],iny);
                break;
            }
            case 'M':
            {
                scanf("%d%d",&inx,&iny);
                printf("%d\n",qmax(inx,iny));
                break;
            }
            case 'S':
            {
                scanf("%d%d",&inx,&iny);
                printf("%d\n",qsum(inx,iny));
                break;
            }
        }
    }
    return 0;
}