bzoj 4530 [Bjoi2014]大融合
内容
Description
小强要在N个孤立的星球上建立起一套通信系统。这套通信系统就是连接N个点的一个树。
这个树的边是一条一条添加上去的。在某个时刻,一条边的负载就是它所在的当前能够
联通的树上路过它的简单路径的数量。
例如,在上图中,现在一共有了5条边。其中,(3,8)这条边的负载是6,因
为有六条简单路径2-3-8,2-3-8-7,3-8,3-8-7,4-3-8,4-3-8-7路过了(3,8)。
现在,你的任务就是随着边的添加,动态的回答小强对于某些边的负载的
询问。
Input
第一行包含两个整数N,Q,表示星球的数量和操作的数量。星球从1开始编号。
接下来的Q行,每行是如下两种格式之一:
A x y 表示在x和y之间连一条边。保证之前x和y是不联通的。
Q x y 表示询问(x,y)这条边上的负载。保证x和y之间有一条边。
1≤N,Q≤100000
Output
对每个查询操作,输出被查询的边的负载。
Sample Input
8 6
A 2 3
A 3 4
A 3 8
A 8 7
A 6 5
Q 3 8
Sample Output
6
题解
答案就是两边的子树大小乘积。或者子树大小乘父亲的联通块大小-子树大小
用并查集找到节点的祖先,维护子树size
这个东西可以用线段树合并来做,查询就是查询dfs序上的一段区间
#include <cstdio>
#include <algorithm>
#define mid (l + r >> 1)
#define lson t[rt].l, l, mid
#define rson t[rt].r, mid + 1, r
using namespace std;
const int N = 200010;
struct Ufs {
int fa[N];
void init(int n) { for (int i = 1; i <= n; ++i) fa[i] = i; }
int finds(int x) { return fa[x] != x ? fa[x] = finds(fa[x]) : x; }
}ufs;
struct opts {
int t, x, y;
}opt[N];
struct segt {
int l, r, sum;
}t[N * 20];
struct edge {
int to, next;
}e[N];
int st[N], tot, sz, dfs_id, in[N], out[N], root[N], h[N], fa[N], n, q;
char s[10];
void add(int x, int y) {
e[++tot].next = st[x];
e[tot].to = y, st[x] = tot;
}
void update(int rt) { t[rt].sum = t[t[rt].l].sum + t[t[rt].r].sum; }
void inserts(int pos, int &rt, int l, int r) {
rt = ++sz;
if (l == r) { t[rt].sum = 1; return; }
if (pos <= mid) inserts(pos, lson);
else inserts(pos, rson);
update(rt);
}
int merges(int x, int y) {
if (!x) return y;
if (!y) return x;
t[x].l = merges(t[x].l, t[y].l);
t[x].r = merges(t[x].r, t[y].r);
update(x); return x;
}
void dfs(int x) {
in[x] = ++dfs_id, inserts(in[x], root[x], 1, n);
for (int i = st[x]; i; i = e[i].next)
if (e[i].to != fa[x])
fa[e[i].to] = x, h[e[i].to] = h[x] + 1,
dfs(e[i].to);
out[x] = dfs_id;
}
int query(int rt, int l, int r, int L, int R) {
if (L <= l && R >= r) return t[rt].sum;
int ans = 0;
if (L <= mid) ans += query(lson, L, R);
if (R > mid) ans += query(rson, L, R);
return ans;
}
main() {
scanf("%d%d", &n, &q);
for (int i = 1; i <= q; ++i) {
scanf("%s%d%d", s, &opt[i].x, &opt[i].y);
if (s[0] == 'A') opt[i].t = 1;
if (opt[i].t)
add(opt[i].x, opt[i].y), add(opt[i].y, opt[i].x);
}
for (int i = 1; i <= n; ++i) if (!fa[i]) dfs(i);
ufs.init(n);
for (int i = 1; i <= q; ++i)
if (opt[i].t) {
int x = ufs.finds(opt[i].x), y = ufs.finds(opt[i].y);
if (h[x] > h[y]) swap(x, y);
root[x] = merges(root[x], root[y]);
ufs.fa[y] = x;
} else {
int x = opt[i].x, y = opt[i].y;
if (h[x] < h[y]) swap(x, y);
int fx = ufs.finds(x);
long long ans = query(root[fx], 1, n, in[x], out[x]);
printf("%lld\n", (1ll * t[root[fx]].sum - ans) * ans);
}
}