WIP: Hackier overriding for connections
This commit is contained in:
parent
4af9ffc0f5
commit
39157aa368
@ -1,27 +1,32 @@
|
||||
import json
|
||||
import logging
|
||||
from asyncio import Future
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
from typing import Any, Never, NoReturn
|
||||
from typing import cast
|
||||
|
||||
import mqttapi as mqtt
|
||||
from pubnub.enums import PNReconnectionPolicy
|
||||
from pysnoo import ActivityState
|
||||
from pysnoo import Device
|
||||
from pysnoo import SessionLevel
|
||||
from pysnoo import Snoo
|
||||
from pysnoo import SnooAuthSession
|
||||
from pysnoo import SnooPubNub as SnooPubNubBase
|
||||
from pysnoo.models import EventType
|
||||
|
||||
|
||||
# HACK: Avoid error on missing EventType
|
||||
EventType._missing_ = lambda x: EventType.ACTIVITY
|
||||
|
||||
# Set global log level to debug
|
||||
logging.getLogger().setLevel(logging.DEBUG)
|
||||
|
||||
# Use polling or listening
|
||||
POLLING = True
|
||||
|
||||
# HACK: Avoid error on missing EventType
|
||||
EventType._missing_ = lambda value: EventType.ACTIVITY
|
||||
|
||||
# TODO: Catch and handle this:
|
||||
# Error in Snoo PubNub Listener of Category: 3
|
||||
# Exception in subscribe loop: HTTP Client Error (403): {'message': 'Forbidden', 'payload': {'channels': ['ActivityState.7460194284235017']}, 'error': True, 'service': 'Access Manager', 'status': 403}
|
||||
|
||||
|
||||
# HACK: Subclass to modify original pubnub policy
|
||||
@ -35,6 +40,14 @@ class SnooPubNub(SnooPubNubBase):
|
||||
pnconfig.reconnect_policy = PNReconnectionPolicy.EXPONENTIAL
|
||||
return pnconfig
|
||||
|
||||
async def await_disconnect(self):
|
||||
"""Await disconnect"""
|
||||
if self._listener.is_connected():
|
||||
await self._listener.wait_for_disconnect()
|
||||
|
||||
def set_token(self, token):
|
||||
return self._pubnub.set_token(token)
|
||||
|
||||
|
||||
def _serialize_device_info(device: Device) -> dict[str, Any]:
|
||||
return {
|
||||
@ -110,7 +123,7 @@ async def status_switch_callback(pubnub: SnooPubNub, data: dict[str, Any]):
|
||||
if data["state"] == "ON":
|
||||
await pubnub.publish_start()
|
||||
else:
|
||||
await pubnub.publish_goto_state("ONLINE")
|
||||
await pubnub.publish_goto_state(SessionLevel.ONLINE)
|
||||
|
||||
|
||||
async def hold_switch_callback(pubnub: SnooPubNub, data: dict[str, Any]):
|
||||
@ -167,31 +180,39 @@ ENTITIES = [
|
||||
]
|
||||
|
||||
|
||||
def _get_token(token_file: str):
|
||||
"""Read auth token from JSON file."""
|
||||
try:
|
||||
with open(token_file) as infile:
|
||||
return json.load(infile)
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
except ValueError:
|
||||
pass
|
||||
class Token:
|
||||
def __init__(self, file_path: str) -> None:
|
||||
self._path = file_path
|
||||
|
||||
def value(self) -> dict | None:
|
||||
"""Read auth token from JSON file."""
|
||||
try:
|
||||
with open(self._path) as infile:
|
||||
return json.load(infile)
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
def _get_token_updater(token_file: str):
|
||||
def token_updater(token):
|
||||
with open(token_file, "w") as outfile:
|
||||
json.dump(token, outfile)
|
||||
return None
|
||||
|
||||
return token_updater
|
||||
@property
|
||||
def updater(self) -> Callable[[dict], None]:
|
||||
def token_updater(token: dict):
|
||||
with open(self._path, "w") as outfile:
|
||||
json.dump(token, outfile)
|
||||
|
||||
return token_updater
|
||||
|
||||
|
||||
def _device_to_pubnub(snoo: Snoo, device: Device) -> SnooPubNub:
|
||||
assert snoo.auth.access_token
|
||||
pubnub = SnooPubNub(
|
||||
snoo.auth.access_token,
|
||||
device.serial_number,
|
||||
f"pn-pysnoo-{device.serial_number}",
|
||||
)
|
||||
|
||||
# TODO: Add some way of passing this config into pysnoo
|
||||
# Another temp option to subclass SnooPubNub
|
||||
# pubnub._pubnub.config.reconnect_policy = PNReconnectionPolicy.EXPONENTIAL
|
||||
@ -200,69 +221,114 @@ def _device_to_pubnub(snoo: Snoo, device: Device) -> SnooPubNub:
|
||||
|
||||
class SnooMQTT(mqtt.Mqtt):
|
||||
def __init__(self, *args, **kwargs) -> None:
|
||||
self._snoo: Device | None = None
|
||||
self._devices: list[Device] = []
|
||||
self._pubnubs: dict[str, SnooPubNub] = {}
|
||||
self._session: SnooAuthSession | None = None
|
||||
self._snoo_subscription: Future | None = None
|
||||
self._snoo_poll_timer_handle: str|None = None
|
||||
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
async def initialize(self) -> None:
|
||||
token_path = self._token_path
|
||||
token = _get_token(token_path)
|
||||
token_updater = _get_token_updater(token_path)
|
||||
self._session = await self.authorize_session()
|
||||
snoo = Snoo(self._session)
|
||||
|
||||
async with SnooAuthSession(token, token_updater) as auth:
|
||||
if not auth.authorized:
|
||||
new_token = await auth.fetch_token(self._username, self._password)
|
||||
token_updater(new_token)
|
||||
self.log("got inital token")
|
||||
|
||||
self._snoo = Snoo(auth)
|
||||
|
||||
self._devices = await self._snoo.get_devices()
|
||||
# Get devices to be monitored
|
||||
self._devices = await snoo.get_devices()
|
||||
if not self._devices:
|
||||
raise ValueError("No Snoo devices connected to account")
|
||||
|
||||
# Publish discovery information
|
||||
self._publish_discovery(self._devices)
|
||||
|
||||
# Subscribe to updates
|
||||
for device in self._devices:
|
||||
pubnub = _device_to_pubnub(self._snoo, device)
|
||||
self._pubnubs[device.serial_number] = pubnub
|
||||
|
||||
if POLLING:
|
||||
await self._schedule_updates(device, pubnub)
|
||||
else:
|
||||
await self._subscribe_and_listen(device, pubnub)
|
||||
await self._publish_discovery(self._devices)
|
||||
|
||||
# Listen for home assistant status to republish discovery
|
||||
self.listen_event(
|
||||
self._birth_listener_handle = await self.listen_event(
|
||||
self.birth_callback,
|
||||
"MQTT_MESSAGE",
|
||||
topic="homeassistant/status",
|
||||
)
|
||||
|
||||
await self.snoo_subscribe(snoo)
|
||||
|
||||
async def terminate(self) -> None:
|
||||
# Stop listening to birth events
|
||||
if self._birth_listener_handle:
|
||||
await self.cancel_listen_event(self._birth_listener_handle)
|
||||
|
||||
# Stop polling for updates
|
||||
if self._snoo_poll_timer_handle:
|
||||
await self.cancel_timer(self._snoo_poll_timer_handle)
|
||||
|
||||
# Stop listening task
|
||||
if self._snoo_subscription:
|
||||
self._snoo_subscription.cancel()
|
||||
|
||||
# Close the session
|
||||
if self._session:
|
||||
await self._session.close()
|
||||
|
||||
async def reinit(self) -> None:
|
||||
"""Reinitialize the app by terminating everything and starting up again."""
|
||||
await self.terminate()
|
||||
await self.initialize()
|
||||
|
||||
async def authorize_session(self) -> SnooAuthSession:
|
||||
# Get token and updator
|
||||
token = Token(self._token_path)
|
||||
|
||||
# Create and pre-authorize session
|
||||
session = SnooAuthSession(token.value() or {}, token.updater)
|
||||
if not session.authorized:
|
||||
new_token = await session.fetch_token(self._username, self._password)
|
||||
token.updater(new_token)
|
||||
self.log("got inital token")
|
||||
|
||||
return session
|
||||
|
||||
async def birth_callback(self, event_name, data, *args) -> None:
|
||||
"""Callback listening for hass status messages.
|
||||
|
||||
Should republish discovery and initial status for every device."""
|
||||
self.log("Read hass status event: %s, %s", event_name, data)
|
||||
|
||||
self._publish_discovery(self._devices)
|
||||
await self._publish_discovery(self._devices)
|
||||
for device in self._devices:
|
||||
pubnub = self._pubnubs[device.serial_number]
|
||||
for activity_state in await pubnub.history(1):
|
||||
self._create_activity_callback(device)(activity_state)
|
||||
|
||||
async def _schedule_updates(self, device: Device, pubnub: SnooPubNub):
|
||||
async def snoo_subscribe(self, snoo: Snoo) -> None:
|
||||
"""Creates AppDaemon subscription task to listen to Snoo updates."""
|
||||
# Subscribe and start listening to Snoo updates
|
||||
self._snoo_subscription = self.create_task(self.snoo_real_subscribe(snoo))
|
||||
|
||||
async def snoo_real_subscribe(self, snoo: Snoo):
|
||||
"""Coroutine that subscribes to updates from Snoo.
|
||||
|
||||
Never returns."""
|
||||
# This should only run after initialize, so this should not be None
|
||||
# Subscribe to updates
|
||||
for device in self._devices:
|
||||
pubnub = _device_to_pubnub(snoo, device)
|
||||
self._pubnubs[device.serial_number] = pubnub
|
||||
|
||||
if self._polling:
|
||||
self.log("Scheduling updates using polling")
|
||||
self._snoo_poll_timer_handle = await self._schedule_updates(device)
|
||||
else:
|
||||
self.log("Scheduling updates listening")
|
||||
await self._subscribe_and_listen(device, pubnub)
|
||||
|
||||
async def _schedule_updates(self, device: Device) -> str:
|
||||
"""Schedules from pubnub history periodic updates."""
|
||||
cb = self._create_activity_callback(device)
|
||||
|
||||
async def poll_history(*args):
|
||||
for activity_state in await pubnub.history(1):
|
||||
# TODO: Maybe try/except here for auth errors and then reconnect self.terminate() self.initialize()
|
||||
for activity_state in await self._pubnubs[device.serial_number].history(1):
|
||||
cb(activity_state)
|
||||
|
||||
self.run_every(poll_history, "now", 30)
|
||||
return await self.run_every(poll_history, "now", 30)
|
||||
|
||||
async def _subscribe_and_listen(self, device: Device, pubnub: SnooPubNub):
|
||||
"""Subscribes to pubnub activity and listens to new events."""
|
||||
@ -273,7 +339,29 @@ class SnooMQTT(mqtt.Mqtt):
|
||||
for activity_state in await pubnub.history(1):
|
||||
cb(activity_state)
|
||||
|
||||
# TODO: Maybe try/except here for auth errors and then reconnect self.terminate() self.initialize()
|
||||
await pubnub.subscribe_and_await_connect()
|
||||
self.log("Done subscribing and listening...")
|
||||
await pubnub.await_disconnect()
|
||||
self.log("Disconnected! Maybe re-initialize here")
|
||||
# self.create_task(self.reinit())
|
||||
|
||||
async def _disconnect_listener(self):
|
||||
# Wait for everything to disconnect
|
||||
for pubnub in self._pubnubs.values():
|
||||
await pubnub.await_disconnect()
|
||||
|
||||
self.log("Disconnected! Maybe re-initialize here")
|
||||
|
||||
# Refresh token
|
||||
assert self._session and self._session.auto_refresh_url and self._session.token_updater
|
||||
token = await self._session.refresh_token(self._session.auto_refresh_url)
|
||||
self._session.token_updater(token)
|
||||
|
||||
for pubnub in self._pubnubs.values():
|
||||
pubnub.set_token(token)
|
||||
|
||||
# self.create_task(self.reinit())
|
||||
|
||||
def _create_activity_callback(self, device: Device):
|
||||
"""Creates an activity callback for a given device."""
|
||||
@ -286,7 +374,7 @@ class SnooMQTT(mqtt.Mqtt):
|
||||
|
||||
return activity_callback
|
||||
|
||||
def _publish_discovery(self, devices: list[Device]):
|
||||
async def _publish_discovery(self, devices: list[Device]):
|
||||
"""Publishes discovery messages for the provided devices."""
|
||||
for device in devices:
|
||||
for entity in ENTITIES:
|
||||
@ -321,20 +409,26 @@ class SnooMQTT(mqtt.Mqtt):
|
||||
self.log("topic: %s payload: %s", topic, payload)
|
||||
self.mqtt_publish(topic, payload=payload)
|
||||
|
||||
# App arguments
|
||||
|
||||
@property
|
||||
def _username(self):
|
||||
def _username(self) -> str:
|
||||
username = self.args.get("username")
|
||||
if not username:
|
||||
raise ValueError("Must provide a username")
|
||||
return username
|
||||
|
||||
@property
|
||||
def _password(self):
|
||||
def _password(self) -> str:
|
||||
password = self.args.get("password")
|
||||
if not password:
|
||||
raise ValueError("Must provide a password")
|
||||
return password
|
||||
|
||||
@property
|
||||
def _token_path(self):
|
||||
def _token_path(self) -> str:
|
||||
return self.args.get("token_path", "snoo_token.json")
|
||||
|
||||
@property
|
||||
def _polling(self) -> bool:
|
||||
return self.args.get("polling", False)
|
||||
|
Loading…
Reference in New Issue
Block a user