【Luogu P3178】[HAOI2015]树上操作

P3178 [HAOI2015]树上操作

题目描述

有一棵点数为 N 的树,以点 1 为根,且树点有边权。然后有 M 个操作,分为三种:

  • 操作 1 :把某个节点 x 的点权增加 a 。
  • 操作 2 :把某个节点 x 为根的子树中所有点的点权都增加 a 。
  • 操作 3 :询问某个节点 x 到根的路径中所有点的点权和。

输入格式

第一行包含两个整数 N, M 。表示点数和操作数。
接下来一行 N 个整数,表示树中节点的初始权值。
接下来 N-1 行每行两个正整数 from, to , 表示该树中存在一条边 (from, to) 。
再接下来 M 行,每行分别表示一次操作。其中第一个数表示该操作的种类( 1-3 ) ,之后接这个操作的参数( x 或者 x a ) 。

输出格式

对于每个询问操作,输出该询问的答案。答案之间用换行隔开。

输入样例

5 5
1 2 3 4 5
1 2
1 4
2 3
2 5
3 3
1 2 1
3 5
2 1 2
3 3

输出样例

6
9
13

说明/提示

对于 100% 的数据, N,M<=100000 ,且所有输入数据的绝对值都不会超过 10^6 。

Solution

模板题,树链剖分+线段树

AC Code

不开 long long 见祖宗

#include<cstdio>
#include<vector>
typedef long long llint;
const bool DEBUG_ENABLE = 1;
const llint MAXN = 1e6 + 5;
const llint MOD = 19260827;

llint ti, n, m;
struct TREE {
    llint we[MAXN], siz[MAXN], dep[MAXN], tim[MAXN], son[MAXN], fa[MAXN], top[MAXN], wei[MAXN];
    std::vector<llint> grp[MAXN];
    inline void findSize(llint pos, llint fat) {
        fa[pos] = fat;
        dep[pos] = dep[fat] + 1;
        siz[pos] = 1;
        llint maxsize = -1;
        for(llint i = 0; i < grp[pos].size(); ++i) {
            llint v = grp[pos][i];
            if(v == fat) continue;
            findSize(v, pos);
            siz[pos] += siz[v];
            if(siz[v] > maxsize) maxsize = siz[v], son[pos] = v;
        }
    }
    inline void partition(llint pos, llint tp) {
        tim[pos] = ++ti;
        top[pos] = tp;
        wei[ti] = we[pos];
        if(!son[pos]) return;
        partition(son[pos], tp);
        for(llint i = 0; i < grp[pos].size(); ++i) {
            llint v = grp[pos][i];
            if(v == son[pos] || v == fa[pos]) continue;
            partition(v, v);
        }
    }
    inline void addEdge(llint u, llint v) {
        grp[u].push_back(v);
    }
    llint ar[MAXN], laz[MAXN];
    inline void build(llint rt, llint l, llint r) {
        if(l == r) {
            ar[rt] = wei[l];
            return;
        }
        llint mid = (l + r) >> 1;
        build(rt * 2, l, mid);
        build(rt * 2 + 1, mid + 1, r);
        ar[rt] = ar[rt * 2] + ar[rt * 2 + 1];
    }
    inline void pushdown(llint l, llint r, llint s, llint t, llint pos) {
        laz[pos * 2] += laz[pos];
        laz[pos * 2 + 1] += laz[pos];
        llint mid = (s + t) >> 1;
        ar[pos * 2] += (mid - s + 1) * laz[pos];
        ar[pos * 2 + 1] += (t - mid) * laz[pos];
        laz[pos] = 0;
    }
    inline llint getSum(llint l, llint r, llint s, llint t, llint pos) {
        if(l <= s && t <= r) return ar[pos];
        llint mid = (s + t) >> 1;
        llint ans = 0;
        pushdown(l, r, s, t, pos);
        if(l <= mid) ans += getSum(l, r, s, mid, pos * 2);
        if(r > mid) ans += getSum(l, r, mid + 1, t, pos * 2 + 1);
        return ans;
    }
    inline void update(llint l, llint r, llint val, llint s, llint t, llint pos) {
        if(l <= s && t <= r) {
            ar[pos] += (t - s + 1) * val;
            laz[pos] += val;
            return;
        }
        llint mid = (s + t) >> 1;
        pushdown(l, r, s, t, pos);
        if(l <= mid) update(l, r, val, s, mid, pos * 2);
        if(r > mid) update(l, r, val, mid + 1, t, pos * 2 + 1);
        ar[pos] = ar[pos * 2] + ar[pos * 2 + 1];
    }
    inline void multiUpdate(llint pos, llint val) {
        update(tim[pos], tim[pos] + siz[pos] - 1, val, 1, ti, 1);
    }
    inline llint query(llint pos) {
        llint ans = 0;
        llint x = pos, y = 1;
        while(top[x] != top[y]) {
            if(dep[top[x]] < dep[top[y]]) std::swap(x, y);
            ans += getSum(tim[top[x]], tim[x], 1, ti, 1);
            x = fa[top[x]];
        }
        if(dep[x] > dep[y]) std::swap(x, y);
        ans += getSum(tim[x], tim[y], 1, ti, 1);
        return ans;
    }
} a;

int main() {
/*==================================*/
#ifdef LOCAL_JUDGE
    freopen("\\Codes\\in.in", "r", stdin);
    freopen("\\Codes\\out.out", "w", stdout);
#endif
/*==================================*/
    scanf("%lld%lld", &n, &m);
    for(llint i = 1; i <= n; ++i) {
        scanf("%lld", &a.we[i]);
    }
    llint u, v;
    for(llint i = 1; i < n; ++i) {
        scanf("%lld%lld", &u, &v);
        a.addEdge(u, v);
        a.addEdge(v, u);
    }
    a.findSize(1, 0);
    a.partition(1, 1);
    a.build(1, 1, ti);
    llint opt;
    for(llint i = 1; i <= m; ++i) {
        scanf("%lld%lld", &opt, &u);
        if(opt == 1) {
            scanf("%lld", &v);
            a.update(a.tim[u], a.tim[u], v, 1, ti, 1);
        } else if(opt == 2) {
            scanf("%lld", &v);
            a.multiUpdate(u, v);
        } else {
            printf("%lld\n", a.query(u));
        }
    }
    return 0;
}
点赞

发表评论

电子邮件地址不会被公开。必填项已用 * 标注