#include <iostream> #include <cstdio> #include <string> #include <vector> #include <set> #include <map> #include <queue> #include <cmath> #include <algorithm> #include <sstream> #include <stack> #include <cstring> #include <iomanip> #include <ctime> #include <cassert> using namespace std; #define pb push_back #define INF 1001001001 #define FOR(i,n) for(int (i)=0;(i)<(n);++(i)) #define FORI(i,n) for(int (i)=1;(i)<=(n);++(i)) #define mp make_pair #define pii pair<int,int> #define ll long long #define vi vector<int> #define SZ(x) ((int)((x).size())) #define fi first #define se second #define wez(n) int (n); scanf("%d",&(n)); #define wez2(n,m) int (n),(m); scanf("%d %d",&(n),&(m)); #define wez3(n,m,k) int (n),(m),(k); scanf("%d %d %d",&(n),&(m),&(k)); inline void pisz(int n) { printf("%d\n",n); } template<typename T,typename TT> ostream& operator<<(ostream &s,pair<T,TT> t) {return s<<"("<<t.first<<","<<t.second<<")";} template<typename T> ostream& operator<<(ostream &s,vector<T> t){FOR(i,SZ(t))s<<t[i]<<" ";return s; } #define DBG(vari) cout<<"["<<__LINE__<<"] "<<#vari<<" = "<<(vari)<<endl; #define ALL(t) t.begin(),t.end() #define FOREACH(i,t) for (__typeof(t.begin()) i=t.begin(); i!=t.end(); i++) #define TESTS wez(testow)while(testow--) #define REP(i,a,b) for(int (i)=(a);(i)<=(b);++i) #define REPD(i,a,b) for(int (i)=(a); (i)>=(b);--i) #define REMAX(a,b) (a)=max((a),(b)); #define REMIN(a,b) (a)=min((a),(b)); #define IOS ios_base::sync_with_stdio(0);); const int N = 500500; vi adj[N]; int l[N], r[N]; void dfs (int v, int p) { if (adj[v].empty()) return; vi events; int bal = 0; for (int x : adj[v]) if (x != p) { dfs(x, v); events.pb(l[x]); events.pb(r[x]); --bal; } sort(ALL(events)); for (int ev : events) { ++bal; if (bal == 0) l[v] = ev; else if (bal == 1) r[v] = ev; } } ll res = 0; void compute (int v, int p, int pval) { int myval; if (p == -1) { myval = l[v]; } else { if (pval < l[v]) { myval = l[v]; } else if (r[v] < pval) { myval = r[v]; } else { myval = pval; } res += abs(myval - pval); } for (int x : adj[v]) if (x != p) { compute(x, v, myval); } } int main () { wez2(n,m) FOR(u,n-1) { wez2(a,b) adj[a].pb(b); adj[b].pb(a); } FORI(i,m) { wez(x) l[i] = r[i] = x; } if (n == m) { // n == 2 pisz(abs(r[1] - r[2])); return 0; } dfs(n, -1); compute(n, -1, -1); cout << res; }
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 | #include <iostream> #include <cstdio> #include <string> #include <vector> #include <set> #include <map> #include <queue> #include <cmath> #include <algorithm> #include <sstream> #include <stack> #include <cstring> #include <iomanip> #include <ctime> #include <cassert> using namespace std; #define pb push_back #define INF 1001001001 #define FOR(i,n) for(int (i)=0;(i)<(n);++(i)) #define FORI(i,n) for(int (i)=1;(i)<=(n);++(i)) #define mp make_pair #define pii pair<int,int> #define ll long long #define vi vector<int> #define SZ(x) ((int)((x).size())) #define fi first #define se second #define wez(n) int (n); scanf("%d",&(n)); #define wez2(n,m) int (n),(m); scanf("%d %d",&(n),&(m)); #define wez3(n,m,k) int (n),(m),(k); scanf("%d %d %d",&(n),&(m),&(k)); inline void pisz(int n) { printf("%d\n",n); } template<typename T,typename TT> ostream& operator<<(ostream &s,pair<T,TT> t) {return s<<"("<<t.first<<","<<t.second<<")";} template<typename T> ostream& operator<<(ostream &s,vector<T> t){FOR(i,SZ(t))s<<t[i]<<" ";return s; } #define DBG(vari) cout<<"["<<__LINE__<<"] "<<#vari<<" = "<<(vari)<<endl; #define ALL(t) t.begin(),t.end() #define FOREACH(i,t) for (__typeof(t.begin()) i=t.begin(); i!=t.end(); i++) #define TESTS wez(testow)while(testow--) #define REP(i,a,b) for(int (i)=(a);(i)<=(b);++i) #define REPD(i,a,b) for(int (i)=(a); (i)>=(b);--i) #define REMAX(a,b) (a)=max((a),(b)); #define REMIN(a,b) (a)=min((a),(b)); #define IOS ios_base::sync_with_stdio(0);); const int N = 500500; vi adj[N]; int l[N], r[N]; void dfs (int v, int p) { if (adj[v].empty()) return; vi events; int bal = 0; for (int x : adj[v]) if (x != p) { dfs(x, v); events.pb(l[x]); events.pb(r[x]); --bal; } sort(ALL(events)); for (int ev : events) { ++bal; if (bal == 0) l[v] = ev; else if (bal == 1) r[v] = ev; } } ll res = 0; void compute (int v, int p, int pval) { int myval; if (p == -1) { myval = l[v]; } else { if (pval < l[v]) { myval = l[v]; } else if (r[v] < pval) { myval = r[v]; } else { myval = pval; } res += abs(myval - pval); } for (int x : adj[v]) if (x != p) { compute(x, v, myval); } } int main () { wez2(n,m) FOR(u,n-1) { wez2(a,b) adj[a].pb(b); adj[b].pb(a); } FORI(i,m) { wez(x) l[i] = r[i] = x; } if (n == m) { // n == 2 pisz(abs(r[1] - r[2])); return 0; } dfs(n, -1); compute(n, -1, -1); cout << res; } |