#include <iostream>
#include<algorithm>
#include <vector>
using namespace std;
class Node
{
public:
vector<int> incyd;
int left;
int right;
Node()
{
incyd.clear();
}
void getRange(Node* parent);
unsigned long long int setValues(Node* parent);
} all[500001];
void Node::getRange(Node* parent)
{
if(left==right&&left!=0)return;
vector<int> lewe,prawe;
lewe.reserve(incyd.size());
prawe.reserve(incyd.size());
int size=incyd.size();
for(vector<int>::iterator it=incyd.begin();
it!=incyd.end();
it++)
{
if(&all[*it]==parent)
{
size--;
continue;
}
all[*it].getRange(this);
lewe.push_back(all[*it].left);
prawe.push_back(all[*it].right);
}
sort(lewe.begin(),lewe.end());
sort(prawe.begin(),prawe.end());
if(size==0)
{
left=1;
right=500000;
}
else if(size==1)
{
left=lewe[0];
right=prawe[0];
}
else if(size%2==0)
{
left=lewe[size/2-1];
right=prawe[size/2];
}
else
{
left=lewe[size/2-1];
right=prawe[size/2+1];
if(size==incyd.size())
{
left=lewe[size/2];
right=prawe[size/2];
}
}
}
unsigned long long int Node::setValues(Node* parent)
{
unsigned long long int aux=0;
if(parent==0)right=left;
else
{
if(parent->left<left)right=left;
else if(parent->left>right)left=right;
else left=right=parent->left;
int x=parent->left-left;
if(x<0)aux=-x;
else aux=x;
}
for(vector<int>::iterator it=incyd.begin();
it!=incyd.end();
it++)
{
if(&all[(*it)]==parent)continue;
aux+=all[*it].setValues(this);
}
return aux;
}
int main()
{
int n,m;
cin>>n>>m;
for(int i=0;i<n-1;i++)
{
int x,y;
cin>>x>>y;
all[x].incyd.push_back(y);
all[y].incyd.push_back(x);
}
for(int i=0;i<m;i++)
{
int x;
cin>>x;
all[i+1].left=x;
all[i+1].right=x;
}
all[m+1].getRange(0);
cout<<all[m+1].setValues(0);
return 0;
}
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 | #include <iostream> #include<algorithm> #include <vector> using namespace std; class Node { public: vector<int> incyd; int left; int right; Node() { incyd.clear(); } void getRange(Node* parent); unsigned long long int setValues(Node* parent); } all[500001]; void Node::getRange(Node* parent) { if(left==right&&left!=0)return; vector<int> lewe,prawe; lewe.reserve(incyd.size()); prawe.reserve(incyd.size()); int size=incyd.size(); for(vector<int>::iterator it=incyd.begin(); it!=incyd.end(); it++) { if(&all[*it]==parent) { size--; continue; } all[*it].getRange(this); lewe.push_back(all[*it].left); prawe.push_back(all[*it].right); } sort(lewe.begin(),lewe.end()); sort(prawe.begin(),prawe.end()); if(size==0) { left=1; right=500000; } else if(size==1) { left=lewe[0]; right=prawe[0]; } else if(size%2==0) { left=lewe[size/2-1]; right=prawe[size/2]; } else { left=lewe[size/2-1]; right=prawe[size/2+1]; if(size==incyd.size()) { left=lewe[size/2]; right=prawe[size/2]; } } } unsigned long long int Node::setValues(Node* parent) { unsigned long long int aux=0; if(parent==0)right=left; else { if(parent->left<left)right=left; else if(parent->left>right)left=right; else left=right=parent->left; int x=parent->left-left; if(x<0)aux=-x; else aux=x; } for(vector<int>::iterator it=incyd.begin(); it!=incyd.end(); it++) { if(&all[(*it)]==parent)continue; aux+=all[*it].setValues(this); } return aux; } int main() { int n,m; cin>>n>>m; for(int i=0;i<n-1;i++) { int x,y; cin>>x>>y; all[x].incyd.push_back(y); all[y].incyd.push_back(x); } for(int i=0;i<m;i++) { int x; cin>>x; all[i+1].left=x; all[i+1].right=x; } all[m+1].getRange(0); cout<<all[m+1].setValues(0); return 0; } |
English