#include <cstdio> #include <vector> #include <cstdlib> #include <algorithm> using std::vector; using std::sort; #define ll long long const int maxN = 5e5 + 10; int from[maxN]; int to[maxN]; bool B[maxN]; vector<int>nbh[maxN]; int n, m; bool cmp(int a, int b) { if(abs(a) == abs(b)) return a > b; return abs(a) < abs(b); } long long res; int vis = 0; void dfs(int v) { vis++; if(nbh[v].size() == 1) return; B[v] = true; vector<int>val; for(int a: nbh[v]) if(!B[a]) { dfs(a); val.push_back(from[a]); val.push_back(-to[a]); } sort(val.begin(), val.end(), cmp); int beg = 0, end = 0; long long mn = 0, sum = 0; for(int i = val.size() - 2; i >= 0; --i) { sum += (abs(val[i + 1]) - abs(val[i])) * (ll) end; if(val[i] > 0) end++; } mn = sum; from[v] = to[v] = abs(val[0]); for(int i = 1; i < (int)val.size(); ++i) { if(val[i] > 0) end--; if(val[i] < 0) beg++; // printf(" %d(%lld)[%d %d] ", val[i], sum, beg, end); // printf(" --<%lld>--> ", (abs(val[i + 1]) - abs(val[i])) * (ll) (beg - end)); sum += (abs(val[i]) - abs(val[i - 1])) * (ll) (beg - end); if(sum < mn) { mn = sum; from[v] = abs(val[i]); } else if(sum == mn) to[v] = abs(val[i]); } // printf("\nfor %d mn is %lld [%d %d]\n", v, mn, from[v], to[v]); res += mn; B[v] = false; } int main() { scanf("%d%d", &n, &m); for(int a, b, i = 1; i < n; ++i) { scanf("%d%d", &a, &b); nbh[a].push_back(b); nbh[b].push_back(a); } for(int i = 1; i <= m; ++i) { scanf("%d", from + i); to[i] = from[i]; } if(n == m && m == 2) { res = abs(from[1] - from[2]); } else { dfs(n); } printf("%lld\n", res); return 0; }
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 | #include <cstdio> #include <vector> #include <cstdlib> #include <algorithm> using std::vector; using std::sort; #define ll long long const int maxN = 5e5 + 10; int from[maxN]; int to[maxN]; bool B[maxN]; vector<int>nbh[maxN]; int n, m; bool cmp(int a, int b) { if(abs(a) == abs(b)) return a > b; return abs(a) < abs(b); } long long res; int vis = 0; void dfs(int v) { vis++; if(nbh[v].size() == 1) return; B[v] = true; vector<int>val; for(int a: nbh[v]) if(!B[a]) { dfs(a); val.push_back(from[a]); val.push_back(-to[a]); } sort(val.begin(), val.end(), cmp); int beg = 0, end = 0; long long mn = 0, sum = 0; for(int i = val.size() - 2; i >= 0; --i) { sum += (abs(val[i + 1]) - abs(val[i])) * (ll) end; if(val[i] > 0) end++; } mn = sum; from[v] = to[v] = abs(val[0]); for(int i = 1; i < (int)val.size(); ++i) { if(val[i] > 0) end--; if(val[i] < 0) beg++; // printf(" %d(%lld)[%d %d] ", val[i], sum, beg, end); // printf(" --<%lld>--> ", (abs(val[i + 1]) - abs(val[i])) * (ll) (beg - end)); sum += (abs(val[i]) - abs(val[i - 1])) * (ll) (beg - end); if(sum < mn) { mn = sum; from[v] = abs(val[i]); } else if(sum == mn) to[v] = abs(val[i]); } // printf("\nfor %d mn is %lld [%d %d]\n", v, mn, from[v], to[v]); res += mn; B[v] = false; } int main() { scanf("%d%d", &n, &m); for(int a, b, i = 1; i < n; ++i) { scanf("%d%d", &a, &b); nbh[a].push_back(b); nbh[b].push_back(a); } for(int i = 1; i <= m; ++i) { scanf("%d", from + i); to[i] = from[i]; } if(n == m && m == 2) { res = abs(from[1] - from[2]); } else { dfs(n); } printf("%lld\n", res); return 0; } |