题目描述
有一棵点数为 N 的树,以点 1 为根,且树点有边权。然后有 M 个操作,分为三种:
- 操作 1 :把某个节点 x 的点权增加 a 。
- 操作 2 :把某个节点 x 为根的子树中所有点的点权都增加 a 。
- 操作 3 :询问某个节点 x 到根的路径中所有点的点权和。
输入格式
第一行包含两个整数 N, M 。表示点数和操作数。
接下来一行 N 个整数,表示树中节点的初始权值。
接下来 N-1 行每行两个正整数 from, to , 表示该树中存在一条边 (from, to) 。
再接下来 M 行,每行分别表示一次操作。其中第一个数表示该操作的种类( 1-3 ) ,之后接这个操作的参数( x 或者 x a ) 。
输出格式
对于每个询问操作,输出该询问的答案。答案之间用换行隔开。
输入样例
5 5
1 2 3 4 5
1 2
1 4
2 3
2 5
3 3
1 2 1
3 5
2 1 2
3 3
输出样例
6
9
13
说明/提示
对于 100% 的数据, N,M<=100000 ,且所有输入数据的绝对值都不会超过 10^6 。
Solution
模板题,树链剖分+线段树
AC Code
不开 long long 见祖宗
#include<cstdio>
#include<vector>
typedef long long llint;
const bool DEBUG_ENABLE = 1;
const llint MAXN = 1e6 + 5;
const llint MOD = 19260827;
llint ti, n, m;
struct TREE {
llint we[MAXN], siz[MAXN], dep[MAXN], tim[MAXN], son[MAXN], fa[MAXN], top[MAXN], wei[MAXN];
std::vector<llint> grp[MAXN];
inline void findSize(llint pos, llint fat) {
fa[pos] = fat;
dep[pos] = dep[fat] + 1;
siz[pos] = 1;
llint maxsize = -1;
for(llint i = 0; i < grp[pos].size(); ++i) {
llint v = grp[pos][i];
if(v == fat) continue;
findSize(v, pos);
siz[pos] += siz[v];
if(siz[v] > maxsize) maxsize = siz[v], son[pos] = v;
}
}
inline void partition(llint pos, llint tp) {
tim[pos] = ++ti;
top[pos] = tp;
wei[ti] = we[pos];
if(!son[pos]) return;
partition(son[pos], tp);
for(llint i = 0; i < grp[pos].size(); ++i) {
llint v = grp[pos][i];
if(v == son[pos] || v == fa[pos]) continue;
partition(v, v);
}
}
inline void addEdge(llint u, llint v) {
grp[u].push_back(v);
}
llint ar[MAXN], laz[MAXN];
inline void build(llint rt, llint l, llint r) {
if(l == r) {
ar[rt] = wei[l];
return;
}
llint mid = (l + r) >> 1;
build(rt * 2, l, mid);
build(rt * 2 + 1, mid + 1, r);
ar[rt] = ar[rt * 2] + ar[rt * 2 + 1];
}
inline void pushdown(llint l, llint r, llint s, llint t, llint pos) {
laz[pos * 2] += laz[pos];
laz[pos * 2 + 1] += laz[pos];
llint mid = (s + t) >> 1;
ar[pos * 2] += (mid - s + 1) * laz[pos];
ar[pos * 2 + 1] += (t - mid) * laz[pos];
laz[pos] = 0;
}
inline llint getSum(llint l, llint r, llint s, llint t, llint pos) {
if(l <= s && t <= r) return ar[pos];
llint mid = (s + t) >> 1;
llint ans = 0;
pushdown(l, r, s, t, pos);
if(l <= mid) ans += getSum(l, r, s, mid, pos * 2);
if(r > mid) ans += getSum(l, r, mid + 1, t, pos * 2 + 1);
return ans;
}
inline void update(llint l, llint r, llint val, llint s, llint t, llint pos) {
if(l <= s && t <= r) {
ar[pos] += (t - s + 1) * val;
laz[pos] += val;
return;
}
llint mid = (s + t) >> 1;
pushdown(l, r, s, t, pos);
if(l <= mid) update(l, r, val, s, mid, pos * 2);
if(r > mid) update(l, r, val, mid + 1, t, pos * 2 + 1);
ar[pos] = ar[pos * 2] + ar[pos * 2 + 1];
}
inline void multiUpdate(llint pos, llint val) {
update(tim[pos], tim[pos] + siz[pos] - 1, val, 1, ti, 1);
}
inline llint query(llint pos) {
llint ans = 0;
llint x = pos, y = 1;
while(top[x] != top[y]) {
if(dep[top[x]] < dep[top[y]]) std::swap(x, y);
ans += getSum(tim[top[x]], tim[x], 1, ti, 1);
x = fa[top[x]];
}
if(dep[x] > dep[y]) std::swap(x, y);
ans += getSum(tim[x], tim[y], 1, ti, 1);
return ans;
}
} a;
int main() {
/*==================================*/
#ifdef LOCAL_JUDGE
freopen("\\Codes\\in.in", "r", stdin);
freopen("\\Codes\\out.out", "w", stdout);
#endif
/*==================================*/
scanf("%lld%lld", &n, &m);
for(llint i = 1; i <= n; ++i) {
scanf("%lld", &a.we[i]);
}
llint u, v;
for(llint i = 1; i < n; ++i) {
scanf("%lld%lld", &u, &v);
a.addEdge(u, v);
a.addEdge(v, u);
}
a.findSize(1, 0);
a.partition(1, 1);
a.build(1, 1, ti);
llint opt;
for(llint i = 1; i <= m; ++i) {
scanf("%lld%lld", &opt, &u);
if(opt == 1) {
scanf("%lld", &v);
a.update(a.tim[u], a.tim[u], v, 1, ti, 1);
} else if(opt == 2) {
scanf("%lld", &v);
a.multiUpdate(u, v);
} else {
printf("%lld\n", a.query(u));
}
}
return 0;
}