#! /usr/bin/env python3
from argparse import ArgumentParser
from os import environ
from typing import Any
from typing import cast

import requests


NOMAD_ADDR = environ.get("NOMAD_ADDR", "http://127.0.0.1:4646")
NOMAD_TOKEN = environ.get("NOMAD_TOKEN")


def nomad_req(
    *path: str, params: dict[str, Any] | None = None, method="GET"
) -> list[dict[str, Any]] | dict[str, Any] | str:
    headers = {}
    if NOMAD_TOKEN:
        headers["X-Nomad-Token"] = NOMAD_TOKEN

    response = requests.request(
        method,
        f"{NOMAD_ADDR}/v1/{'/'.join(path)}",
        params=params,
        headers=headers,
    )
    response.raise_for_status()

    try:
        return response.json()
    except requests.exceptions.JSONDecodeError:
        return response.text


def extract_job_services(job: dict[str, Any]) -> dict[str, str]:
    services: dict[str, str] = dict()
    for group in job["TaskGroups"]:
        for service in group.get("Services") or []:
            services[service["Name"]] = group["Name"]
        for task in group["Tasks"]:
            for service in task.get("Services") or []:
                services[service["Name"]] = group["Name"]

    return services

exit_code = 0
parser = ArgumentParser(
    description="Checks for missing services and optionally restarts their allocs.",
)
parser.add_argument("-r", "--restart", action="store_true", help="Restart allocs for missing services")
args = parser.parse_args()

for job in nomad_req("jobs"):
    job = cast(dict[str, Any], job)

    if job["Type"] in ("batch", "sysbatch"):
        continue

    if job["Status"] != "running":
        print(f"WARNING: job {job['Name']} is {job['Status']}")
        continue

    job_detail = nomad_req("job", job["ID"])
    job_detail = cast(dict[str, Any], job_detail)

    expected_services = extract_job_services(job_detail)

    found_services: set[str] = set()
    for service in nomad_req("job", job_detail["ID"], "services"):
        service = cast(dict[str, Any], service)
        found_services.add(service["ServiceName"])

    missing_services = set(expected_services) - found_services
    restart_groups: set[str] = set()
    for missing_service in missing_services:
        print(f"ERROR: Missing service {missing_service} for job {job_detail['Name']}")
        # print(job)
        exit_code = 1

        # Add group associated with missing service to set
        restart_groups.add(expected_services[missing_service])

    if not restart_groups or not args.restart:
        continue

    # Get allocts for groups that are missing services
    restart_allocs: set[str] = set()
    for allocation in nomad_req("job", job_detail["ID"], "allocations"):
        allocation = cast(dict[str, Any], allocation)
        if allocation["ClientStatus"] == "running" and allocation["TaskGroup"] in restart_groups:
            restart_allocs.add(allocation["ID"])

    # Restart allocs associated with missing services
    for allocation in restart_allocs:
        print(f"INFO: Restarting allocation {allocation}")
        nomad_req("client", "allocation", allocation, "restart")


exit(exit_code)