#! /usr/bin/env python3 import re from functools import reduce from dataclasses import dataclass from typing import Any from typing import Dict from typing import Iterable from typing import Set food_parser = re.compile(r"([a-z ]+) \(contains ([a-z, ]+)\)") @dataclass class Food: items: Set[str] alergens: Set[str] def read_input(filename: str) -> Iterable[Food]: with open(filename) as f: for line in f: line = line.strip() if line == "": continue match = food_parser.match(line) if not match: raise ValueError(f"cannot parse line {line}") items = match.group(1).split(" ") alergens = match.group(2).split(", ") yield Food(set(items), set(alergens)) def extract_alergen_info(foods: Iterable[Food]) -> Dict[str, Set[str]]: alergen_defs: Dict[str, Set[str]] = {} for food in foods: for alergen in food.alergens: if alergen not in alergen_defs: alergen_defs[alergen] = food.items.copy() else: alergen_defs[alergen] &= food.items return alergen_defs def reduce_union(items: Iterable[Set[Any]]) -> Set[Any]: return reduce(lambda x, y: x | y, items) def reduce_intersecion(items: Iterable[Set[Any]]) -> Set[Any]: return reduce(lambda x, y: x & y, items) def narrow_results(alergen_defs: Dict[str, Set[str]]): sorted_alergens = sorted(alergen_defs.items(), key=lambda x: len(x[1])) for i, alergen_info in enumerate(sorted_alergens): if len(alergen_info[1]) == 1: # Remove from others if i+1 < len(sorted_alergens): for other_info in sorted_alergens[i+1:]: other_info[1].difference_update(alergen_info[1]) def part1(): all_foods = list(read_input("input.txt")) alergen_defs = extract_alergen_info(all_foods) print("Alergen defs", alergen_defs) safe_items = reduce_union((food.items for food in all_foods)) all_possible_alergens = reduce_union(alergen_defs.values()) print("All possible alergens", all_possible_alergens) safe_items -= all_possible_alergens print("All safe items", safe_items) count_safe = sum( map( lambda food: len(food.items & safe_items), (food for food in all_foods), ), ) print("Count of safe items appeared is", count_safe) print("Part 2") narrow_results(alergen_defs) narrow_results(alergen_defs) narrow_results(alergen_defs) print("Narrowed alergen list", alergen_defs) canon_list = ",".join(( a[1].pop() for a in sorted(alergen_defs.items(), key=lambda x: x[0]) )) print("Canon list of ingredients", canon_list) if __name__ == "__main__": part1()