#! /usr/bin/env python3 from typing import List operators = { "+": lambda x, y: x + y, "-": lambda x, y: x - y, "*": lambda x, y: x * y, "/": lambda x, y: x / y, } class SynNode(object): def __init__(self, nodes, ops): self.nodes = nodes self.ops = ops def evaluate(self) -> int: assert len(self.nodes) == len(self.ops) + 1 # Eval sub nodes node_values = [node.evaluate() for node in self.nodes] total_value = node_values[0] for i, op in enumerate(self.ops, 1): total_value = operators[op](total_value, node_values[i]) return total_value class ValueNode(SynNode): def __init__(self, val: int): self.value = val def evaluate(self) -> int: return self.value def find_high_level_ops(s: str): balance = 0 indicies = [] for i, ch in enumerate(s): if ch == ')': balance += 1 elif ch == '(': balance -= 1 elif ch in operators and balance == 0: indicies.append(i) return indicies def parse(s: str) -> SynNode: s = s.strip() if s.isnumeric(): return ValueNode(int(s)) hl_op_indicies = find_high_level_ops(s) if not hl_op_indicies: # Must be like "(1 + 2)", remove parent and try parse sub operation return parse(s[1:-1]) ops = [s[i] for i in hl_op_indicies] nodes: List[SynNode] = [] start = 0 for i in hl_op_indicies: nodes.append(parse(s[start:i])) start = i + 1 # Parse remaining nodes.append(parse(s[start:])) while '+' in ops: add_index = ops.index('+') # Remove op and neighbor nodes and replace with a syn node instead op = ops.pop(add_index) right = nodes.pop(add_index+1) left = nodes.pop(add_index) nodes.insert(add_index, SynNode([left, right], [op])) return SynNode(nodes, ops) def part2(): # print(parse("1 + (2 * 3) + (4 * (5 + 6))").evaluate()) # print(parse("5 + (8 * 3 + 9 + 3 * 4 * 3)").evaluate()) total = 0 with open("input.txt") as f: for line in f: r = parse(line).evaluate() total += r print(f"Total: {total}") if __name__ == "__main__": part2()