import sys
from functools import lru_cache
def decode(line):
parts = line.split()
n = int(parts[0])
c = parts[1]
nums = list(map(int, parts[2:]))
result = []
curr = c
for a in nums:
result.append(curr * a)
curr = ')' if curr == '(' else '('
return ''.join(result)
def can_interleave(s, r):
ls, lr = len(s), len(r)
if (ls + lr) % 2 != 0:
return False
@lru_cache(maxsize=None)
def dp(a, b):
if a == 0 and b == 0:
return (0, 0)
INF = float('inf')
lo, hi = INF, -INF
if a > 0:
prev = dp(a-1, b)
if prev is not None:
d = 1 if s[a-1] == '(' else -1
new_lo, new_hi = prev[0] + d, prev[1] + d
if new_hi >= 0:
lo = min(lo, max(0, new_lo))
hi = max(hi, new_hi)
if b > 0:
prev = dp(a, b-1)
if prev is not None:
d = 1 if r[b-1] == '(' else -1
new_lo, new_hi = prev[0] + d, prev[1] + d
if new_hi >= 0:
lo = min(lo, max(0, new_lo))
hi = max(hi, new_hi)
if hi == -INF:
return None
return (lo, hi)
result = dp(ls, lr)
dp.cache_clear()
if result is None:
return False
return result[0] <= 0 <= result[1]
def solve():
lines = sys.stdin.read().split('\n')
s = decode(lines[0])
t = decode(lines[1])
count = 0
lt = len(t)
for i in range(lt):
for j in range(i, lt):
if can_interleave(s, t[i:j+1]):
count += 1
print(count)
solve()
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 | import sys from functools import lru_cache def decode(line): parts = line.split() n = int(parts[0]) c = parts[1] nums = list(map(int, parts[2:])) result = [] curr = c for a in nums: result.append(curr * a) curr = ')' if curr == '(' else '(' return ''.join(result) def can_interleave(s, r): ls, lr = len(s), len(r) if (ls + lr) % 2 != 0: return False @lru_cache(maxsize=None) def dp(a, b): if a == 0 and b == 0: return (0, 0) INF = float('inf') lo, hi = INF, -INF if a > 0: prev = dp(a-1, b) if prev is not None: d = 1 if s[a-1] == '(' else -1 new_lo, new_hi = prev[0] + d, prev[1] + d if new_hi >= 0: lo = min(lo, max(0, new_lo)) hi = max(hi, new_hi) if b > 0: prev = dp(a, b-1) if prev is not None: d = 1 if r[b-1] == '(' else -1 new_lo, new_hi = prev[0] + d, prev[1] + d if new_hi >= 0: lo = min(lo, max(0, new_lo)) hi = max(hi, new_hi) if hi == -INF: return None return (lo, hi) result = dp(ls, lr) dp.cache_clear() if result is None: return False return result[0] <= 0 <= result[1] def solve(): lines = sys.stdin.read().split('\n') s = decode(lines[0]) t = decode(lines[1]) count = 0 lt = len(t) for i in range(lt): for j in range(i, lt): if can_interleave(s, t[i:j+1]): count += 1 print(count) solve() |
English