题目链接
https://codeforces.com/problemset/problem/191/C
思路
一道比较板的LCA和树上差分的题。
先预处理出这棵树的LCA,之后对于每一对 a i , b i a_{i},b_{i} ai,bi,在树上做差分,最后用 d f s dfs dfs处理差分数组即可。
树上差分记得从叶子向根节点,不要弄反。
代码
#include <bits/stdc++.h>
using namespace std;
#define int long long
const int N = 1e5 + 5;
const int inf = 0x3f3f3f3f3f3f3f3f;
int n, k;
int u[N], v[N];
int d[N];//差分数组
struct LCA
{vector<vector<int>>mp;vector<int>depth;vector<vector<int>>fa;LCA() {}LCA(int n) {init(n);}void init(int n){mp.resize(n + 1);depth.resize(n + 1);fa.resize(n + 1, vector<int>(20));}void add_edge(int a, int b){//建双向边mp[a].push_back(b);mp[b].push_back(a);}void bfs(int root){fill(depth.begin(), depth.end(), inf);depth[0] = 0, depth[root] = 1;queue<int>q;q.push(root);while (q.size()){int u = q.front();q.pop();for (int i = 0; i < mp[u].size(); i++){int j = mp[u][i];if (depth[j] > depth[u] + 1){depth[j] = depth[u] + 1;q.push(j);fa[j][0] = u;for (int k = 1; k <= 19; k++){fa[j][k] = fa[fa[j][k - 1]][k - 1];}}}}}int lca(int a, int b){if (depth[a] < depth[b]) swap(a, b);for (int k = 19; k >= 0; k -- )if (depth[fa[a][k]] >= depth[b])a = fa[a][k];if (a == b) return a;for (int k = 19; k >= 0; k -- )if (fa[a][k] != fa[b][k]){a = fa[a][k];b = fa[b][k];}return fa[a][0];}
};
void solve()
{cin >> n;LCA tree(n);map<int, int>st;for (int i = 1; i < n; i++){cin >> u[i] >> v[i];tree.add_edge(u[i], v[i]);st[u[i] * n + v[i]] = i;st[v[i] * n + u[i]] = i;}tree.bfs(1);cin >> k;for (int i = 1, a, b; i <= k; i++){cin >> a >> b;int zu = tree.lca(a, b);if (zu != a && zu != b){d[a]++, d[b]++, d[zu] -= 2;}else{if (zu == a) d[b]++, d[a]--;else d[a]++, d[b]--;}}auto dfs1 = [&](auto dfs1, int u, int fu)->void {for (int j : tree.mp[u]){if (j == fu) continue;dfs1(dfs1, j, u);d[u] += d[j];}};dfs1(dfs1, 1, -1);//树上差分->前缀和vector<int>ans(n);auto dfs2 = [&](auto dfs2, int u, int fu)-> void {for (int j : tree.mp[u]){if (j == fu) continue;int idx = st[u * n + j];ans[idx] = d[j];dfs2(dfs2, j, u);}};dfs2(dfs2, 1, -1);for (int i = 1; i < n; i++){cout << ans[i] << " ";}cout << endl;
}
signed main()
{ios::sync_with_stdio(false);cin.tie(0), cout.tie(0);int test = 1;// cin >> test;for (int i = 1; i <= test; i++){solve();}return 0;
}