aoc-2020/d21/main.py

106 lines
2.8 KiB
Python
Executable File

#! /usr/bin/env python3
import re
from functools import reduce
from dataclasses import dataclass
from typing import Any
from typing import Dict
from typing import Iterable
from typing import Set
food_parser = re.compile(r"([a-z ]+) \(contains ([a-z, ]+)\)")
@dataclass
class Food:
items: Set[str]
alergens: Set[str]
def read_input(filename: str) -> Iterable[Food]:
with open(filename) as f:
for line in f:
line = line.strip()
if line == "":
continue
match = food_parser.match(line)
if not match:
raise ValueError(f"cannot parse line {line}")
items = match.group(1).split(" ")
alergens = match.group(2).split(", ")
yield Food(set(items), set(alergens))
def extract_alergen_info(foods: Iterable[Food]) -> Dict[str, Set[str]]:
alergen_defs: Dict[str, Set[str]] = {}
for food in foods:
for alergen in food.alergens:
if alergen not in alergen_defs:
alergen_defs[alergen] = food.items.copy()
else:
alergen_defs[alergen] &= food.items
return alergen_defs
def reduce_union(items: Iterable[Set[Any]]) -> Set[Any]:
return reduce(lambda x, y: x | y, items)
def reduce_intersecion(items: Iterable[Set[Any]]) -> Set[Any]:
return reduce(lambda x, y: x & y, items)
def narrow_results(alergen_defs: Dict[str, Set[str]]):
sorted_alergens = sorted(alergen_defs.items(), key=lambda x: len(x[1]))
for i, alergen_info in enumerate(sorted_alergens):
if len(alergen_info[1]) == 1:
# Remove from others
if i+1 < len(sorted_alergens):
for other_info in sorted_alergens[i+1:]:
other_info[1].difference_update(alergen_info[1])
def part1():
all_foods = list(read_input("input.txt"))
alergen_defs = extract_alergen_info(all_foods)
print("Alergen defs", alergen_defs)
safe_items = reduce_union((food.items for food in all_foods))
all_possible_alergens = reduce_union(alergen_defs.values())
print("All possible alergens", all_possible_alergens)
safe_items -= all_possible_alergens
print("All safe items", safe_items)
count_safe = sum(
map(
lambda food: len(food.items & safe_items),
(food for food in all_foods),
),
)
print("Count of safe items appeared is", count_safe)
print("Part 2")
narrow_results(alergen_defs)
narrow_results(alergen_defs)
narrow_results(alergen_defs)
print("Narrowed alergen list", alergen_defs)
canon_list = ",".join((
a[1].pop() for a in sorted(alergen_defs.items(), key=lambda x: x[0])
))
print("Canon list of ingredients", canon_list)
if __name__ == "__main__":
part1()