1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
#include <algorithm>
#include <cassert>
#include <iostream>
#include <iterator>
#include <limits>
#include <vector>

#ifdef DEBUG
static constexpr const bool debug = true;
#else
static constexpr const bool debug = false;
#endif

using vertexid_t = ssize_t;
using gauge_t = int32_t;
using cost_t = int64_t;

struct graph_t {
    ssize_t n, m;
    struct vertex_t {
        std::vector<vertexid_t> adj;
        gauge_t gauge;

        vertex_t() : gauge{-1} { }
    };
    std::vector<vertex_t> vertices;

    friend std::istream& operator>>(std::istream &str, graph_t &input) {
        str >> input.n >> input.m;
        input.vertices.clear();
        input.vertices.resize(input.n);

        for(ssize_t i = 1; i < input.n; ++i) {
            vertexid_t a, b;
            str >> a >> b;
            --a, --b;
            input.vertices[a].adj.push_back(b);
            input.vertices[b].adj.push_back(a);
        }

        for(ssize_t i = 0; i < input.m; ++i) {
            str >> input.vertices[i].gauge;
        }

        return str;
    }
};

class solver_t {
    public:
    solver_t(const solver_t&) = default;
    solver_t(solver_t&&) = default;

    template<class Graph>
    solver_t(Graph &&graph) : graph(std::forward<Graph>(graph)) { }

    cost_t operator()() const && {
        if(graph.n == graph.m) {
            assert(graph.n == 2);
            return std::abs(graph.vertices[0].gauge - graph.vertices[1].gauge);
        }

        return dfs(graph.n - 1, -1).cost;
    }

    private:
    using gauge_pair_t = std::pair<gauge_t, gauge_t>;

    struct result_t {
        cost_t cost;
        gauge_pair_t gauges;
    };

    result_t dfs(vertexid_t vertexid, vertexid_t parentid) const {
        if(vertexid < graph.m) {
            auto gauge = graph.vertices[vertexid].gauge;
            return {0, {gauge, gauge}};
        }

        cost_t cost = 0;
        std::vector<gauge_pair_t> results;

        for(const auto neighbourid: graph.vertices[vertexid].adj) {
            if(neighbourid != parentid) {
                const auto r = dfs(neighbourid, vertexid);
                cost += r.cost;
                results.push_back(std::move(r.gauges));
            }
        }

        auto result = compute(std::move(results));
        result.cost += cost;
        
        if(debug) {
            std::cerr << vertexid << ": " << result.cost << " -> " << result.gauges.first << ":" << result.gauges.second << std::endl;
        }

        return result;
    }

    static gauge_pair_t merge(const std::vector<gauge_pair_t> &gauge_pairs) {
        if(debug) {
            std::cerr << "Merging: ";
            for(const auto &gauge_pair: gauge_pairs)
                std::cerr << gauge_pair.first << ":" << gauge_pair.second << " ";
            std::cerr << std::endl;
        }

        std::vector<gauge_t> gauges;

        for(const auto &gauge_pair: gauge_pairs) {
            gauges.push_back(gauge_pair.first);
            gauges.push_back(gauge_pair.second);
        } 

        std::sort(gauges.begin(), gauges.end());

        auto idx = gauges.size() / 2;

        return {gauges[idx - 1], gauges[idx]};
    }

    static cost_t costfor(const gauge_t gauge, const std::vector<gauge_pair_t> &gauge_pairs) {
        cost_t cost = 0;
        for(const auto &gauge_pair: gauge_pairs) {
            if(gauge < gauge_pair.first)
                cost += gauge_pair.first - gauge;
            else if(gauge > gauge_pair.second)
                cost += gauge - gauge_pair.second;
        }
        return cost;
    }

    static result_t compute(const std::vector<gauge_pair_t> &gauge_pairs) {
        const auto gauge_pair = merge(gauge_pairs);
        const auto cost = costfor(gauge_pair.first, gauge_pairs);
        return {cost, gauge_pair};
    }

    graph_t graph;
};

int main() {
    std::ios_base::sync_with_stdio(false);
    std::cin.tie(nullptr);

    graph_t graph;
    std::cin >> graph;
    std::cout << solver_t{std::move(graph)}() << std::endl;
}