"""WebSocket endpoint for host client connections. Handles: - Connection authentication - Node registration - Message forwarding - Heartbeat """ from __future__ import annotations import asyncio import logging from typing import Optional from fastapi import WebSocket, WebSocketDisconnect, WebSocketException from router.nodes import get_node_registry from router.rpc import handle_task_complete from shared import ( RegisterMessage, ForwardRequest, ForwardResponse, TaskComplete, Heartbeat, NodeStatus, decode, encode, ) logger = logging.getLogger(__name__) async def ws_node_endpoint(websocket: WebSocket) -> None: """WebSocket endpoint for host client connections.""" await websocket.accept() registry = get_node_registry() secret = websocket.headers.get("authorization", "") if secret.startswith("Bearer "): secret = secret[7:] if not registry.validate_secret(secret): logger.warning("Invalid router secret, rejecting connection") await websocket.close(code=4001, reason="Invalid secret") return node_id: Optional[str] = None heartbeat_task: Optional[asyncio.Task] = None async def send_heartbeat(): """Send periodic pings to the host client.""" try: while True: await asyncio.sleep(30) try: await websocket.send_text(encode(Heartbeat(type="ping"))) except Exception: break except asyncio.CancelledError: pass try: async for data in websocket.iter_text(): try: msg = decode(data) except Exception as e: logger.error("Failed to decode message: %s", e) continue if isinstance(msg, RegisterMessage): node_id = msg.node_id await registry.register(websocket, msg) heartbeat_task = asyncio.create_task(send_heartbeat()) elif isinstance(msg, ForwardResponse): from router.rpc import resolve_response await resolve_response(msg) elif isinstance(msg, TaskComplete): await handle_task_complete(msg) elif isinstance(msg, Heartbeat): if msg.type == "pong" and node_id: await registry.update_heartbeat(node_id) elif isinstance(msg, NodeStatus): await registry.update_status(msg) else: logger.debug("Received unhandled message type: %s", type(msg).__name__) except WebSocketDisconnect: logger.info("WebSocket disconnected") except Exception as e: logger.exception("WebSocket error: %s", e) finally: if heartbeat_task: heartbeat_task.cancel() if node_id: await registry.unregister(node_id)