"""Host client main module. Connects to the router via WebSocket, receives forwarded messages, runs the local mailboy LLM, and sends responses back. """ from __future__ import annotations import asyncio import logging import secrets import time from typing import Any, Optional import websockets from agent.manager import manager from agent.scheduler import scheduler from agent.task_runner import task_runner from host_client.config import HostConfig, get_host_config from orchestrator.agent import agent from orchestrator.tools import set_current_user, set_current_chat from shared import ( RegisterMessage, ForwardRequest, ForwardResponse, TaskComplete, Heartbeat, NodeStatus, encode, decode, ) logger = logging.getLogger(__name__) class NodeClient: """WebSocket client that connects to the router and handles messages.""" def __init__(self, config: HostConfig): self.config = config self.ws: Any = None self._running = False self._last_heartbeat = time.time() self._reconnect_delay = 1.0 async def connect(self) -> bool: """Connect to the router WebSocket.""" headers = {} if self.config.router_secret: headers["Authorization"] = f"Bearer {self.config.router_secret}" try: self.ws = await websockets.connect( self.config.router_url, extra_headers=headers, ping_interval=30, ping_timeout=10, ) logger.info("Connected to router: %s", self.config.router_url) self._reconnect_delay = 1.0 return True except Exception as e: logger.error("Failed to connect to router: %s", e) return False async def register(self) -> bool: """Send registration message to the router.""" if not self.ws: return False msg = RegisterMessage( node_id=self.config.node_id, display_name=self.config.display_name, serves_users=self.config.serves_users, working_dir=self.config.working_dir, capabilities=self.config.capabilities, ) try: await self.ws.send(encode(msg)) logger.info("Sent registration for node: %s", self.config.node_id) return True except Exception as e: logger.error("Failed to send registration: %s", e) return False async def handle_forward(self, request: ForwardRequest) -> None: """Handle a forwarded message from the router.""" logger.info("Received forward request %s from user %s", request.id, request.user_id) set_current_user(request.user_id) set_current_chat(request.chat_id) try: reply = await agent.run(request.user_id, request.text) response = ForwardResponse( id=request.id, reply=reply, error="", ) except Exception as e: logger.exception("Error processing forward request %s", request.id) response = ForwardResponse( id=request.id, reply="", error=str(e), ) if self.ws: try: await self.ws.send(encode(response)) except Exception as e: logger.error("Failed to send response: %s", e) async def send_heartbeat(self) -> None: """Send a ping heartbeat to the router.""" if self.ws: try: await self.ws.send(encode(Heartbeat(type="ping"))) self._last_heartbeat = time.time() except Exception as e: logger.error("Failed to send heartbeat: %s", e) async def send_status(self) -> None: """Send node status update to the router.""" if not self.ws: return sessions = manager.list_sessions() active_sessions = [ {"conv_id": s["conv_id"], "working_dir": s["cwd"]} for s in sessions ] status = NodeStatus( node_id=self.config.node_id, sessions=len(sessions), active_sessions=active_sessions, ) try: await self.ws.send(encode(status)) except Exception as e: logger.error("Failed to send status: %s", e) async def handle_message(self, data: str) -> None: """Handle an incoming message from the router.""" try: msg = decode(data) except Exception as e: logger.error("Failed to decode message: %s", e) return if isinstance(msg, ForwardRequest): await self.handle_forward(msg) elif isinstance(msg, Heartbeat): if msg.type == "ping": if self.ws: try: await self.ws.send(encode(Heartbeat(type="pong"))) except Exception as e: logger.error("Failed to send pong: %s", e) elif msg.type == "pong": self._last_heartbeat = time.time() else: logger.debug("Received message type: %s", msg.type) async def receive_loop(self) -> None: """Main receive loop for incoming messages.""" if not self.ws: return try: async for data in self.ws: await self.handle_message(data) except websockets.ConnectionClosed as e: logger.warning("Connection closed: %s", e) except Exception as e: logger.exception("Error in receive loop: %s", e) async def heartbeat_loop(self) -> None: """Periodic heartbeat loop.""" while self._running: await asyncio.sleep(30) if self.ws and self.ws.open: await self.send_heartbeat() async def status_loop(self) -> None: """Periodic status update loop.""" while self._running: await asyncio.sleep(60) if self.ws and self.ws.open: await self.send_status() async def run(self) -> None: """Main run loop with reconnection.""" self._running = True await manager.start() await scheduler.start() task_runner.set_notification_handler(self._send_task_complete) while self._running: if await self.connect(): if await self.register(): try: await asyncio.gather( self.receive_loop(), self.heartbeat_loop(), self.status_loop(), ) except Exception: pass if self._running: logger.info("Reconnecting in %.1f seconds...", self._reconnect_delay) await asyncio.sleep(self._reconnect_delay) self._reconnect_delay = min(self._reconnect_delay * 2, 60) async def _send_task_complete(self, task) -> None: """Send TaskComplete notification to router.""" if not self.ws: return from shared import TaskComplete, encode msg = TaskComplete( task_id=task.task_id, user_id=task.user_id or "", chat_id=task.notify_chat_id or "", result=task.result or task.error or "", ) try: await self.ws.send(encode(msg)) logger.info("Sent TaskComplete for task %s", task.task_id) except Exception as e: logger.error("Failed to send TaskComplete: %s", e) async def stop(self) -> None: """Stop the client.""" self._running = False if self.ws: await self.ws.close() await manager.stop() await scheduler.stop() logger.info("Node client stopped") @classmethod def from_keyring(cls, router_url: Optional[str] = None, secret: Optional[str] = None) -> "NodeClient": """Create a client from keyring.yaml (for standalone mode).""" config = HostConfig.from_keyring() if router_url: config.router_url = router_url if secret: config.router_secret = secret return cls(config) async def main() -> None: """Entry point for standalone host client.""" logging.basicConfig( level=logging.INFO, format="%(asctime)s [%(levelname)s] %(name)s: %(message)s", ) client = NodeClient(get_host_config()) try: await client.run() except KeyboardInterrupt: await client.stop() if __name__ == "__main__": asyncio.run(main())