97 lines
2.2 KiB
Python
Executable File
97 lines
2.2 KiB
Python
Executable File
#! /usr/bin/env python3
|
|
from typing import List
|
|
|
|
|
|
operators = {
|
|
"+": lambda x, y: x + y,
|
|
"-": lambda x, y: x - y,
|
|
"*": lambda x, y: x * y,
|
|
"/": lambda x, y: x / y,
|
|
}
|
|
|
|
|
|
class SynNode(object):
|
|
def __init__(self, nodes, ops):
|
|
self.nodes = nodes
|
|
self.ops = ops
|
|
|
|
def evaluate(self) -> int:
|
|
assert len(self.nodes) == len(self.ops) + 1
|
|
# Eval sub nodes
|
|
node_values = [node.evaluate() for node in self.nodes]
|
|
total_value = node_values[0]
|
|
|
|
for i, op in enumerate(self.ops, 1):
|
|
total_value = operators[op](total_value, node_values[i])
|
|
|
|
return total_value
|
|
|
|
|
|
class ValueNode(SynNode):
|
|
def __init__(self, val: int):
|
|
self.value = val
|
|
|
|
def evaluate(self) -> int:
|
|
return self.value
|
|
|
|
|
|
def find_high_level_ops(s: str):
|
|
balance = 0
|
|
indicies = []
|
|
for i, ch in enumerate(s):
|
|
if ch == ')':
|
|
balance += 1
|
|
elif ch == '(':
|
|
balance -= 1
|
|
elif ch in operators and balance == 0:
|
|
indicies.append(i)
|
|
|
|
return indicies
|
|
|
|
|
|
def parse(s: str) -> SynNode:
|
|
s = s.strip()
|
|
if s.isnumeric():
|
|
return ValueNode(int(s))
|
|
|
|
hl_op_indicies = find_high_level_ops(s)
|
|
if not hl_op_indicies:
|
|
# Must be like "(1 + 2)", remove parent and try parse sub operation
|
|
return parse(s[1:-1])
|
|
|
|
ops = [s[i] for i in hl_op_indicies]
|
|
nodes: List[SynNode] = []
|
|
|
|
start = 0
|
|
for i in hl_op_indicies:
|
|
nodes.append(parse(s[start:i]))
|
|
start = i + 1
|
|
|
|
# Parse remaining
|
|
nodes.append(parse(s[start:]))
|
|
|
|
while '+' in ops:
|
|
add_index = ops.index('+')
|
|
# Remove op and neighbor nodes and replace with a syn node instead
|
|
op = ops.pop(add_index)
|
|
right = nodes.pop(add_index+1)
|
|
left = nodes.pop(add_index)
|
|
nodes.insert(add_index, SynNode([left, right], [op]))
|
|
|
|
return SynNode(nodes, ops)
|
|
|
|
|
|
def part2():
|
|
# print(parse("1 + (2 * 3) + (4 * (5 + 6))").evaluate())
|
|
# print(parse("5 + (8 * 3 + 9 + 3 * 4 * 3)").evaluate())
|
|
total = 0
|
|
with open("input.txt") as f:
|
|
for line in f:
|
|
r = parse(line).evaluate()
|
|
total += r
|
|
|
|
print(f"Total: {total}")
|
|
|
|
if __name__ == "__main__":
|
|
part2()
|