BZOJ 4127 Abs

2018.03.03

题目大意

给你一颗树,请你维护以下操作:路径加上一个数(只+不-),求路径节点点权和。


这道题一看是路径上的就是树链剖分……之后重点是如何维护绝对值。在经过仔细思考之后我还是没能弄懂这个东西的性质。但是重点在这里——区间修改必须是加Add(一个非负整数)

考虑以下:我们一共有N个点,如果我们暴力重构每个点的话需要$O(N \log N)$。对于每个数我们只需要在其由负数转成正数时进行一次暴力重构,所以我们的期望时间复杂度就是$O(N \log N)$ 。

维护区间的负数个数,区间绝对值和,laz标记相关信息,树链剖分就很简单了。

注意这道题爆int,所以需要用long long,最重要的是初始值千万千万不能设成maxint那样的话会GG。还有别忘了更新值的同时只要不满足l==r就一定要更新laz值

#include <cstdio>
#include <algorithm>
using namespace std;
#define lson pos<<1
#define rson pos<<1|1
#define maxN 100010
int n,m,initval[maxN],inittmp[maxN],inx,iny,inz,num[maxN],fa[maxN],son[maxN],siz[maxN],top[maxN],depth[maxN],mode;
int to[maxN<<1],nex[maxN<<1],head[maxN],tot,now;
long long laz[maxN<<2],val[maxN<<2],minn[maxN<<2],cnt[maxN<<2];
void addedge(int tx,int ty) {to[++tot]=ty,nex[tot]=head[tx],head[tx]=tot;}
void dfs1(int pos,int pre)
{
    depth[pos]=depth[pre]+1;
    fa[pos]=pre;
    siz[pos]=1;
    for(int i=head[pos];i;i=nex[i])
        if(to[i]!=pre)
        {
            dfs1(to[i],pos);
            siz[pos]+=siz[to[i]];
            if(siz[son[pos]]<siz[to[i]]) son[pos]=to[i];
        }
    return;
}
void dfs2(int pos,int pre)
{
    num[pos]=++now;
    top[pos]=pre;
    inittmp[num[pos]]=initval[pos];
    if(son[pos]) dfs2(son[pos],pre);
    for(int i=head[pos];i;i=nex[i])
        if(to[i]!=fa[pos]&&to[i]!=son[pos]) dfs2(to[i],to[i]);
    return;
}
void pushup(int pos)
{
    minn[pos]=min(minn[lson],minn[rson]);
    val[pos]=val[lson]+val[rson];
    cnt[pos]=cnt[lson]+cnt[rson];
}
void pushdown(int pos,int l,int r)
{
    if(laz[pos])
    {
        int mid=(l+r)>>1;
        laz[lson]+=laz[pos],laz[rson]+=laz[pos];
        val[lson]+=laz[pos]*(mid-l+1-(cnt[lson]<<1));
        val[rson]+=laz[pos]*(r-mid-(cnt[rson]<<1));
        minn[lson]-=laz[pos],minn[rson]-=laz[pos];
        laz[pos]=0;
    }
}
void build(int pos,int l,int r)
{
    if(l==r)
    {
        minn[pos]=inittmp[l]<0?-inittmp[l]:0x3f3f3f3f3f3f3f3fll;
        val[pos]=abs(inittmp[l]);
        cnt[pos]=inittmp[l]<0;
        return;
    }
    int mid=(l+r)>>1;
    build(lson,l,mid);
    build(rson,mid+1,r);
    pushup(pos);
}
long long atot(int pos,int l,int r,int x,int y)
{
    if(x<=l&&r<=y) return val[pos];
    pushdown(pos,l,r);
    int mid=(l+r)>>1;
    long long ret=0;
    if(x<=mid) ret+=atot(lson,l,mid,x,y);
    if(y>mid) ret+=atot(rson,mid+1,r,x,y);
    return ret;
}
long long query(int tx,int ty)
{
    long long ret=0;
    while(top[tx]!=top[ty])
    {
        if(depth[top[tx]]<depth[top[ty]]) swap(tx,ty);
        ret+=atot(1,1,n,num[top[tx]],num[tx]);
        tx=fa[top[tx]];
    }
    if(depth[tx]>depth[ty]) swap(tx,ty);
    ret+=atot(1,1,n,num[tx],num[ty]);
    return ret;
}
void rebuild(int pos,int l,int r,long long tx)
{
    if(l==r)
    {
        minn[pos]=0x3f3f3f3f3f3f3f3fll;
        val[pos]=1ll*tx-val[pos];
        cnt[pos]=0;
        return;
    }
    pushdown(pos,l,r);
    int mid=(l+r)>>1;
    if(minn[lson]<tx) rebuild(lson,l,mid,tx);
    else
    {
        val[lson]+=tx*(mid-l+1-(cnt[lson]<<1));
        minn[lson]-=tx;
        laz[lson]+=tx;
    }
    if(minn[rson]<tx) rebuild(rson,mid+1,r,tx);
    else
    {
        val[rson]+=tx*(r-mid-(cnt[rson]<<1));
        minn[rson]-=tx;
        laz[rson]+=tx;
    }
    pushup(pos);
}
void qadd(int pos,int l,int r,int x,int y,long long z)
{
    if(x<=l&&r<=y)
    {
        if(minn[pos]<z) rebuild(pos,l,r,z);
        else
        {
            laz[pos]+=z;
            val[pos]+=z*(r-l+1-(cnt[pos]<<1));
            minn[pos]-=z;
        }
        return;
    }
    pushdown(pos,l,r);
    int mid=(l+r)>>1;
    if(x<=mid) qadd(lson,l,mid,x,y,z);
    if(y>mid) qadd(rson,mid+1,r,x,y,z);
    pushup(pos);
    return;
}
void add(int tx,int ty,int tz)
{
    while(top[tx]!=top[ty])
    {
        if(depth[top[tx]]<depth[top[ty]]) swap(tx,ty);
        qadd(1,1,n,num[top[tx]],num[tx],tz);
        tx=fa[top[tx]];
    }
    if(depth[tx]>depth[ty]) swap(tx,ty);
    qadd(1,1,n,num[tx],num[ty],tz);
    return;
}
int main()
{
    scanf("%d%d",&n,&m);
    for(int i=1;i<=n;++i) scanf("%d",initval+i);
    for(int i=1;i<n;++i) scanf("%d%d",&inx,&iny),addedge(inx,iny),addedge(iny,inx);
    dfs1(1,0);
    dfs2(1,1);
    build(1,1,n);
    while(m--)
    {
        scanf("%d",&mode);
        switch(mode)
        {
            case 1:
            {
                scanf("%d%d%d",&inx,&iny,&inz);
                add(inx,iny,inz);
                break;
            }
            case 2:
            {
                scanf("%d%d",&inx,&iny);
                printf("%lld\n",query(inx,iny));
                break;
            }
        }
    }
}