首页 > 题解 > bzoj 3772 精神污染

bzoj 3772 精神污染

Description

兵库县位于日本列岛的中央位置,北临日本海,南面濑户内海直通太平洋,中央部位是森林和山地,与拥有关西机场的大阪府比邻而居,是关西地区面积最大的县,是集经济和文化于一体的一大地区,是日本西部门户,海陆空交通设施发达。濑户内海沿岸气候温暖,多晴天,有日本少见的贸易良港神户港所在的神户市和曾是豪族城邑“城下町”的姬路市等大城市,还有以疗养地而闻名的六甲山地等。
兵库县官方也大力发展旅游,为了方便,他们在县内的N个旅游景点上建立了n-1条观光道,构成了一棵图论中的树。同时他们推出了M条观光线路,每条线路由两个节点x和y指定,经过的旅游景点就是树上x到y的唯一路径上的点。保证一条路径只出现一次。
你和你的朋友打算前往兵库县旅游,但旅行社还没有告知你们最终选择的观光线路是哪一条(假设是线路A)。这时候你得到了一个消息:在兵库北有一群丧心病狂的香菜蜜,他们已经选定了一条观光线路(假设是线路B),对这条路线上的所有景点都释放了【精神污染】。这个计划还有可能影响其他的线路,比如有四个景点1-2-3-4,而【精神污染】的路径是1-4,那么1-3,2-4,1-2等路径也被视为被完全污染了。
现在你想知道的是,假设随便选择两条不同的路径A和B,存在一条路径使得如果这条路径被污染,另一条路径也被污染的概率。换句话说,一条路径被另一条路径包含的概率。

Input

第一行两个整数N,M
接下来N-1行,每行两个数a,b,表示A和B之间有一条观光道。
接下来M行,每行两个数x,y,表示一条旅游线路。

Output

所求的概率,以最简分数形式输出。

Sample Input

5 3

1 2

2 3

3 4

2 5

3 5

2 5

1 4

Sample Output

1/3

样例解释

可以选择的路径对有(1,2),(1,3),(2,3),只有路径1完全覆盖路径2。

HINT

100%的数据满足:N,M<=100000

题解

恶心的题。。转个学姐的blog

对于每一条链计算能完全覆盖它的有多少条。
处理出来dfs序了之后,可以发现边大概分为三种情况:x和y的lca不是x和y中的某一个,x和y的lca是x或y,还有就是一条路径就是一个点。对于第一种情况,能覆盖它的路径一定是一个端点在x的子树里,一个端点在y的子树里。对于第二种情况,能覆盖它的路径一定是一个端点在y的子树里,另外一个端点在x的子树外。对于第三种情况,除了计算和第二种情况相似的之外,还要计算有多少条路径的lca是这个点,也就是覆盖了这个点。
用权值线段树套权值线段树即可,内层用动态开点,那么空间和时间都是O(nlogn)O(nlogn)。外层表示路径端点较小的dfs序,内层表示路径端点较大的dfs序。

