1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import sys
from functools import lru_cache

def decode(line):
    parts = line.split()
    n = int(parts[0])
    c = parts[1]
    nums = list(map(int, parts[2:]))
    result = []
    curr = c
    for a in nums:
        result.append(curr * a)
        curr = ')' if curr == '(' else '('
    return ''.join(result)

def can_interleave(s, r):
    ls, lr = len(s), len(r)
    if (ls + lr) % 2 != 0:
        return False

    @lru_cache(maxsize=None)
    def dp(a, b):
        if a == 0 and b == 0:
            return (0, 0)
        INF = float('inf')
        lo, hi = INF, -INF
        if a > 0:
            prev = dp(a-1, b)
            if prev is not None:
                d = 1 if s[a-1] == '(' else -1
                new_lo, new_hi = prev[0] + d, prev[1] + d
                if new_hi >= 0:
                    lo = min(lo, max(0, new_lo))
                    hi = max(hi, new_hi)
        if b > 0:
            prev = dp(a, b-1)
            if prev is not None:
                d = 1 if r[b-1] == '(' else -1
                new_lo, new_hi = prev[0] + d, prev[1] + d
                if new_hi >= 0:
                    lo = min(lo, max(0, new_lo))
                    hi = max(hi, new_hi)
        if hi == -INF:
            return None
        return (lo, hi)

    result = dp(ls, lr)
    dp.cache_clear()
    if result is None:
        return False
    return result[0] <= 0 <= result[1]

def solve():
    lines = sys.stdin.read().split('\n')
    s = decode(lines[0])
    t = decode(lines[1])
    count = 0
    lt = len(t)
    for i in range(lt):
        for j in range(i, lt):
            if can_interleave(s, t[i:j+1]):
                count += 1
    print(count)

solve()