bzoj 3772 精神污染
内容
Description
兵库县位于日本列岛的中央位置,北临日本海,南面濑户内海直通太平洋,中央部位是森林和山地,与拥有关西机场的大阪府比邻而居,是关西地区面积最大的县,是集经济和文化于一体的一大地区,是日本西部门户,海陆空交通设施发达。濑户内海沿岸气候温暖,多晴天,有日本少见的贸易良港神户港所在的神户市和曾是豪族城邑“城下町”的姬路市等大城市,还有以疗养地而闻名的六甲山地等。
兵库县官方也大力发展旅游,为了方便,他们在县内的N个旅游景点上建立了n-1条观光道,构成了一棵图论中的树。同时他们推出了M条观光线路,每条线路由两个节点x和y指定,经过的旅游景点就是树上x到y的唯一路径上的点。保证一条路径只出现一次。
你和你的朋友打算前往兵库县旅游,但旅行社还没有告知你们最终选择的观光线路是哪一条(假设是线路A)。这时候你得到了一个消息:在兵库北有一群丧心病狂的香菜蜜,他们已经选定了一条观光线路(假设是线路B),对这条路线上的所有景点都释放了【精神污染】。这个计划还有可能影响其他的线路,比如有四个景点1-2-3-4,而【精神污染】的路径是1-4,那么1-3,2-4,1-2等路径也被视为被完全污染了。
现在你想知道的是,假设随便选择两条不同的路径A和B,存在一条路径使得如果这条路径被污染,另一条路径也被污染的概率。换句话说,一条路径被另一条路径包含的概率。
Input
第一行两个整数N,M
接下来N-1行,每行两个数a,b,表示A和B之间有一条观光道。
接下来M行,每行两个数x,y,表示一条旅游线路。
Output
所求的概率,以最简分数形式输出。
Sample Input
5 3
1 2
2 3
3 4
2 5
3 5
2 5
1 4
Sample Output
1/3
样例解释
可以选择的路径对有(1,2),(1,3),(2,3),只有路径1完全覆盖路径2。
HINT
100%的数据满足:N,M<=100000
题解
恶心的题。。转个学姐的blog
对于每一条链计算能完全覆盖它的有多少条。
处理出来dfs序了之后,可以发现边大概分为三种情况:x和y的lca不是x和y中的某一个,x和y的lca是x或y,还有就是一条路径就是一个点。对于第一种情况,能覆盖它的路径一定是一个端点在x的子树里,一个端点在y的子树里。对于第二种情况,能覆盖它的路径一定是一个端点在y的子树里,另外一个端点在x的子树外。对于第三种情况,除了计算和第二种情况相似的之外,还要计算有多少条路径的lca是这个点,也就是覆盖了这个点。
用权值线段树套权值线段树即可,内层用动态开点,那么空间和时间都是O(nlogn)O(nlogn)。外层表示路径端点较小的dfs序,内层表示路径端点较大的dfs序。
#include <cstdio>
#include <algorithm>
using namespace std;
const int N = 100010;
int in[N], out[N], last, n, m, x, y;
struct data {
int x, y, l, las;
}a[N];
bool comp(data a, data b) {
return in[a.x] == in[b.x] ? in[a.y] < in[a.x] : in[a.x] < in[b.x];
}
struct edge {
int to, next;
}e[N << 1];
int st[N], tot, h[N], f[N][25], ds, cnt[N];
void add(int x, int y) {
e[++tot].next = st[x];
e[tot].to = y, st[x] = tot;
}
void dfs(int x, int pre) {
in[x] = ++ds, h[x] = h[pre] + 1;
for (int i = 1; i < 20; ++i) {
if (h[x] - (1 << i) < 1) break;
f[x][i] = f[f[x][i - 1]][i - 1];
}
for (int i = st[x]; i; i = e[i].next)
if (e[i].to != pre)
f[e[i].to][0] = x, dfs(e[i].to, x);
out[x] = ds;
}
int lca(int x, int y) {
if (x == y) return last = x;
if (h[x] < h[y]) swap(x, y);
for (int i = 19; i >= 0; --i)
while (h[f[x][i]] > h[y])
x = f[x][i];
last = x;
if (h[x] > h[y]) x = f[x][0];
if (x == y) return x;
for (int i = 19; i >= 0; --i)
if (f[x][i] != f[y][i])
x = f[x][i], y = f[y][i];
return f[x][0];
}
int root[N];
struct segt {
int l, r;
long long sum;
}pri[N * 20];
int pot;
void inserts(int pos, int &now, int l, int r) {
pri[++pot] = pri[now], now = pot, pri[now].sum++;
if (l == r) return;
int mid = l + r >> 1;
if (pos <= mid) inserts(pos, pri[now].l, l, mid);
else inserts(pos, pri[now].r, mid + 1, r);
}
long long query(int now, int l, int r, int L, int R) {
if (L <= l && r <= R) return pri[now].sum;
int mid = l + r >> 1; long long ans = 0;
if (L <= mid) ans += query(pri[now].l, l, mid, L, R);
if (R > mid) ans += query(pri[now].r, mid + 1, r, L, R);
return ans;
}
long long gcd(long long a, long long b) { return !b ? a : gcd(b, a % b); }
long long ans;
main() {
scanf("%d%d", &n, &m);
for (int i = 1; i < n; ++i)
scanf("%d%d", &x, &y),
add(x, y), add(y, x);
dfs(1, 0);
for (int i = 1; i <= m; ++i) {
scanf("%d%d", &x, &y);
if (in[x] > in[y]) swap(x, y);
a[i].x = x, a[i].y = y, a[i].l = lca(x, y);
a[i].las = last, cnt[a[i].l]++;
}
sort(a + 1, a + m + 1, comp);
for (int i = 1; i <= m; ++i) {
if (i != 1 && in[a[i - 1].x] != in[a[i].x] - 1)
for (int j = in[a[i - 1].x] + 1; j < in[a[i].x]; ++j)
root[j] = root[j - 1];
root[in[a[i].x]] = root[in[a[i - 1].x]];
inserts(in[a[i].y], root[in[a[i].x]], 1, n);
}
for (int i = in[a[m].x] + 1; i <= n; ++i) root[i] = root[i - 1];
for (int i = 1; i <= m; ++i) {
int l = a[i].l; last = a[i].las;
if (l != a[i].x) {
int o = in[a[i].x], u = out[a[i].x], k = in[a[i].y], g = out[a[i].y];
ans += query(root[u], 1, n, k, g) - query(root[o - 1], 1, n, k, g);
} else {
if (a[i].x != a[i].y) {
int o = 1, u = in[last] - 1, k = in[a[i].y], g = out[a[i].y];
if (o <= u) ans += query(root[u], 1, n, k, g) - query(root[o - 1], 1, n, k, g);
o = in[a[i].y], u = out[a[i].y], k = out[last] + 1, g = n;
if (k <= g) ans += query(root[u], 1, n, k, g) - query(root[o - 1], 1, n, k, g);
} else {
int o = 1, u = in[last] - 1, k = in[last], g = out[last];
if (o <= u) ans += query(root[u], 1, n, k, g) - query(root[o - 1], 1, n, k, g);
o = in[last], u = out[last], k = out[last] + 1, g = n;
if (k <= g) ans += query(root[u], 1, n, k, g) - query(root[o - 1], 1, n, k, g);
ans += cnt[last];
}
}
}
ans -= m; long long po = (long long)m * ((long long)m - 1) / 2ll;
long long gc = gcd(ans, po); ans /= gc, po /= gc;
printf("%lld/%lld", ans, po);
}