WIP: Hackier overriding for connections
This commit is contained in:
parent
4af9ffc0f5
commit
39157aa368
@ -1,27 +1,32 @@
|
|||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
|
from asyncio import Future
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any
|
from typing import Any, Never, NoReturn
|
||||||
|
from typing import cast
|
||||||
|
|
||||||
import mqttapi as mqtt
|
import mqttapi as mqtt
|
||||||
from pubnub.enums import PNReconnectionPolicy
|
from pubnub.enums import PNReconnectionPolicy
|
||||||
from pysnoo import ActivityState
|
from pysnoo import ActivityState
|
||||||
from pysnoo import Device
|
from pysnoo import Device
|
||||||
|
from pysnoo import SessionLevel
|
||||||
from pysnoo import Snoo
|
from pysnoo import Snoo
|
||||||
from pysnoo import SnooAuthSession
|
from pysnoo import SnooAuthSession
|
||||||
from pysnoo import SnooPubNub as SnooPubNubBase
|
from pysnoo import SnooPubNub as SnooPubNubBase
|
||||||
from pysnoo.models import EventType
|
from pysnoo.models import EventType
|
||||||
|
|
||||||
|
|
||||||
# HACK: Avoid error on missing EventType
|
|
||||||
EventType._missing_ = lambda x: EventType.ACTIVITY
|
|
||||||
|
|
||||||
# Set global log level to debug
|
# Set global log level to debug
|
||||||
logging.getLogger().setLevel(logging.DEBUG)
|
logging.getLogger().setLevel(logging.DEBUG)
|
||||||
|
|
||||||
# Use polling or listening
|
|
||||||
POLLING = True
|
# HACK: Avoid error on missing EventType
|
||||||
|
EventType._missing_ = lambda value: EventType.ACTIVITY
|
||||||
|
|
||||||
|
# TODO: Catch and handle this:
|
||||||
|
# Error in Snoo PubNub Listener of Category: 3
|
||||||
|
# Exception in subscribe loop: HTTP Client Error (403): {'message': 'Forbidden', 'payload': {'channels': ['ActivityState.7460194284235017']}, 'error': True, 'service': 'Access Manager', 'status': 403}
|
||||||
|
|
||||||
|
|
||||||
# HACK: Subclass to modify original pubnub policy
|
# HACK: Subclass to modify original pubnub policy
|
||||||
@ -35,6 +40,14 @@ class SnooPubNub(SnooPubNubBase):
|
|||||||
pnconfig.reconnect_policy = PNReconnectionPolicy.EXPONENTIAL
|
pnconfig.reconnect_policy = PNReconnectionPolicy.EXPONENTIAL
|
||||||
return pnconfig
|
return pnconfig
|
||||||
|
|
||||||
|
async def await_disconnect(self):
|
||||||
|
"""Await disconnect"""
|
||||||
|
if self._listener.is_connected():
|
||||||
|
await self._listener.wait_for_disconnect()
|
||||||
|
|
||||||
|
def set_token(self, token):
|
||||||
|
return self._pubnub.set_token(token)
|
||||||
|
|
||||||
|
|
||||||
def _serialize_device_info(device: Device) -> dict[str, Any]:
|
def _serialize_device_info(device: Device) -> dict[str, Any]:
|
||||||
return {
|
return {
|
||||||
@ -110,7 +123,7 @@ async def status_switch_callback(pubnub: SnooPubNub, data: dict[str, Any]):
|
|||||||
if data["state"] == "ON":
|
if data["state"] == "ON":
|
||||||
await pubnub.publish_start()
|
await pubnub.publish_start()
|
||||||
else:
|
else:
|
||||||
await pubnub.publish_goto_state("ONLINE")
|
await pubnub.publish_goto_state(SessionLevel.ONLINE)
|
||||||
|
|
||||||
|
|
||||||
async def hold_switch_callback(pubnub: SnooPubNub, data: dict[str, Any]):
|
async def hold_switch_callback(pubnub: SnooPubNub, data: dict[str, Any]):
|
||||||
@ -167,31 +180,39 @@ ENTITIES = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def _get_token(token_file: str):
|
class Token:
|
||||||
"""Read auth token from JSON file."""
|
def __init__(self, file_path: str) -> None:
|
||||||
try:
|
self._path = file_path
|
||||||
with open(token_file) as infile:
|
|
||||||
return json.load(infile)
|
|
||||||
except FileNotFoundError:
|
|
||||||
pass
|
|
||||||
except ValueError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
def value(self) -> dict | None:
|
||||||
|
"""Read auth token from JSON file."""
|
||||||
|
try:
|
||||||
|
with open(self._path) as infile:
|
||||||
|
return json.load(infile)
|
||||||
|
except FileNotFoundError:
|
||||||
|
pass
|
||||||
|
except ValueError:
|
||||||
|
pass
|
||||||
|
|
||||||
def _get_token_updater(token_file: str):
|
return None
|
||||||
def token_updater(token):
|
|
||||||
with open(token_file, "w") as outfile:
|
|
||||||
json.dump(token, outfile)
|
|
||||||
|
|
||||||
return token_updater
|
@property
|
||||||
|
def updater(self) -> Callable[[dict], None]:
|
||||||
|
def token_updater(token: dict):
|
||||||
|
with open(self._path, "w") as outfile:
|
||||||
|
json.dump(token, outfile)
|
||||||
|
|
||||||
|
return token_updater
|
||||||
|
|
||||||
|
|
||||||
def _device_to_pubnub(snoo: Snoo, device: Device) -> SnooPubNub:
|
def _device_to_pubnub(snoo: Snoo, device: Device) -> SnooPubNub:
|
||||||
|
assert snoo.auth.access_token
|
||||||
pubnub = SnooPubNub(
|
pubnub = SnooPubNub(
|
||||||
snoo.auth.access_token,
|
snoo.auth.access_token,
|
||||||
device.serial_number,
|
device.serial_number,
|
||||||
f"pn-pysnoo-{device.serial_number}",
|
f"pn-pysnoo-{device.serial_number}",
|
||||||
)
|
)
|
||||||
|
|
||||||
# TODO: Add some way of passing this config into pysnoo
|
# TODO: Add some way of passing this config into pysnoo
|
||||||
# Another temp option to subclass SnooPubNub
|
# Another temp option to subclass SnooPubNub
|
||||||
# pubnub._pubnub.config.reconnect_policy = PNReconnectionPolicy.EXPONENTIAL
|
# pubnub._pubnub.config.reconnect_policy = PNReconnectionPolicy.EXPONENTIAL
|
||||||
@ -200,69 +221,114 @@ def _device_to_pubnub(snoo: Snoo, device: Device) -> SnooPubNub:
|
|||||||
|
|
||||||
class SnooMQTT(mqtt.Mqtt):
|
class SnooMQTT(mqtt.Mqtt):
|
||||||
def __init__(self, *args, **kwargs) -> None:
|
def __init__(self, *args, **kwargs) -> None:
|
||||||
self._snoo: Device | None = None
|
|
||||||
self._devices: list[Device] = []
|
self._devices: list[Device] = []
|
||||||
self._pubnubs: dict[str, SnooPubNub] = {}
|
self._pubnubs: dict[str, SnooPubNub] = {}
|
||||||
|
self._session: SnooAuthSession | None = None
|
||||||
|
self._snoo_subscription: Future | None = None
|
||||||
|
self._snoo_poll_timer_handle: str|None = None
|
||||||
|
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
async def initialize(self) -> None:
|
||||||
token_path = self._token_path
|
self._session = await self.authorize_session()
|
||||||
token = _get_token(token_path)
|
snoo = Snoo(self._session)
|
||||||
token_updater = _get_token_updater(token_path)
|
|
||||||
|
|
||||||
async with SnooAuthSession(token, token_updater) as auth:
|
# Get devices to be monitored
|
||||||
if not auth.authorized:
|
self._devices = await snoo.get_devices()
|
||||||
new_token = await auth.fetch_token(self._username, self._password)
|
|
||||||
token_updater(new_token)
|
|
||||||
self.log("got inital token")
|
|
||||||
|
|
||||||
self._snoo = Snoo(auth)
|
|
||||||
|
|
||||||
self._devices = await self._snoo.get_devices()
|
|
||||||
if not self._devices:
|
if not self._devices:
|
||||||
raise ValueError("No Snoo devices connected to account")
|
raise ValueError("No Snoo devices connected to account")
|
||||||
|
|
||||||
# Publish discovery information
|
# Publish discovery information
|
||||||
self._publish_discovery(self._devices)
|
await self._publish_discovery(self._devices)
|
||||||
|
|
||||||
# Subscribe to updates
|
|
||||||
for device in self._devices:
|
|
||||||
pubnub = _device_to_pubnub(self._snoo, device)
|
|
||||||
self._pubnubs[device.serial_number] = pubnub
|
|
||||||
|
|
||||||
if POLLING:
|
|
||||||
await self._schedule_updates(device, pubnub)
|
|
||||||
else:
|
|
||||||
await self._subscribe_and_listen(device, pubnub)
|
|
||||||
|
|
||||||
# Listen for home assistant status to republish discovery
|
# Listen for home assistant status to republish discovery
|
||||||
self.listen_event(
|
self._birth_listener_handle = await self.listen_event(
|
||||||
self.birth_callback,
|
self.birth_callback,
|
||||||
"MQTT_MESSAGE",
|
"MQTT_MESSAGE",
|
||||||
topic="homeassistant/status",
|
topic="homeassistant/status",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
await self.snoo_subscribe(snoo)
|
||||||
|
|
||||||
|
async def terminate(self) -> None:
|
||||||
|
# Stop listening to birth events
|
||||||
|
if self._birth_listener_handle:
|
||||||
|
await self.cancel_listen_event(self._birth_listener_handle)
|
||||||
|
|
||||||
|
# Stop polling for updates
|
||||||
|
if self._snoo_poll_timer_handle:
|
||||||
|
await self.cancel_timer(self._snoo_poll_timer_handle)
|
||||||
|
|
||||||
|
# Stop listening task
|
||||||
|
if self._snoo_subscription:
|
||||||
|
self._snoo_subscription.cancel()
|
||||||
|
|
||||||
|
# Close the session
|
||||||
|
if self._session:
|
||||||
|
await self._session.close()
|
||||||
|
|
||||||
|
async def reinit(self) -> None:
|
||||||
|
"""Reinitialize the app by terminating everything and starting up again."""
|
||||||
|
await self.terminate()
|
||||||
|
await self.initialize()
|
||||||
|
|
||||||
|
async def authorize_session(self) -> SnooAuthSession:
|
||||||
|
# Get token and updator
|
||||||
|
token = Token(self._token_path)
|
||||||
|
|
||||||
|
# Create and pre-authorize session
|
||||||
|
session = SnooAuthSession(token.value() or {}, token.updater)
|
||||||
|
if not session.authorized:
|
||||||
|
new_token = await session.fetch_token(self._username, self._password)
|
||||||
|
token.updater(new_token)
|
||||||
|
self.log("got inital token")
|
||||||
|
|
||||||
|
return session
|
||||||
|
|
||||||
async def birth_callback(self, event_name, data, *args) -> None:
|
async def birth_callback(self, event_name, data, *args) -> None:
|
||||||
"""Callback listening for hass status messages.
|
"""Callback listening for hass status messages.
|
||||||
|
|
||||||
Should republish discovery and initial status for every device."""
|
Should republish discovery and initial status for every device."""
|
||||||
self.log("Read hass status event: %s, %s", event_name, data)
|
self.log("Read hass status event: %s, %s", event_name, data)
|
||||||
|
|
||||||
self._publish_discovery(self._devices)
|
await self._publish_discovery(self._devices)
|
||||||
for device in self._devices:
|
for device in self._devices:
|
||||||
pubnub = self._pubnubs[device.serial_number]
|
pubnub = self._pubnubs[device.serial_number]
|
||||||
for activity_state in await pubnub.history(1):
|
for activity_state in await pubnub.history(1):
|
||||||
self._create_activity_callback(device)(activity_state)
|
self._create_activity_callback(device)(activity_state)
|
||||||
|
|
||||||
async def _schedule_updates(self, device: Device, pubnub: SnooPubNub):
|
async def snoo_subscribe(self, snoo: Snoo) -> None:
|
||||||
|
"""Creates AppDaemon subscription task to listen to Snoo updates."""
|
||||||
|
# Subscribe and start listening to Snoo updates
|
||||||
|
self._snoo_subscription = self.create_task(self.snoo_real_subscribe(snoo))
|
||||||
|
|
||||||
|
async def snoo_real_subscribe(self, snoo: Snoo):
|
||||||
|
"""Coroutine that subscribes to updates from Snoo.
|
||||||
|
|
||||||
|
Never returns."""
|
||||||
|
# This should only run after initialize, so this should not be None
|
||||||
|
# Subscribe to updates
|
||||||
|
for device in self._devices:
|
||||||
|
pubnub = _device_to_pubnub(snoo, device)
|
||||||
|
self._pubnubs[device.serial_number] = pubnub
|
||||||
|
|
||||||
|
if self._polling:
|
||||||
|
self.log("Scheduling updates using polling")
|
||||||
|
self._snoo_poll_timer_handle = await self._schedule_updates(device)
|
||||||
|
else:
|
||||||
|
self.log("Scheduling updates listening")
|
||||||
|
await self._subscribe_and_listen(device, pubnub)
|
||||||
|
|
||||||
|
async def _schedule_updates(self, device: Device) -> str:
|
||||||
"""Schedules from pubnub history periodic updates."""
|
"""Schedules from pubnub history periodic updates."""
|
||||||
cb = self._create_activity_callback(device)
|
cb = self._create_activity_callback(device)
|
||||||
|
|
||||||
async def poll_history(*args):
|
async def poll_history(*args):
|
||||||
for activity_state in await pubnub.history(1):
|
# TODO: Maybe try/except here for auth errors and then reconnect self.terminate() self.initialize()
|
||||||
|
for activity_state in await self._pubnubs[device.serial_number].history(1):
|
||||||
cb(activity_state)
|
cb(activity_state)
|
||||||
|
|
||||||
self.run_every(poll_history, "now", 30)
|
return await self.run_every(poll_history, "now", 30)
|
||||||
|
|
||||||
async def _subscribe_and_listen(self, device: Device, pubnub: SnooPubNub):
|
async def _subscribe_and_listen(self, device: Device, pubnub: SnooPubNub):
|
||||||
"""Subscribes to pubnub activity and listens to new events."""
|
"""Subscribes to pubnub activity and listens to new events."""
|
||||||
@ -273,7 +339,29 @@ class SnooMQTT(mqtt.Mqtt):
|
|||||||
for activity_state in await pubnub.history(1):
|
for activity_state in await pubnub.history(1):
|
||||||
cb(activity_state)
|
cb(activity_state)
|
||||||
|
|
||||||
|
# TODO: Maybe try/except here for auth errors and then reconnect self.terminate() self.initialize()
|
||||||
await pubnub.subscribe_and_await_connect()
|
await pubnub.subscribe_and_await_connect()
|
||||||
|
self.log("Done subscribing and listening...")
|
||||||
|
await pubnub.await_disconnect()
|
||||||
|
self.log("Disconnected! Maybe re-initialize here")
|
||||||
|
# self.create_task(self.reinit())
|
||||||
|
|
||||||
|
async def _disconnect_listener(self):
|
||||||
|
# Wait for everything to disconnect
|
||||||
|
for pubnub in self._pubnubs.values():
|
||||||
|
await pubnub.await_disconnect()
|
||||||
|
|
||||||
|
self.log("Disconnected! Maybe re-initialize here")
|
||||||
|
|
||||||
|
# Refresh token
|
||||||
|
assert self._session and self._session.auto_refresh_url and self._session.token_updater
|
||||||
|
token = await self._session.refresh_token(self._session.auto_refresh_url)
|
||||||
|
self._session.token_updater(token)
|
||||||
|
|
||||||
|
for pubnub in self._pubnubs.values():
|
||||||
|
pubnub.set_token(token)
|
||||||
|
|
||||||
|
# self.create_task(self.reinit())
|
||||||
|
|
||||||
def _create_activity_callback(self, device: Device):
|
def _create_activity_callback(self, device: Device):
|
||||||
"""Creates an activity callback for a given device."""
|
"""Creates an activity callback for a given device."""
|
||||||
@ -286,7 +374,7 @@ class SnooMQTT(mqtt.Mqtt):
|
|||||||
|
|
||||||
return activity_callback
|
return activity_callback
|
||||||
|
|
||||||
def _publish_discovery(self, devices: list[Device]):
|
async def _publish_discovery(self, devices: list[Device]):
|
||||||
"""Publishes discovery messages for the provided devices."""
|
"""Publishes discovery messages for the provided devices."""
|
||||||
for device in devices:
|
for device in devices:
|
||||||
for entity in ENTITIES:
|
for entity in ENTITIES:
|
||||||
@ -321,20 +409,26 @@ class SnooMQTT(mqtt.Mqtt):
|
|||||||
self.log("topic: %s payload: %s", topic, payload)
|
self.log("topic: %s payload: %s", topic, payload)
|
||||||
self.mqtt_publish(topic, payload=payload)
|
self.mqtt_publish(topic, payload=payload)
|
||||||
|
|
||||||
|
# App arguments
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _username(self):
|
def _username(self) -> str:
|
||||||
username = self.args.get("username")
|
username = self.args.get("username")
|
||||||
if not username:
|
if not username:
|
||||||
raise ValueError("Must provide a username")
|
raise ValueError("Must provide a username")
|
||||||
return username
|
return username
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _password(self):
|
def _password(self) -> str:
|
||||||
password = self.args.get("password")
|
password = self.args.get("password")
|
||||||
if not password:
|
if not password:
|
||||||
raise ValueError("Must provide a password")
|
raise ValueError("Must provide a password")
|
||||||
return password
|
return password
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _token_path(self):
|
def _token_path(self) -> str:
|
||||||
return self.args.get("token_path", "snoo_token.json")
|
return self.args.get("token_path", "snoo_token.json")
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _polling(self) -> bool:
|
||||||
|
return self.args.get("polling", False)
|
||||||
|
Loading…
Reference in New Issue
Block a user