1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
#include <iostream>
#include <vector>
#include <algorithm>

enum class PointType {
    Opening, Closing
};

namespace std {
    std::string const to_string(PointType type) {
        switch(type) {
            case PointType::Opening: return "OPEN";
            case PointType::Closing: return "CLOSE";
        }
    }
}

struct Point {
    PointType type;
    unsigned coord;

    Point(PointType type, unsigned coord): type(type), coord(coord) {}

    bool operator < (Point const& rhs) const {
        return coord < rhs.coord;
    }

    static Point const make_open(unsigned coord) {
        return {PointType::Opening, coord};
    }

    static Point const make_close(unsigned coord) {
        return {PointType::Closing, coord};
    }
};

struct Interval {
    unsigned left, right;

    Interval(unsigned left, unsigned right): left(left), right(right) {}
    Interval(unsigned number): Interval(number, number) {}

    friend std::ostream & operator << (std::ostream & out, Interval const& in) {
        return out << "(" << in.left << ", " << in.right << ")";
    }
};

Interval const vertex_result(std::vector<Interval> const&& intervals, unsigned long long & result) {
    std::vector<Point> points;
    points.reserve(intervals.size() * 2);
    for(auto&& interval : intervals) {
        points.emplace_back(PointType::Opening, interval.left);
        points.emplace_back(PointType::Closing, interval.right);
    }
    std::sort(std::begin(points), std::end(points));

    unsigned long long best_result = std::numeric_limits<unsigned long long>::max();
    Interval best_interval(0);
    unsigned long long tmp_result = 0;
    long long last_coord = points.front().coord;
    for(auto&& p : points) if(p.type == PointType::Opening) tmp_result += std::abs(p.coord - last_coord);

    int left = 0;
    int right = intervals.size();

    for(auto&& p : points) {
        tmp_result += (left - right) * (p.coord - last_coord);
        //std::cerr << "Sweep line: " << p.coord << " " << std::to_string(p.type) << " " << tmp_result << std::endl;

        if(tmp_result < best_result) {
            best_result = tmp_result;
            best_interval = Interval{p.coord};
        } else if(tmp_result == best_result) {
            best_interval.right = p.coord;
        }
        switch(p.type) {
            case PointType::Opening: right --; break;
            case PointType::Closing: left ++; break;
        }
        last_coord = p.coord;
    }

    result += best_result;

    return best_interval;
}

Interval const solve(unsigned v, unsigned par, std::vector<std::vector<unsigned>> const& G, std::vector<unsigned> const& hints, unsigned long long & result) {
    std::vector<Interval> intervals;
    intervals.reserve(G[v].size());

    if(hints[v] != 0) return hints[v];

    for(auto u: G[v]) {
        if(u != par) intervals.emplace_back(solve(u, v, G, hints, result));
    }

    auto interval = vertex_result(std::move(intervals), result);
    //std::cerr << "Solve for " << v <<": " << interval << std::endl;
    return interval;
}

int main() {
    std::ios_base::sync_with_stdio(false);

    unsigned n, m;
    std::cin >> n >> m;
    n ++;

    std::vector<std::vector<unsigned>> G(n);

    for(unsigned i = 2; i < n; i++) {
        unsigned a, b;
        std::cin >> a >> b;

        G[a].push_back(b);
        G[b].push_back(a);
    }

    std::vector<unsigned> hints(n);

    for(unsigned i = 1; i <= m; i++) {
        std::cin >> hints[i];
    }

    unsigned long long result = 0;
    solve(n - 1, 0, G, hints, result);

    for(unsigned i = 1; i < n; i++) {
        for(auto v: G[i]) {
            if(i < v && hints[i] != 0 && hints[v] != 0) {
                result += std::abs(static_cast<int>(hints[i]) - static_cast<int>(hints[v]));
            }
        }
    }

    std::cout << result << std::endl;
}