首页 > 题解 > bzoj 4530 [Bjoi2014]大融合

bzoj 4530 [Bjoi2014]大融合

Description

小强要在N个孤立的星球上建立起一套通信系统。这套通信系统就是连接N个点的一个树。
这个树的边是一条一条添加上去的。在某个时刻,一条边的负载就是它所在的当前能够
联通的树上路过它的简单路径的数量。

例如,在上图中,现在一共有了5条边。其中,(3,8)这条边的负载是6,因
为有六条简单路径2-3-8,2-3-8-7,3-8,3-8-7,4-3-8,4-3-8-7路过了(3,8)。
现在,你的任务就是随着边的添加,动态的回答小强对于某些边的负载的
询问。

Input

第一行包含两个整数N,Q,表示星球的数量和操作的数量。星球从1开始编号。
接下来的Q行,每行是如下两种格式之一:
A x y 表示在x和y之间连一条边。保证之前x和y是不联通的。
Q x y 表示询问(x,y)这条边上的负载。保证x和y之间有一条边。
1≤N,Q≤100000

Output

对每个查询操作,输出被查询的边的负载。

Sample Input

8 6

A 2 3

A 3 4

A 3 8

A 8 7

A 6 5

Q 3 8

Sample Output

6

题解

答案就是两边的子树大小乘积。或者子树大小乘父亲的联通块大小-子树大小

用并查集找到节点的祖先,维护子树size

这个东西可以用线段树合并来做,查询就是查询dfs序上的一段区间

#include <cstdio>
#include <algorithm>
#define mid (l + r >> 1)
#define lson t[rt].l, l, mid
#define rson t[rt].r, mid + 1, r
using namespace std;
const int N = 200010;
struct Ufs {
    int fa[N];
    void init(int n) { for (int i = 1; i <= n; ++i) fa[i] = i; }
    int finds(int x) { return fa[x] != x ? fa[x] = finds(fa[x]) : x; }
}ufs;
struct opts {
    int t, x, y;
}opt[N];
struct segt {
    int l, r, sum;
}t[N * 20];
struct edge {
    int to, next;
}e[N];
int st[N], tot, sz, dfs_id, in[N], out[N], root[N], h[N], fa[N], n, q;
char s[10];
void add(int x, int y) {
    e[++tot].next = st[x];
    e[tot].to = y, st[x] = tot;
}
void update(int rt) { t[rt].sum = t[t[rt].l].sum + t[t[rt].r].sum; }
void inserts(int pos, int &rt, int l, int r) {
    rt = ++sz;
    if (l == r) { t[rt].sum = 1; return; }
    if (pos <= mid) inserts(pos, lson);
    else            inserts(pos, rson);
    update(rt);
}
int merges(int x, int y) {
    if (!x) return y;
    if (!y) return x;
    t[x].l = merges(t[x].l, t[y].l);
    t[x].r = merges(t[x].r, t[y].r);
    update(x); return x;
}
void dfs(int x) {
    in[x] = ++dfs_id, inserts(in[x], root[x], 1, n);
    for (int i = st[x]; i; i = e[i].next)
        if (e[i].to != fa[x])
            fa[e[i].to] = x, h[e[i].to] = h[x] + 1,
            dfs(e[i].to);
    out[x] = dfs_id;
}
int query(int rt, int l, int r, int L, int R) {
    if (L <= l && R >= r) return t[rt].sum;
    int ans = 0;
    if (L <= mid) ans += query(lson, L, R);
    if (R >  mid) ans += query(rson, L, R);
    return ans;
}
main() {
    scanf("%d%d", &n, &q);
    for (int i = 1; i <= q; ++i) {
        scanf("%s%d%d", s, &opt[i].x, &opt[i].y);
        if (s[0] == 'A') opt[i].t = 1;
        if (opt[i].t)
            add(opt[i].x, opt[i].y), add(opt[i].y, opt[i].x);
    }
    for (int i = 1; i <= n; ++i) if (!fa[i]) dfs(i);
    ufs.init(n);
    for (int i = 1; i <= q; ++i)
        if (opt[i].t) {
            int x = ufs.finds(opt[i].x), y = ufs.finds(opt[i].y);
            if (h[x] > h[y]) swap(x, y);
            root[x] = merges(root[x], root[y]);
            ufs.fa[y] = x;
        } else {
            int x = opt[i].x, y = opt[i].y;
            if (h[x] < h[y]) swap(x, y);
            int fx = ufs.finds(x);
            long long ans = query(root[fx], 1, n, in[x], out[x]);
            printf("%lld\n", (1ll * t[root[fx]].sum - ans) * ans);
        }
}