341 lines
11 KiB
Python
341 lines
11 KiB
Python
|
import json
|
||
|
import logging
|
||
|
from collections.abc import Callable
|
||
|
from dataclasses import dataclass
|
||
|
from typing import Any
|
||
|
|
||
|
import mqttapi as mqtt
|
||
|
from pubnub.enums import PNReconnectionPolicy
|
||
|
from pysnoo import ActivityState
|
||
|
from pysnoo import Device
|
||
|
from pysnoo import Snoo
|
||
|
from pysnoo import SnooAuthSession
|
||
|
from pysnoo import SnooPubNub as SnooPubNubBase
|
||
|
from pysnoo.models import EventType
|
||
|
|
||
|
|
||
|
# HACK: Avoid error on missing EventType
|
||
|
EventType._missing_ = lambda x: EventType.ACTIVITY
|
||
|
|
||
|
# Set global log level to debug
|
||
|
logging.getLogger().setLevel(logging.DEBUG)
|
||
|
|
||
|
# Use polling or listening
|
||
|
POLLING = True
|
||
|
|
||
|
|
||
|
# HACK: Subclass to modify original pubnub policy
|
||
|
class SnooPubNub(SnooPubNubBase):
|
||
|
"""Subclass of the original Snoo pubnub to alter the reconnect policy."""
|
||
|
|
||
|
@staticmethod
|
||
|
def _setup_pnconfig(access_token, uuid):
|
||
|
"""Generate Setup"""
|
||
|
pnconfig = SnooPubNubBase._setup_pnconfig(access_token, uuid)
|
||
|
pnconfig.reconnect_policy = PNReconnectionPolicy.EXPONENTIAL
|
||
|
return pnconfig
|
||
|
|
||
|
|
||
|
def _serialize_device_info(device: Device) -> dict[str, Any]:
|
||
|
return {
|
||
|
"identifiers": [
|
||
|
device.serial_number,
|
||
|
],
|
||
|
"name": device.baby,
|
||
|
}
|
||
|
|
||
|
|
||
|
@dataclass
|
||
|
class SnooActivityEntity:
|
||
|
"""Represents an entity that should be updated via Activity messages."""
|
||
|
|
||
|
field: str
|
||
|
component: str
|
||
|
device_class: str | None = None
|
||
|
name: str | None = None
|
||
|
value_template: str | None = None
|
||
|
command_callback: Callable | None = None
|
||
|
|
||
|
@property
|
||
|
def friendly_name(self) -> str:
|
||
|
return self.field.replace("_", " ").capitalize()
|
||
|
|
||
|
@property
|
||
|
def config_topic(self) -> str:
|
||
|
entity_name = self.field if self.field != "state" else "snoo_state"
|
||
|
return f"homeassistant/{self.component}/{entity_name}/config"
|
||
|
|
||
|
def device_topic(self, device: Device) -> str:
|
||
|
return f"homeassistant/snoo_{device.serial_number}"
|
||
|
|
||
|
def command_topic(self, device: Device) -> str | None:
|
||
|
if self.component != "switch":
|
||
|
return None
|
||
|
|
||
|
return f"{self.device_topic(device)}/{self.field}/set"
|
||
|
|
||
|
def discovery_message(self, device: Device) -> dict[str, Any]:
|
||
|
payload = {
|
||
|
"~": self.device_topic(device),
|
||
|
"name": self.name or self.friendly_name,
|
||
|
"device_class": self.device_class,
|
||
|
"state_topic": "~/state",
|
||
|
"unique_id": f"{device.serial_number}_{self.field}",
|
||
|
"device": _serialize_device_info(device),
|
||
|
"value_template": self.value_template
|
||
|
or "{{ value_json." + self.field + " }}",
|
||
|
}
|
||
|
|
||
|
if self.component in {"binary_sensor", "switch"}:
|
||
|
# Use "True" and "False" for binary sensor payloads
|
||
|
payload.update(
|
||
|
{
|
||
|
# Safety classes are actually True if unsafe, so flipping these
|
||
|
"payload_on": "True" if self.device_class != "safety" else "False",
|
||
|
"payload_off": "False" if self.device_class != "safety" else "True",
|
||
|
}
|
||
|
)
|
||
|
|
||
|
if command_topic := self.command_topic(device):
|
||
|
payload.update(
|
||
|
{
|
||
|
"command_topic": command_topic,
|
||
|
}
|
||
|
)
|
||
|
|
||
|
return payload
|
||
|
|
||
|
|
||
|
async def status_switch_callback(pubnub: SnooPubNub, data: dict[str, Any]):
|
||
|
if data["state"] == "ON":
|
||
|
await pubnub.publish_start()
|
||
|
else:
|
||
|
await pubnub.publish_goto_state("ONLINE")
|
||
|
|
||
|
|
||
|
async def hold_switch_callback(pubnub: SnooPubNub, data: dict[str, Any]):
|
||
|
last_activity_state = (await pubnub.history(1))[0]
|
||
|
current_state = last_activity_state.state_machine.state
|
||
|
if not current_state.is_active_level():
|
||
|
return
|
||
|
if data["state"] == "ON":
|
||
|
await pubnub.publish_goto_state(current_state, True)
|
||
|
else:
|
||
|
await pubnub.publish_goto_state(current_state, False)
|
||
|
|
||
|
|
||
|
ENTITIES = [
|
||
|
SnooActivityEntity(
|
||
|
"left_safety_clip",
|
||
|
"binary_sensor",
|
||
|
device_class="safety",
|
||
|
),
|
||
|
SnooActivityEntity(
|
||
|
"right_safety_clip",
|
||
|
"binary_sensor",
|
||
|
device_class="safety",
|
||
|
),
|
||
|
SnooActivityEntity(
|
||
|
"weaning",
|
||
|
"binary_sensor",
|
||
|
value_template="{{ value_json.state_machine.weaning }}",
|
||
|
),
|
||
|
SnooActivityEntity(
|
||
|
"is_active_session",
|
||
|
"binary_sensor",
|
||
|
device_class="running",
|
||
|
name="Active Session",
|
||
|
value_template="{{ value_json.state_machine.is_active_session }}",
|
||
|
),
|
||
|
SnooActivityEntity(
|
||
|
"state",
|
||
|
"sensor",
|
||
|
value_template="{{ value_json.state_machine.state }}",
|
||
|
),
|
||
|
SnooActivityEntity(
|
||
|
"status",
|
||
|
"switch",
|
||
|
value_template="{{ value_json.state_machine.state != 'ONLINE' }}",
|
||
|
command_callback=status_switch_callback,
|
||
|
),
|
||
|
SnooActivityEntity(
|
||
|
"hold",
|
||
|
"switch",
|
||
|
value_template="{{ value_json.state_machine.hold }}",
|
||
|
command_callback=hold_switch_callback,
|
||
|
),
|
||
|
]
|
||
|
|
||
|
|
||
|
def _get_token(token_file: str):
|
||
|
"""Read auth token from JSON file."""
|
||
|
try:
|
||
|
with open(token_file) as infile:
|
||
|
return json.load(infile)
|
||
|
except FileNotFoundError:
|
||
|
pass
|
||
|
except ValueError:
|
||
|
pass
|
||
|
|
||
|
|
||
|
def _get_token_updater(token_file: str):
|
||
|
def token_updater(token):
|
||
|
with open(token_file, "w") as outfile:
|
||
|
json.dump(token, outfile)
|
||
|
|
||
|
return token_updater
|
||
|
|
||
|
|
||
|
def _device_to_pubnub(snoo: Snoo, device: Device) -> SnooPubNub:
|
||
|
pubnub = SnooPubNub(
|
||
|
snoo.auth.access_token,
|
||
|
device.serial_number,
|
||
|
f"pn-pysnoo-{device.serial_number}",
|
||
|
)
|
||
|
# TODO: Add some way of passing this config into pysnoo
|
||
|
# Another temp option to subclass SnooPubNub
|
||
|
# pubnub._pubnub.config.reconnect_policy = PNReconnectionPolicy.EXPONENTIAL
|
||
|
return pubnub
|
||
|
|
||
|
|
||
|
class SnooMQTT(mqtt.Mqtt):
|
||
|
def __init__(self, *args, **kwargs) -> None:
|
||
|
self._snoo: Device | None = None
|
||
|
self._devices: list[Device] = []
|
||
|
self._pubnubs: dict[str, SnooPubNub] = {}
|
||
|
super().__init__(*args, **kwargs)
|
||
|
|
||
|
async def initialize(self) -> None:
|
||
|
token_path = self._token_path
|
||
|
token = _get_token(token_path)
|
||
|
token_updater = _get_token_updater(token_path)
|
||
|
|
||
|
async with SnooAuthSession(token, token_updater) as auth:
|
||
|
if not auth.authorized:
|
||
|
new_token = await auth.fetch_token(self._username, self._password)
|
||
|
token_updater(new_token)
|
||
|
self.log("got inital token")
|
||
|
|
||
|
self._snoo = Snoo(auth)
|
||
|
|
||
|
self._devices = await self._snoo.get_devices()
|
||
|
if not self._devices:
|
||
|
raise ValueError("No Snoo devices connected to account")
|
||
|
|
||
|
# Publish discovery information
|
||
|
self._publish_discovery(self._devices)
|
||
|
|
||
|
# Subscribe to updates
|
||
|
for device in self._devices:
|
||
|
pubnub = _device_to_pubnub(self._snoo, device)
|
||
|
self._pubnubs[device.serial_number] = pubnub
|
||
|
|
||
|
if POLLING:
|
||
|
await self._schedule_updates(device, pubnub)
|
||
|
else:
|
||
|
await self._subscribe_and_listen(device, pubnub)
|
||
|
|
||
|
# Listen for home assistant status to republish discovery
|
||
|
self.listen_event(
|
||
|
self.birth_callback,
|
||
|
"MQTT_MESSAGE",
|
||
|
topic="homeassistant/status",
|
||
|
)
|
||
|
|
||
|
async def birth_callback(self, event_name, data, *args) -> None:
|
||
|
"""Callback listening for hass status messages.
|
||
|
|
||
|
Should republish discovery and initial status for every device."""
|
||
|
self.log("Read hass status event: %s, %s", event_name, data)
|
||
|
|
||
|
self._publish_discovery(self._devices)
|
||
|
for device in self._devices:
|
||
|
pubnub = self._pubnubs[device.serial_number]
|
||
|
for activity_state in await pubnub.history(1):
|
||
|
self._create_activity_callback(device)(activity_state)
|
||
|
|
||
|
async def _schedule_updates(self, device: Device, pubnub: SnooPubNub):
|
||
|
"""Schedules from pubnub history periodic updates."""
|
||
|
cb = self._create_activity_callback(device)
|
||
|
|
||
|
async def poll_history(*args):
|
||
|
for activity_state in await pubnub.history(1):
|
||
|
cb(activity_state)
|
||
|
|
||
|
self.run_every(poll_history, "now", 30)
|
||
|
|
||
|
async def _subscribe_and_listen(self, device: Device, pubnub: SnooPubNub):
|
||
|
"""Subscribes to pubnub activity and listens to new events."""
|
||
|
cb = self._create_activity_callback(device)
|
||
|
pubnub.add_listener(cb)
|
||
|
|
||
|
# Publish first status message
|
||
|
for activity_state in await pubnub.history(1):
|
||
|
cb(activity_state)
|
||
|
|
||
|
await pubnub.subscribe_and_await_connect()
|
||
|
|
||
|
def _create_activity_callback(self, device: Device):
|
||
|
"""Creates an activity callback for a given device."""
|
||
|
|
||
|
def activity_callback(activity_state: ActivityState):
|
||
|
self._publish_object(
|
||
|
f"homeassistant/snoo_{device.serial_number}/state",
|
||
|
activity_state,
|
||
|
)
|
||
|
|
||
|
return activity_callback
|
||
|
|
||
|
def _publish_discovery(self, devices: list[Device]):
|
||
|
"""Publishes discovery messages for the provided devices."""
|
||
|
for device in devices:
|
||
|
for entity in ENTITIES:
|
||
|
self._publish_object(
|
||
|
entity.config_topic,
|
||
|
entity.discovery_message(device),
|
||
|
)
|
||
|
|
||
|
# See if we need to listen to commands
|
||
|
if (
|
||
|
entity.command_callback is not None
|
||
|
and entity.command_topic is not None
|
||
|
):
|
||
|
|
||
|
async def cb(self, event_name, data, *args):
|
||
|
assert entity.command_callback
|
||
|
await entity.command_callback(
|
||
|
self._pubnubs[device.serial_number], json.loads(data)
|
||
|
)
|
||
|
|
||
|
self.listen_event(
|
||
|
cb, "MQTT_MESSAGE", topic=entity.command_topic(device)
|
||
|
)
|
||
|
|
||
|
def _publish_object(self, topic, payload_object):
|
||
|
"""Publishes an object, serializing it's payload as JSON."""
|
||
|
try:
|
||
|
payload = json.dumps(payload_object.to_dict())
|
||
|
except AttributeError:
|
||
|
payload = json.dumps(payload_object)
|
||
|
|
||
|
self.log("topic: %s payload: %s", topic, payload)
|
||
|
self.mqtt_publish(topic, payload=payload)
|
||
|
|
||
|
@property
|
||
|
def _username(self):
|
||
|
username = self.args.get("username")
|
||
|
if not username:
|
||
|
raise ValueError("Must provide a username")
|
||
|
return username
|
||
|
|
||
|
@property
|
||
|
def _password(self):
|
||
|
password = self.args.get("password")
|
||
|
if not password:
|
||
|
raise ValueError("Must provide a password")
|
||
|
return password
|
||
|
|
||
|
@property
|
||
|
def _token_path(self):
|
||
|
return self.args.get("token_path", "snoo_token.json")
|