#include <iostream> #include <cstdint> #include <list> #include <vector> #include <unordered_set> #include <stack> #include <map> #include <algorithm> #include <string> using namespace std; struct Stacja { list<Stacja*> incoming; unordered_set<Stacja*> outgoing; uint32_t finalWidth; uint32_t id; bool uglyFlag; Stacja(uint32_t i): id(i), finalWidth(0), uglyFlag(false) {} }; uint32_t n, m; vector<Stacja*> stacje; uint64_t cost; void readInput() { cin >> n >> m; stacje.reserve(n); for (uint32_t i=0; i<n; i++) stacje.push_back(new Stacja(i+1)); for (uint32_t i=0; i<n-1; i++) { uint32_t u,v; cin >> u >> v; u--;v--; stacje[u]->outgoing.insert(stacje[v]); stacje[v]->outgoing.insert(stacje[u]); } for (uint32_t i=0; i<m; i++) { uint32_t maxRail; cin >> maxRail; stacje[i]->finalWidth = maxRail; } } void sortLeavesByRailDesc() { sort(stacje.begin(), stacje.begin()+m, [](const Stacja* a, const Stacja* b) { return a->finalWidth > b->finalWidth; }); } template< typename T > T absdiff( const T& lhs, const T& rhs ) { return lhs>rhs ? lhs-rhs : rhs-lhs; } //#define ERRPRINT int32_t canPullDown(Stacja* stacja, uint32_t width, uint32_t depth) { int32_t result = 0; if (stacja->finalWidth > 0) { if (stacja->finalWidth <= width) result++; else if (stacja->finalWidth > width) result--; } else { stacja->uglyFlag = true; for (auto it = stacja->incoming.begin(); it != stacja->incoming.end(); ++it) if (!(*it)->uglyFlag) { #ifdef ERRPRINT cerr << string(depth, ' ') << "incoming: " << endl; #endif result += canPullDown(*it, width, depth+1); } uint32_t remaining = stacja->outgoing.size(); for (auto it = stacja->outgoing.begin(); (it != stacja->outgoing.end()) && (result <= remaining); ++it, remaining--) if (!(*it)->uglyFlag) { #ifdef ERRPRINT cerr << string(depth, ' ') << "outgoing: " << endl; #endif result += canPullDown(*it, width, depth+1); } #ifdef ERRPRINT cerr << string(depth, ' ') << "skipped " << remaining << endl; #endif stacja->uglyFlag = false; } #ifdef ERRPRINT cerr << string(depth, ' ') << "canPullDown(" << stacja->id << ", " << width << ") = " << result << endl; #endif return result == 0 ? 0 : result/abs(result); } void propagateRails() { stack<Stacja*> stos; for (auto it = stacje.begin(); it != stacje.begin()+m; ++it) stos.push(*it); while (!stos.empty()) { Stacja* curr = stos.top(); stos.pop(); #ifdef ERRPRINT cerr << "Strarting for stacja " << curr->id << ", finalWidth = " << curr->finalWidth << endl; #endif // policz bazarki dla incoming for (auto it = curr->incoming.begin(); it != curr->incoming.end(); ++it) { uint32_t bazarek = absdiff(curr->finalWidth, (*it)->finalWidth); // cerr << "Bazarek do stacji " << (*it)->id << " = " << bazarek << endl; cost += bazarek; } // dla wszystkich outgoing for (auto it = curr->outgoing.begin(); it != curr->outgoing.end(); ++it) { Stacja* dest = *it; // przenies sie do incoming dest->outgoing.erase(curr); dest->incoming.push_back(curr); // sprawdz czy nie ma finalWidth if (dest->finalWidth == 0) { // sprawdz czy mozesz policzyc finalWidth if (canPullDown(dest, curr->finalWidth, 0) >= 0) { dest->finalWidth = curr->finalWidth; stos.push(dest); } } } } } int main() { readInput(); sortLeavesByRailDesc(); propagateRails(); cout << cost; return 0; }
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 | #include <iostream> #include <cstdint> #include <list> #include <vector> #include <unordered_set> #include <stack> #include <map> #include <algorithm> #include <string> using namespace std; struct Stacja { list<Stacja*> incoming; unordered_set<Stacja*> outgoing; uint32_t finalWidth; uint32_t id; bool uglyFlag; Stacja(uint32_t i): id(i), finalWidth(0), uglyFlag(false) {} }; uint32_t n, m; vector<Stacja*> stacje; uint64_t cost; void readInput() { cin >> n >> m; stacje.reserve(n); for (uint32_t i=0; i<n; i++) stacje.push_back(new Stacja(i+1)); for (uint32_t i=0; i<n-1; i++) { uint32_t u,v; cin >> u >> v; u--;v--; stacje[u]->outgoing.insert(stacje[v]); stacje[v]->outgoing.insert(stacje[u]); } for (uint32_t i=0; i<m; i++) { uint32_t maxRail; cin >> maxRail; stacje[i]->finalWidth = maxRail; } } void sortLeavesByRailDesc() { sort(stacje.begin(), stacje.begin()+m, [](const Stacja* a, const Stacja* b) { return a->finalWidth > b->finalWidth; }); } template< typename T > T absdiff( const T& lhs, const T& rhs ) { return lhs>rhs ? lhs-rhs : rhs-lhs; } //#define ERRPRINT int32_t canPullDown(Stacja* stacja, uint32_t width, uint32_t depth) { int32_t result = 0; if (stacja->finalWidth > 0) { if (stacja->finalWidth <= width) result++; else if (stacja->finalWidth > width) result--; } else { stacja->uglyFlag = true; for (auto it = stacja->incoming.begin(); it != stacja->incoming.end(); ++it) if (!(*it)->uglyFlag) { #ifdef ERRPRINT cerr << string(depth, ' ') << "incoming: " << endl; #endif result += canPullDown(*it, width, depth+1); } uint32_t remaining = stacja->outgoing.size(); for (auto it = stacja->outgoing.begin(); (it != stacja->outgoing.end()) && (result <= remaining); ++it, remaining--) if (!(*it)->uglyFlag) { #ifdef ERRPRINT cerr << string(depth, ' ') << "outgoing: " << endl; #endif result += canPullDown(*it, width, depth+1); } #ifdef ERRPRINT cerr << string(depth, ' ') << "skipped " << remaining << endl; #endif stacja->uglyFlag = false; } #ifdef ERRPRINT cerr << string(depth, ' ') << "canPullDown(" << stacja->id << ", " << width << ") = " << result << endl; #endif return result == 0 ? 0 : result/abs(result); } void propagateRails() { stack<Stacja*> stos; for (auto it = stacje.begin(); it != stacje.begin()+m; ++it) stos.push(*it); while (!stos.empty()) { Stacja* curr = stos.top(); stos.pop(); #ifdef ERRPRINT cerr << "Strarting for stacja " << curr->id << ", finalWidth = " << curr->finalWidth << endl; #endif // policz bazarki dla incoming for (auto it = curr->incoming.begin(); it != curr->incoming.end(); ++it) { uint32_t bazarek = absdiff(curr->finalWidth, (*it)->finalWidth); // cerr << "Bazarek do stacji " << (*it)->id << " = " << bazarek << endl; cost += bazarek; } // dla wszystkich outgoing for (auto it = curr->outgoing.begin(); it != curr->outgoing.end(); ++it) { Stacja* dest = *it; // przenies sie do incoming dest->outgoing.erase(curr); dest->incoming.push_back(curr); // sprawdz czy nie ma finalWidth if (dest->finalWidth == 0) { // sprawdz czy mozesz policzyc finalWidth if (canPullDown(dest, curr->finalWidth, 0) >= 0) { dest->finalWidth = curr->finalWidth; stos.push(dest); } } } } } int main() { readInput(); sortLeavesByRailDesc(); propagateRails(); cout << cost; return 0; } |