import json import logging from asyncio import Future from collections.abc import Callable from dataclasses import dataclass from typing import Any, Never, NoReturn from typing import cast import mqttapi as mqtt from pubnub.enums import PNReconnectionPolicy from pysnoo import ActivityState from pysnoo import Device from pysnoo import SessionLevel from pysnoo import Snoo from pysnoo import SnooAuthSession from pysnoo import SnooPubNub as SnooPubNubBase from pysnoo.models import EventType # Set global log level to debug logging.getLogger().setLevel(logging.DEBUG) # HACK: Avoid error on missing EventType EventType._missing_ = lambda value: EventType.ACTIVITY # TODO: Catch and handle this: # Error in Snoo PubNub Listener of Category: 3 # Exception in subscribe loop: HTTP Client Error (403): {'message': 'Forbidden', 'payload': {'channels': ['ActivityState.7460194284235017']}, 'error': True, 'service': 'Access Manager', 'status': 403} # HACK: Subclass to modify original pubnub policy class SnooPubNub(SnooPubNubBase): """Subclass of the original Snoo pubnub to alter the reconnect policy.""" @staticmethod def _setup_pnconfig(access_token, uuid): """Generate Setup""" pnconfig = SnooPubNubBase._setup_pnconfig(access_token, uuid) pnconfig.reconnect_policy = PNReconnectionPolicy.EXPONENTIAL return pnconfig async def await_disconnect(self): """Await disconnect""" if self._listener.is_connected(): await self._listener.wait_for_disconnect() def set_token(self, token): return self._pubnub.set_token(token) def _serialize_device_info(device: Device) -> dict[str, Any]: return { "identifiers": [ device.serial_number, ], "name": device.baby, } @dataclass class SnooActivityEntity: """Represents an entity that should be updated via Activity messages.""" field: str component: str device_class: str | None = None name: str | None = None value_template: str | None = None command_callback: Callable | None = None @property def friendly_name(self) -> str: return self.field.replace("_", " ").capitalize() @property def config_topic(self) -> str: entity_name = self.field if self.field != "state" else "snoo_state" return f"homeassistant/{self.component}/{entity_name}/config" def device_topic(self, device: Device) -> str: return f"homeassistant/snoo_{device.serial_number}" def command_topic(self, device: Device) -> str | None: if self.component != "switch": return None return f"{self.device_topic(device)}/{self.field}/set" def discovery_message(self, device: Device) -> dict[str, Any]: payload = { "~": self.device_topic(device), "name": self.name or self.friendly_name, "device_class": self.device_class, "state_topic": "~/state", "unique_id": f"{device.serial_number}_{self.field}", "device": _serialize_device_info(device), "value_template": self.value_template or "{{ value_json." + self.field + " }}", } if self.component in {"binary_sensor", "switch"}: # Use "True" and "False" for binary sensor payloads payload.update( { # Safety classes are actually True if unsafe, so flipping these "payload_on": "True" if self.device_class != "safety" else "False", "payload_off": "False" if self.device_class != "safety" else "True", } ) if command_topic := self.command_topic(device): payload.update( { "command_topic": command_topic, } ) return payload async def status_switch_callback(pubnub: SnooPubNub, data: dict[str, Any]): if data["state"] == "ON": await pubnub.publish_start() else: await pubnub.publish_goto_state(SessionLevel.ONLINE) async def hold_switch_callback(pubnub: SnooPubNub, data: dict[str, Any]): last_activity_state = (await pubnub.history(1))[0] current_state = last_activity_state.state_machine.state if not current_state.is_active_level(): return if data["state"] == "ON": await pubnub.publish_goto_state(current_state, True) else: await pubnub.publish_goto_state(current_state, False) ENTITIES = [ SnooActivityEntity( "left_safety_clip", "binary_sensor", device_class="safety", ), SnooActivityEntity( "right_safety_clip", "binary_sensor", device_class="safety", ), SnooActivityEntity( "weaning", "binary_sensor", value_template="{{ value_json.state_machine.weaning }}", ), SnooActivityEntity( "is_active_session", "binary_sensor", device_class="running", name="Active Session", value_template="{{ value_json.state_machine.is_active_session }}", ), SnooActivityEntity( "state", "sensor", value_template="{{ value_json.state_machine.state }}", ), SnooActivityEntity( "status", "switch", value_template="{{ value_json.state_machine.state != 'ONLINE' }}", command_callback=status_switch_callback, ), SnooActivityEntity( "hold", "switch", value_template="{{ value_json.state_machine.hold }}", command_callback=hold_switch_callback, ), ] class Token: def __init__(self, file_path: str) -> None: self._path = file_path def value(self) -> dict | None: """Read auth token from JSON file.""" try: with open(self._path) as infile: return json.load(infile) except FileNotFoundError: pass except ValueError: pass return None @property def updater(self) -> Callable[[dict], None]: def token_updater(token: dict): with open(self._path, "w") as outfile: json.dump(token, outfile) return token_updater def _device_to_pubnub(snoo: Snoo, device: Device) -> SnooPubNub: assert snoo.auth.access_token pubnub = SnooPubNub( snoo.auth.access_token, device.serial_number, f"pn-pysnoo-{device.serial_number}", ) # TODO: Add some way of passing this config into pysnoo # Another temp option to subclass SnooPubNub # pubnub._pubnub.config.reconnect_policy = PNReconnectionPolicy.EXPONENTIAL return pubnub class SnooMQTT(mqtt.Mqtt): def __init__(self, *args, **kwargs) -> None: self._devices: list[Device] = [] self._pubnubs: dict[str, SnooPubNub] = {} self._session: SnooAuthSession | None = None self._snoo_subscription: Future | None = None self._snoo_poll_timer_handle: str|None = None super().__init__(*args, **kwargs) async def initialize(self) -> None: self._session = await self.authorize_session() snoo = Snoo(self._session) # Get devices to be monitored self._devices = await snoo.get_devices() if not self._devices: raise ValueError("No Snoo devices connected to account") # Publish discovery information await self._publish_discovery(self._devices) # Listen for home assistant status to republish discovery self._birth_listener_handle = await self.listen_event( self.birth_callback, "MQTT_MESSAGE", topic="homeassistant/status", ) await self.snoo_subscribe(snoo) async def terminate(self) -> None: # Stop listening to birth events if self._birth_listener_handle: await self.cancel_listen_event(self._birth_listener_handle) # Stop polling for updates if self._snoo_poll_timer_handle: await self.cancel_timer(self._snoo_poll_timer_handle) # Stop listening task if self._snoo_subscription: self._snoo_subscription.cancel() # Close the session if self._session: await self._session.close() async def reinit(self) -> None: """Reinitialize the app by terminating everything and starting up again.""" await self.terminate() await self.initialize() async def authorize_session(self) -> SnooAuthSession: # Get token and updator token = Token(self._token_path) # Create and pre-authorize session session = SnooAuthSession(token.value() or {}, token.updater) if not session.authorized: new_token = await session.fetch_token(self._username, self._password) token.updater(new_token) self.log("got inital token") return session async def birth_callback(self, event_name, data, *args) -> None: """Callback listening for hass status messages. Should republish discovery and initial status for every device.""" self.log("Read hass status event: %s, %s", event_name, data) await self._publish_discovery(self._devices) for device in self._devices: pubnub = self._pubnubs[device.serial_number] for activity_state in await pubnub.history(1): self._create_activity_callback(device)(activity_state) async def snoo_subscribe(self, snoo: Snoo) -> None: """Creates AppDaemon subscription task to listen to Snoo updates.""" # Subscribe and start listening to Snoo updates self._snoo_subscription = self.create_task(self.snoo_real_subscribe(snoo)) async def snoo_real_subscribe(self, snoo: Snoo): """Coroutine that subscribes to updates from Snoo. Never returns.""" # This should only run after initialize, so this should not be None # Subscribe to updates for device in self._devices: pubnub = _device_to_pubnub(snoo, device) self._pubnubs[device.serial_number] = pubnub if self._polling: self.log("Scheduling updates using polling") self._snoo_poll_timer_handle = await self._schedule_updates(device) else: self.log("Scheduling updates listening") await self._subscribe_and_listen(device, pubnub) async def _schedule_updates(self, device: Device) -> str: """Schedules from pubnub history periodic updates.""" cb = self._create_activity_callback(device) async def poll_history(*args): # TODO: Maybe try/except here for auth errors and then reconnect self.terminate() self.initialize() for activity_state in await self._pubnubs[device.serial_number].history(1): cb(activity_state) return await self.run_every(poll_history, "now", 30) async def _subscribe_and_listen(self, device: Device, pubnub: SnooPubNub): """Subscribes to pubnub activity and listens to new events.""" cb = self._create_activity_callback(device) pubnub.add_listener(cb) # Publish first status message for activity_state in await pubnub.history(1): cb(activity_state) # TODO: Maybe try/except here for auth errors and then reconnect self.terminate() self.initialize() await pubnub.subscribe_and_await_connect() self.log("Done subscribing and listening...") await pubnub.await_disconnect() self.log("Disconnected! Maybe re-initialize here") # self.create_task(self.reinit()) async def _disconnect_listener(self): # Wait for everything to disconnect for pubnub in self._pubnubs.values(): await pubnub.await_disconnect() self.log("Disconnected! Maybe re-initialize here") # Refresh token assert self._session and self._session.auto_refresh_url and self._session.token_updater token = await self._session.refresh_token(self._session.auto_refresh_url) self._session.token_updater(token) for pubnub in self._pubnubs.values(): pubnub.set_token(token) # self.create_task(self.reinit()) def _create_activity_callback(self, device: Device): """Creates an activity callback for a given device.""" def activity_callback(activity_state: ActivityState): self._publish_object( f"homeassistant/snoo_{device.serial_number}/state", activity_state, ) return activity_callback async def _publish_discovery(self, devices: list[Device]): """Publishes discovery messages for the provided devices.""" for device in devices: for entity in ENTITIES: self._publish_object( entity.config_topic, entity.discovery_message(device), ) # See if we need to listen to commands if ( entity.command_callback is not None and entity.command_topic is not None ): async def cb(self, event_name, data, *args): assert entity.command_callback await entity.command_callback( self._pubnubs[device.serial_number], json.loads(data) ) self.listen_event( cb, "MQTT_MESSAGE", topic=entity.command_topic(device) ) def _publish_object(self, topic, payload_object): """Publishes an object, serializing it's payload as JSON.""" try: payload = json.dumps(payload_object.to_dict()) except AttributeError: payload = json.dumps(payload_object) self.log("topic: %s payload: %s", topic, payload) self.mqtt_publish(topic, payload=payload) # App arguments @property def _username(self) -> str: username = self.args.get("username") if not username: raise ValueError("Must provide a username") return username @property def _password(self) -> str: password = self.args.get("password") if not password: raise ValueError("Must provide a password") return password @property def _token_path(self) -> str: return self.args.get("token_path", "snoo_token.json") @property def _polling(self) -> bool: return self.args.get("polling", False)