#include <cstdio>
#include <algorithm>
using namespace std;
const int N = 100010;
int in[N], out[N], last, n, m, x, y;
struct data {
    int x, y, l, las;
}a[N];
bool comp(data a, data b) {
    return in[a.x] == in[b.x] ? in[a.y] < in[a.x] : in[a.x] < in[b.x];
}
struct edge {
    int to, next;
}e[N << 1];
int st[N], tot, h[N], f[N][25], ds, cnt[N];
void add(int x, int y) {
    e[++tot].next = st[x];
    e[tot].to = y, st[x] = tot;
}
void dfs(int x, int pre) {
    in[x] = ++ds, h[x] = h[pre] + 1;
    for (int i = 1; i < 20; ++i) {
        if (h[x] - (1 << i) < 1) break;
        f[x][i] = f[f[x][i - 1]][i - 1];
    }
    for (int i = st[x]; i; i = e[i].next)
        if (e[i].to != pre)
            f[e[i].to][0] = x, dfs(e[i].to, x);
    out[x] = ds;
}
int lca(int x, int y) {
    if (x == y) return last = x;
    if (h[x] < h[y]) swap(x, y);
    for (int i = 19; i >= 0; --i)
        while (h[f[x][i]] > h[y])
            x = f[x][i];
    last = x;
    if (h[x] > h[y]) x = f[x][0];
    if (x == y) return x;
    for (int i = 19; i >= 0; --i)
        if (f[x][i] != f[y][i])
            x = f[x][i], y = f[y][i];
    return f[x][0];
}
int root[N];
struct segt {
    int l, r;
    long long sum;
}pri[N * 20];
int pot;
void inserts(int pos, int &now, int l, int r) {
    pri[++pot] = pri[now], now = pot, pri[now].sum++;
    if (l == r) return;
    int mid = l + r >> 1;
    if (pos <= mid) inserts(pos, pri[now].l, l, mid);
    else inserts(pos, pri[now].r, mid + 1, r);
}
long long query(int now, int l, int r, int L, int R) {
    if (L <= l && r <= R) return pri[now].sum;
    int mid = l + r >> 1; long long ans = 0;
    if (L <= mid) ans += query(pri[now].l, l, mid, L, R);
    if (R >  mid) ans += query(pri[now].r, mid + 1, r, L, R);
    return ans;
}
long long gcd(long long a, long long b) { return !b ? a : gcd(b, a % b); }
long long ans;
main() {
    scanf("%d%d", &n, &m);
    for (int i = 1; i < n; ++i)
        scanf("%d%d", &x, &y),
        add(x, y), add(y, x);
    dfs(1, 0);
    for (int i = 1; i <= m; ++i) {
        scanf("%d%d", &x, &y);
        if (in[x] > in[y]) swap(x, y);
        a[i].x = x, a[i].y = y, a[i].l = lca(x, y);
        a[i].las = last, cnt[a[i].l]++;
    }
    sort(a + 1, a + m + 1, comp);
    for (int i = 1; i <= m; ++i) {
        if (i != 1 && in[a[i - 1].x] != in[a[i].x] - 1)
            for (int j = in[a[i - 1].x] + 1; j < in[a[i].x]; ++j)
                root[j] = root[j - 1];
        root[in[a[i].x]] = root[in[a[i - 1].x]];
        inserts(in[a[i].y], root[in[a[i].x]], 1, n);
    }
    for (int i = in[a[m].x] + 1; i <= n; ++i) root[i] = root[i - 1];
    for (int i = 1; i <= m; ++i) {
        int l = a[i].l; last = a[i].las;
        if (l != a[i].x) {
            int o = in[a[i].x], u = out[a[i].x], k = in[a[i].y], g = out[a[i].y];
            ans += query(root[u], 1, n, k, g) - query(root[o - 1], 1, n, k, g);
        } else {
            if (a[i].x != a[i].y) {
                int o = 1, u = in[last] - 1, k = in[a[i].y], g = out[a[i].y];
                if (o <= u) ans += query(root[u], 1, n, k, g) - query(root[o - 1], 1, n, k, g);
                o = in[a[i].y], u = out[a[i].y], k = out[last] + 1, g = n;
                if (k <= g) ans += query(root[u], 1, n, k, g) - query(root[o - 1], 1, n, k, g);
            } else {
                int o = 1, u = in[last] - 1, k = in[last], g = out[last];
                if (o <= u) ans += query(root[u], 1, n, k, g) - query(root[o - 1], 1, n, k, g);
                o = in[last], u = out[last], k = out[last] + 1, g = n;
                if (k <= g) ans += query(root[u], 1, n, k, g) - query(root[o - 1], 1, n, k, g);
                ans += cnt[last];
            }
        }
    }
    ans -= m; long long po = (long long)m * ((long long)m - 1) / 2ll;
    long long gc = gcd(ans, po); ans /= gc, po /= gc;
    printf("%lld/%lld", ans, po);
}