From 39157aa3681026cb9c0326781f31f6eb86d635f0 Mon Sep 17 00:00:00 2001 From: Ian Fijolek Date: Thu, 11 Jan 2024 19:48:42 -0800 Subject: [PATCH] WIP: Hackier overriding for connections --- apps/snoo_mqtt.py | 202 +++++++++++++++++++++++++++++++++------------- 1 file changed, 148 insertions(+), 54 deletions(-) diff --git a/apps/snoo_mqtt.py b/apps/snoo_mqtt.py index 37e3d90..7e1068c 100644 --- a/apps/snoo_mqtt.py +++ b/apps/snoo_mqtt.py @@ -1,27 +1,32 @@ import json import logging +from asyncio import Future from collections.abc import Callable from dataclasses import dataclass -from typing import Any +from typing import Any, Never, NoReturn +from typing import cast import mqttapi as mqtt from pubnub.enums import PNReconnectionPolicy from pysnoo import ActivityState from pysnoo import Device +from pysnoo import SessionLevel from pysnoo import Snoo from pysnoo import SnooAuthSession from pysnoo import SnooPubNub as SnooPubNubBase from pysnoo.models import EventType -# HACK: Avoid error on missing EventType -EventType._missing_ = lambda x: EventType.ACTIVITY - # Set global log level to debug logging.getLogger().setLevel(logging.DEBUG) -# Use polling or listening -POLLING = True + +# HACK: Avoid error on missing EventType +EventType._missing_ = lambda value: EventType.ACTIVITY + +# TODO: Catch and handle this: +# Error in Snoo PubNub Listener of Category: 3 +# Exception in subscribe loop: HTTP Client Error (403): {'message': 'Forbidden', 'payload': {'channels': ['ActivityState.7460194284235017']}, 'error': True, 'service': 'Access Manager', 'status': 403} # HACK: Subclass to modify original pubnub policy @@ -35,6 +40,14 @@ class SnooPubNub(SnooPubNubBase): pnconfig.reconnect_policy = PNReconnectionPolicy.EXPONENTIAL return pnconfig + async def await_disconnect(self): + """Await disconnect""" + if self._listener.is_connected(): + await self._listener.wait_for_disconnect() + + def set_token(self, token): + return self._pubnub.set_token(token) + def _serialize_device_info(device: Device) -> dict[str, Any]: return { @@ -110,7 +123,7 @@ async def status_switch_callback(pubnub: SnooPubNub, data: dict[str, Any]): if data["state"] == "ON": await pubnub.publish_start() else: - await pubnub.publish_goto_state("ONLINE") + await pubnub.publish_goto_state(SessionLevel.ONLINE) async def hold_switch_callback(pubnub: SnooPubNub, data: dict[str, Any]): @@ -167,31 +180,39 @@ ENTITIES = [ ] -def _get_token(token_file: str): - """Read auth token from JSON file.""" - try: - with open(token_file) as infile: - return json.load(infile) - except FileNotFoundError: - pass - except ValueError: - pass +class Token: + def __init__(self, file_path: str) -> None: + self._path = file_path + def value(self) -> dict | None: + """Read auth token from JSON file.""" + try: + with open(self._path) as infile: + return json.load(infile) + except FileNotFoundError: + pass + except ValueError: + pass -def _get_token_updater(token_file: str): - def token_updater(token): - with open(token_file, "w") as outfile: - json.dump(token, outfile) + return None - return token_updater + @property + def updater(self) -> Callable[[dict], None]: + def token_updater(token: dict): + with open(self._path, "w") as outfile: + json.dump(token, outfile) + + return token_updater def _device_to_pubnub(snoo: Snoo, device: Device) -> SnooPubNub: + assert snoo.auth.access_token pubnub = SnooPubNub( snoo.auth.access_token, device.serial_number, f"pn-pysnoo-{device.serial_number}", ) + # TODO: Add some way of passing this config into pysnoo # Another temp option to subclass SnooPubNub # pubnub._pubnub.config.reconnect_policy = PNReconnectionPolicy.EXPONENTIAL @@ -200,69 +221,114 @@ def _device_to_pubnub(snoo: Snoo, device: Device) -> SnooPubNub: class SnooMQTT(mqtt.Mqtt): def __init__(self, *args, **kwargs) -> None: - self._snoo: Device | None = None self._devices: list[Device] = [] self._pubnubs: dict[str, SnooPubNub] = {} + self._session: SnooAuthSession | None = None + self._snoo_subscription: Future | None = None + self._snoo_poll_timer_handle: str|None = None + super().__init__(*args, **kwargs) async def initialize(self) -> None: - token_path = self._token_path - token = _get_token(token_path) - token_updater = _get_token_updater(token_path) + self._session = await self.authorize_session() + snoo = Snoo(self._session) - async with SnooAuthSession(token, token_updater) as auth: - if not auth.authorized: - new_token = await auth.fetch_token(self._username, self._password) - token_updater(new_token) - self.log("got inital token") - - self._snoo = Snoo(auth) - - self._devices = await self._snoo.get_devices() + # Get devices to be monitored + self._devices = await snoo.get_devices() if not self._devices: raise ValueError("No Snoo devices connected to account") # Publish discovery information - self._publish_discovery(self._devices) - - # Subscribe to updates - for device in self._devices: - pubnub = _device_to_pubnub(self._snoo, device) - self._pubnubs[device.serial_number] = pubnub - - if POLLING: - await self._schedule_updates(device, pubnub) - else: - await self._subscribe_and_listen(device, pubnub) + await self._publish_discovery(self._devices) # Listen for home assistant status to republish discovery - self.listen_event( + self._birth_listener_handle = await self.listen_event( self.birth_callback, "MQTT_MESSAGE", topic="homeassistant/status", ) + await self.snoo_subscribe(snoo) + + async def terminate(self) -> None: + # Stop listening to birth events + if self._birth_listener_handle: + await self.cancel_listen_event(self._birth_listener_handle) + + # Stop polling for updates + if self._snoo_poll_timer_handle: + await self.cancel_timer(self._snoo_poll_timer_handle) + + # Stop listening task + if self._snoo_subscription: + self._snoo_subscription.cancel() + + # Close the session + if self._session: + await self._session.close() + + async def reinit(self) -> None: + """Reinitialize the app by terminating everything and starting up again.""" + await self.terminate() + await self.initialize() + + async def authorize_session(self) -> SnooAuthSession: + # Get token and updator + token = Token(self._token_path) + + # Create and pre-authorize session + session = SnooAuthSession(token.value() or {}, token.updater) + if not session.authorized: + new_token = await session.fetch_token(self._username, self._password) + token.updater(new_token) + self.log("got inital token") + + return session + async def birth_callback(self, event_name, data, *args) -> None: """Callback listening for hass status messages. Should republish discovery and initial status for every device.""" self.log("Read hass status event: %s, %s", event_name, data) - self._publish_discovery(self._devices) + await self._publish_discovery(self._devices) for device in self._devices: pubnub = self._pubnubs[device.serial_number] for activity_state in await pubnub.history(1): self._create_activity_callback(device)(activity_state) - async def _schedule_updates(self, device: Device, pubnub: SnooPubNub): + async def snoo_subscribe(self, snoo: Snoo) -> None: + """Creates AppDaemon subscription task to listen to Snoo updates.""" + # Subscribe and start listening to Snoo updates + self._snoo_subscription = self.create_task(self.snoo_real_subscribe(snoo)) + + async def snoo_real_subscribe(self, snoo: Snoo): + """Coroutine that subscribes to updates from Snoo. + + Never returns.""" + # This should only run after initialize, so this should not be None + # Subscribe to updates + for device in self._devices: + pubnub = _device_to_pubnub(snoo, device) + self._pubnubs[device.serial_number] = pubnub + + if self._polling: + self.log("Scheduling updates using polling") + self._snoo_poll_timer_handle = await self._schedule_updates(device) + else: + self.log("Scheduling updates listening") + await self._subscribe_and_listen(device, pubnub) + + async def _schedule_updates(self, device: Device) -> str: """Schedules from pubnub history periodic updates.""" cb = self._create_activity_callback(device) async def poll_history(*args): - for activity_state in await pubnub.history(1): + # TODO: Maybe try/except here for auth errors and then reconnect self.terminate() self.initialize() + for activity_state in await self._pubnubs[device.serial_number].history(1): cb(activity_state) - self.run_every(poll_history, "now", 30) + return await self.run_every(poll_history, "now", 30) async def _subscribe_and_listen(self, device: Device, pubnub: SnooPubNub): """Subscribes to pubnub activity and listens to new events.""" @@ -273,7 +339,29 @@ class SnooMQTT(mqtt.Mqtt): for activity_state in await pubnub.history(1): cb(activity_state) + # TODO: Maybe try/except here for auth errors and then reconnect self.terminate() self.initialize() await pubnub.subscribe_and_await_connect() + self.log("Done subscribing and listening...") + await pubnub.await_disconnect() + self.log("Disconnected! Maybe re-initialize here") + # self.create_task(self.reinit()) + + async def _disconnect_listener(self): + # Wait for everything to disconnect + for pubnub in self._pubnubs.values(): + await pubnub.await_disconnect() + + self.log("Disconnected! Maybe re-initialize here") + + # Refresh token + assert self._session and self._session.auto_refresh_url and self._session.token_updater + token = await self._session.refresh_token(self._session.auto_refresh_url) + self._session.token_updater(token) + + for pubnub in self._pubnubs.values(): + pubnub.set_token(token) + + # self.create_task(self.reinit()) def _create_activity_callback(self, device: Device): """Creates an activity callback for a given device.""" @@ -286,7 +374,7 @@ class SnooMQTT(mqtt.Mqtt): return activity_callback - def _publish_discovery(self, devices: list[Device]): + async def _publish_discovery(self, devices: list[Device]): """Publishes discovery messages for the provided devices.""" for device in devices: for entity in ENTITIES: @@ -321,20 +409,26 @@ class SnooMQTT(mqtt.Mqtt): self.log("topic: %s payload: %s", topic, payload) self.mqtt_publish(topic, payload=payload) + # App arguments + @property - def _username(self): + def _username(self) -> str: username = self.args.get("username") if not username: raise ValueError("Must provide a username") return username @property - def _password(self): + def _password(self) -> str: password = self.args.get("password") if not password: raise ValueError("Must provide a password") return password @property - def _token_path(self): + def _token_path(self) -> str: return self.args.get("token_path", "snoo_token.json") + + @property + def _polling(self) -> bool: + return self.args.get("polling", False)