#include <bits/stdc++.h> #define REP(a,b) for(int a=0; a<(b); ++a) #define FWD(a,b,c) for(int a=(b); a<(c); ++a) #define FWDS(a,b,c,d) for(int a=(b); a<(c); a+=d) #define BCK(a,b,c) for(int a=(b); a>(c); --a) #define ALL(a) (a).begin(), (a).end() #define SIZE(a) ((int)(a).size()) #define SQ(a) ((a)*(a)) #define VAR(x) #x ": " << x << " " #define popcount __builtin_popcount #define popcountll __builtin_popcountll #define gcd __gcd #define x first #define y second #define st first #define nd second #define pb push_back using namespace std; template<typename T> ostream& operator<<(ostream &out, const vector<T> &v){ out << "{"; for(const T &a : v) out << a << ", "; out << "}"; return out; } template<typename S, typename T> ostream& operator<<(ostream &out, const pair<S,T> &p){ out << "(" << p.st << ", " << p.nd << ")"; return out; } typedef long long int64; typedef pair<int, int> PII; typedef pair<int64, int64> PLL; typedef long double K; typedef vector<int> VI; const int dx[] = {0,0,-1,1}; //1,1,-1,1}; const int dy[] = {-1,1,0,0}; //1,-1,1,-1}; const int64 INF = SQ(1000LL * 1000 * 1000); int n, m; int par[500010]; vector<int> edges[500010]; PII R[500010]; int64 cost; vector<PII> intervals; vector<PII> events; void dfs(int u){ if(u <= m) return; for(int v : edges[u]) if(!par[v]){ par[v] = u; dfs(v); } intervals.clear(); events.clear(); for(int v : edges[u]) if(par[v] == u){ intervals.push_back(R[v]); events.push_back(PII(R[v].st, 0)); events.push_back(PII(R[v].nd, 1)); } int lo = 500010, hi = -1; sort(events.begin(), events.end()); int left = 0, right = SIZE(intervals); for(PII e : events){ if(left == right) hi = max(hi, e.st); if(e.nd == 0){ --right; }else{ ++left; } if(left == right) lo = min(lo, e.st); } for(PII in : intervals){ if(in.nd < lo) cost += lo - in.nd; else if(in.st > lo) cost += in.st - lo; } R[u] = PII(lo, hi); return; } int main(){ scanf("%d %d", &n, &m); FWD(i,1,n){ int a, b; scanf("%d %d", &a, &b); edges[a].push_back(b); edges[b].push_back(a); } FWD(i,1,m+1){ scanf("%d", &R[i].st); R[i].nd = R[i].st; } if(m == n){ printf("%d\n", abs(R[1].st - R[2].st)); }else{ par[m+1] = -1; dfs(m+1); printf("%lld\n", cost); } return 0; }
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 | #include <bits/stdc++.h> #define REP(a,b) for(int a=0; a<(b); ++a) #define FWD(a,b,c) for(int a=(b); a<(c); ++a) #define FWDS(a,b,c,d) for(int a=(b); a<(c); a+=d) #define BCK(a,b,c) for(int a=(b); a>(c); --a) #define ALL(a) (a).begin(), (a).end() #define SIZE(a) ((int)(a).size()) #define SQ(a) ((a)*(a)) #define VAR(x) #x ": " << x << " " #define popcount __builtin_popcount #define popcountll __builtin_popcountll #define gcd __gcd #define x first #define y second #define st first #define nd second #define pb push_back using namespace std; template<typename T> ostream& operator<<(ostream &out, const vector<T> &v){ out << "{"; for(const T &a : v) out << a << ", "; out << "}"; return out; } template<typename S, typename T> ostream& operator<<(ostream &out, const pair<S,T> &p){ out << "(" << p.st << ", " << p.nd << ")"; return out; } typedef long long int64; typedef pair<int, int> PII; typedef pair<int64, int64> PLL; typedef long double K; typedef vector<int> VI; const int dx[] = {0,0,-1,1}; //1,1,-1,1}; const int dy[] = {-1,1,0,0}; //1,-1,1,-1}; const int64 INF = SQ(1000LL * 1000 * 1000); int n, m; int par[500010]; vector<int> edges[500010]; PII R[500010]; int64 cost; vector<PII> intervals; vector<PII> events; void dfs(int u){ if(u <= m) return; for(int v : edges[u]) if(!par[v]){ par[v] = u; dfs(v); } intervals.clear(); events.clear(); for(int v : edges[u]) if(par[v] == u){ intervals.push_back(R[v]); events.push_back(PII(R[v].st, 0)); events.push_back(PII(R[v].nd, 1)); } int lo = 500010, hi = -1; sort(events.begin(), events.end()); int left = 0, right = SIZE(intervals); for(PII e : events){ if(left == right) hi = max(hi, e.st); if(e.nd == 0){ --right; }else{ ++left; } if(left == right) lo = min(lo, e.st); } for(PII in : intervals){ if(in.nd < lo) cost += lo - in.nd; else if(in.st > lo) cost += in.st - lo; } R[u] = PII(lo, hi); return; } int main(){ scanf("%d %d", &n, &m); FWD(i,1,n){ int a, b; scanf("%d %d", &a, &b); edges[a].push_back(b); edges[b].push_back(a); } FWD(i,1,m+1){ scanf("%d", &R[i].st); R[i].nd = R[i].st; } if(m == n){ printf("%d\n", abs(R[1].st - R[2].st)); }else{ par[m+1] = -1; dfs(m+1); printf("%lld\n", cost); } return 0; } |