WIP: Hackier overriding for connections

This commit is contained in:
IamTheFij 2024-01-11 19:48:42 -08:00
parent 4af9ffc0f5
commit 39157aa368

View File

@ -1,27 +1,32 @@
import json
import logging
from asyncio import Future
from collections.abc import Callable
from dataclasses import dataclass
from typing import Any
from typing import Any, Never, NoReturn
from typing import cast
import mqttapi as mqtt
from pubnub.enums import PNReconnectionPolicy
from pysnoo import ActivityState
from pysnoo import Device
from pysnoo import SessionLevel
from pysnoo import Snoo
from pysnoo import SnooAuthSession
from pysnoo import SnooPubNub as SnooPubNubBase
from pysnoo.models import EventType
# HACK: Avoid error on missing EventType
EventType._missing_ = lambda x: EventType.ACTIVITY
# Set global log level to debug
logging.getLogger().setLevel(logging.DEBUG)
# Use polling or listening
POLLING = True
# HACK: Avoid error on missing EventType
EventType._missing_ = lambda value: EventType.ACTIVITY
# TODO: Catch and handle this:
# Error in Snoo PubNub Listener of Category: 3
# Exception in subscribe loop: HTTP Client Error (403): {'message': 'Forbidden', 'payload': {'channels': ['ActivityState.7460194284235017']}, 'error': True, 'service': 'Access Manager', 'status': 403}
# HACK: Subclass to modify original pubnub policy
@ -35,6 +40,14 @@ class SnooPubNub(SnooPubNubBase):
pnconfig.reconnect_policy = PNReconnectionPolicy.EXPONENTIAL
return pnconfig
async def await_disconnect(self):
"""Await disconnect"""
if self._listener.is_connected():
await self._listener.wait_for_disconnect()
def set_token(self, token):
return self._pubnub.set_token(token)
def _serialize_device_info(device: Device) -> dict[str, Any]:
return {
@ -110,7 +123,7 @@ async def status_switch_callback(pubnub: SnooPubNub, data: dict[str, Any]):
if data["state"] == "ON":
await pubnub.publish_start()
else:
await pubnub.publish_goto_state("ONLINE")
await pubnub.publish_goto_state(SessionLevel.ONLINE)
async def hold_switch_callback(pubnub: SnooPubNub, data: dict[str, Any]):
@ -167,31 +180,39 @@ ENTITIES = [
]
def _get_token(token_file: str):
"""Read auth token from JSON file."""
try:
with open(token_file) as infile:
return json.load(infile)
except FileNotFoundError:
pass
except ValueError:
pass
class Token:
def __init__(self, file_path: str) -> None:
self._path = file_path
def value(self) -> dict | None:
"""Read auth token from JSON file."""
try:
with open(self._path) as infile:
return json.load(infile)
except FileNotFoundError:
pass
except ValueError:
pass
def _get_token_updater(token_file: str):
def token_updater(token):
with open(token_file, "w") as outfile:
json.dump(token, outfile)
return None
return token_updater
@property
def updater(self) -> Callable[[dict], None]:
def token_updater(token: dict):
with open(self._path, "w") as outfile:
json.dump(token, outfile)
return token_updater
def _device_to_pubnub(snoo: Snoo, device: Device) -> SnooPubNub:
assert snoo.auth.access_token
pubnub = SnooPubNub(
snoo.auth.access_token,
device.serial_number,
f"pn-pysnoo-{device.serial_number}",
)
# TODO: Add some way of passing this config into pysnoo
# Another temp option to subclass SnooPubNub
# pubnub._pubnub.config.reconnect_policy = PNReconnectionPolicy.EXPONENTIAL
@ -200,69 +221,114 @@ def _device_to_pubnub(snoo: Snoo, device: Device) -> SnooPubNub:
class SnooMQTT(mqtt.Mqtt):
def __init__(self, *args, **kwargs) -> None:
self._snoo: Device | None = None
self._devices: list[Device] = []
self._pubnubs: dict[str, SnooPubNub] = {}
self._session: SnooAuthSession | None = None
self._snoo_subscription: Future | None = None
self._snoo_poll_timer_handle: str|None = None
super().__init__(*args, **kwargs)
async def initialize(self) -> None:
token_path = self._token_path
token = _get_token(token_path)
token_updater = _get_token_updater(token_path)
self._session = await self.authorize_session()
snoo = Snoo(self._session)
async with SnooAuthSession(token, token_updater) as auth:
if not auth.authorized:
new_token = await auth.fetch_token(self._username, self._password)
token_updater(new_token)
self.log("got inital token")
self._snoo = Snoo(auth)
self._devices = await self._snoo.get_devices()
# Get devices to be monitored
self._devices = await snoo.get_devices()
if not self._devices:
raise ValueError("No Snoo devices connected to account")
# Publish discovery information
self._publish_discovery(self._devices)
# Subscribe to updates
for device in self._devices:
pubnub = _device_to_pubnub(self._snoo, device)
self._pubnubs[device.serial_number] = pubnub
if POLLING:
await self._schedule_updates(device, pubnub)
else:
await self._subscribe_and_listen(device, pubnub)
await self._publish_discovery(self._devices)
# Listen for home assistant status to republish discovery
self.listen_event(
self._birth_listener_handle = await self.listen_event(
self.birth_callback,
"MQTT_MESSAGE",
topic="homeassistant/status",
)
await self.snoo_subscribe(snoo)
async def terminate(self) -> None:
# Stop listening to birth events
if self._birth_listener_handle:
await self.cancel_listen_event(self._birth_listener_handle)
# Stop polling for updates
if self._snoo_poll_timer_handle:
await self.cancel_timer(self._snoo_poll_timer_handle)
# Stop listening task
if self._snoo_subscription:
self._snoo_subscription.cancel()
# Close the session
if self._session:
await self._session.close()
async def reinit(self) -> None:
"""Reinitialize the app by terminating everything and starting up again."""
await self.terminate()
await self.initialize()
async def authorize_session(self) -> SnooAuthSession:
# Get token and updator
token = Token(self._token_path)
# Create and pre-authorize session
session = SnooAuthSession(token.value() or {}, token.updater)
if not session.authorized:
new_token = await session.fetch_token(self._username, self._password)
token.updater(new_token)
self.log("got inital token")
return session
async def birth_callback(self, event_name, data, *args) -> None:
"""Callback listening for hass status messages.
Should republish discovery and initial status for every device."""
self.log("Read hass status event: %s, %s", event_name, data)
self._publish_discovery(self._devices)
await self._publish_discovery(self._devices)
for device in self._devices:
pubnub = self._pubnubs[device.serial_number]
for activity_state in await pubnub.history(1):
self._create_activity_callback(device)(activity_state)
async def _schedule_updates(self, device: Device, pubnub: SnooPubNub):
async def snoo_subscribe(self, snoo: Snoo) -> None:
"""Creates AppDaemon subscription task to listen to Snoo updates."""
# Subscribe and start listening to Snoo updates
self._snoo_subscription = self.create_task(self.snoo_real_subscribe(snoo))
async def snoo_real_subscribe(self, snoo: Snoo):
"""Coroutine that subscribes to updates from Snoo.
Never returns."""
# This should only run after initialize, so this should not be None
# Subscribe to updates
for device in self._devices:
pubnub = _device_to_pubnub(snoo, device)
self._pubnubs[device.serial_number] = pubnub
if self._polling:
self.log("Scheduling updates using polling")
self._snoo_poll_timer_handle = await self._schedule_updates(device)
else:
self.log("Scheduling updates listening")
await self._subscribe_and_listen(device, pubnub)
async def _schedule_updates(self, device: Device) -> str:
"""Schedules from pubnub history periodic updates."""
cb = self._create_activity_callback(device)
async def poll_history(*args):
for activity_state in await pubnub.history(1):
# TODO: Maybe try/except here for auth errors and then reconnect self.terminate() self.initialize()
for activity_state in await self._pubnubs[device.serial_number].history(1):
cb(activity_state)
self.run_every(poll_history, "now", 30)
return await self.run_every(poll_history, "now", 30)
async def _subscribe_and_listen(self, device: Device, pubnub: SnooPubNub):
"""Subscribes to pubnub activity and listens to new events."""
@ -273,7 +339,29 @@ class SnooMQTT(mqtt.Mqtt):
for activity_state in await pubnub.history(1):
cb(activity_state)
# TODO: Maybe try/except here for auth errors and then reconnect self.terminate() self.initialize()
await pubnub.subscribe_and_await_connect()
self.log("Done subscribing and listening...")
await pubnub.await_disconnect()
self.log("Disconnected! Maybe re-initialize here")
# self.create_task(self.reinit())
async def _disconnect_listener(self):
# Wait for everything to disconnect
for pubnub in self._pubnubs.values():
await pubnub.await_disconnect()
self.log("Disconnected! Maybe re-initialize here")
# Refresh token
assert self._session and self._session.auto_refresh_url and self._session.token_updater
token = await self._session.refresh_token(self._session.auto_refresh_url)
self._session.token_updater(token)
for pubnub in self._pubnubs.values():
pubnub.set_token(token)
# self.create_task(self.reinit())
def _create_activity_callback(self, device: Device):
"""Creates an activity callback for a given device."""
@ -286,7 +374,7 @@ class SnooMQTT(mqtt.Mqtt):
return activity_callback
def _publish_discovery(self, devices: list[Device]):
async def _publish_discovery(self, devices: list[Device]):
"""Publishes discovery messages for the provided devices."""
for device in devices:
for entity in ENTITIES:
@ -321,20 +409,26 @@ class SnooMQTT(mqtt.Mqtt):
self.log("topic: %s payload: %s", topic, payload)
self.mqtt_publish(topic, payload=payload)
# App arguments
@property
def _username(self):
def _username(self) -> str:
username = self.args.get("username")
if not username:
raise ValueError("Must provide a username")
return username
@property
def _password(self):
def _password(self) -> str:
password = self.args.get("password")
if not password:
raise ValueError("Must provide a password")
return password
@property
def _token_path(self):
def _token_path(self) -> str:
return self.args.get("token_path", "snoo_token.json")
@property
def _polling(self) -> bool:
return self.args.get("polling", False)