"""Host client main module. Connects to the router via WebSocket, receives forwarded messages, runs the local mailboy LLM, and sends responses back. """ from __future__ import annotations import asyncio import logging import secrets import time from typing import Optional import websockets from websockets.client import WebSocketClientProtocol from agent.manager import manager from agent.scheduler import scheduler from agent.task_runner import task_runner from host_client.config import HostConfig, get_host_config from orchestrator.agent import run as run_mailboy from orchestrator.tools import set_current_user, set_current_chat from shared import ( RegisterMessage, ForwardRequest, ForwardResponse, TaskComplete, Heartbeat, NodeStatus, encode, decode, ) logger = logging.getLogger(__name__) class NodeClient: """WebSocket client that connects to the router and handles messages.""" def __init__(self, config: HostConfig): self.config = config self.ws: Optional[WebSocketClientProtocol] = None self._running = False self._last_heartbeat = time.time() self._reconnect_delay = 1.0 async def connect(self) -> bool: """Connect to the router WebSocket.""" headers = {} if self.config.router_secret: headers["Authorization"] = f"Bearer {self.config.router_secret}" try: self.ws = await websockets.connect( self.config.router_url, extra_headers=headers, ping_interval=30, ping_timeout=10, ) logger.info("Connected to router: %s", self.config.router_url) self._reconnect_delay = 1.0 return True except Exception as e: logger.error("Failed to connect to router: %s", e) return False async def register(self) -> bool: """Send registration message to the router.""" if not self.ws: return False msg = RegisterMessage( node_id=self.config.node_id, display_name=self.config.display_name, serves_users=self.config.serves_users, working_dir=self.config.working_dir, capabilities=self.config.capabilities, ) try: await self.ws.send(encode(msg)) logger.info("Sent registration for node: %s", self.config.node_id) return True except Exception as e: logger.error("Failed to send registration: %s", e) return False async def handle_forward(self, request: ForwardRequest) -> None: """Handle a forwarded message from the router.""" logger.info("Received forward request %s from user %s", request.id, request.user_id) set_current_user(request.user_id) set_current_chat(request.chat_id) try: reply = await run_mailboy(request.user_id, request.text) response = ForwardResponse( id=request.id, reply=reply, error="", ) except Exception as e: logger.exception("Error processing forward request %s", request.id) response = ForwardResponse( id=request.id, reply="", error=str(e), ) if self.ws: try: await self.ws.send(encode(response)) except Exception as e: logger.error("Failed to send response: %s", e) async def send_heartbeat(self) -> None: """Send a ping heartbeat to the router.""" if self.ws: try: await self.ws.send(encode(Heartbeat(type="ping"))) self._last_heartbeat = time.time() except Exception as e: logger.error("Failed to send heartbeat: %s", e) async def send_status(self) -> None: """Send node status update to the router.""" if not self.ws: return sessions = manager.list_sessions() active_sessions = [ {"conv_id": s["conv_id"], "working_dir": s["working_dir"]} for s in sessions ] status = NodeStatus( node_id=self.config.node_id, sessions=len(sessions), active_sessions=active_sessions, ) try: await self.ws.send(encode(status)) except Exception as e: logger.error("Failed to send status: %s", e) async def handle_message(self, data: str) -> None: """Handle an incoming message from the router.""" try: msg = decode(data) except Exception as e: logger.error("Failed to decode message: %s", e) return if isinstance(msg, ForwardRequest): await self.handle_forward(msg) elif isinstance(msg, Heartbeat): if msg.type == "ping": if self.ws: try: await self.ws.send(encode(Heartbeat(type="pong"))) except Exception as e: logger.error("Failed to send pong: %s", e) elif msg.type == "pong": self._last_heartbeat = time.time() else: logger.debug("Received message type: %s", msg.type) async def receive_loop(self) -> None: """Main receive loop for incoming messages.""" if not self.ws: return try: async for data in self.ws: await self.handle_message(data) except websockets.ConnectionClosed as e: logger.warning("Connection closed: %s", e) except Exception as e: logger.exception("Error in receive loop: %s", e) async def heartbeat_loop(self) -> None: """Periodic heartbeat loop.""" while self._running: await asyncio.sleep(30) if self.ws and self.ws.open: await self.send_heartbeat() async def status_loop(self) -> None: """Periodic status update loop.""" while self._running: await asyncio.sleep(60) if self.ws and self.ws.open: await self.send_status() async def run(self) -> None: """Main run loop with reconnection.""" self._running = True await manager.start() await scheduler.start() task_runner.set_notification_handler(self._send_task_complete) while self._running: if await self.connect(): if await self.register(): try: await asyncio.gather( self.receive_loop(), self.heartbeat_loop(), self.status_loop(), ) except Exception: pass if self._running: logger.info("Reconnecting in %.1f seconds...", self._reconnect_delay) await asyncio.sleep(self._reconnect_delay) self._reconnect_delay = min(self._reconnect_delay * 2, 60) async def _send_task_complete(self, task) -> None: """Send TaskComplete notification to router.""" if not self.ws: return from shared import TaskComplete, encode msg = TaskComplete( task_id=task.task_id, user_id=task.user_id or "", chat_id=task.notify_chat_id or "", result=task.result or task.error or "", ) try: await self.ws.send(encode(msg)) logger.info("Sent TaskComplete for task %s", task.task_id) except Exception as e: logger.error("Failed to send TaskComplete: %s", e) async def stop(self) -> None: """Stop the client.""" self._running = False if self.ws: await self.ws.close() await manager.stop() await scheduler.stop() logger.info("Node client stopped") @classmethod def from_keyring(cls, router_url: Optional[str] = None, secret: Optional[str] = None) -> "NodeClient": """Create a client from keyring.yaml (for standalone mode).""" config = HostConfig.from_keyring() if router_url: config.router_url = router_url if secret: config.router_secret = secret return cls(config) async def main() -> None: """Entry point for standalone host client.""" logging.basicConfig( level=logging.INFO, format="%(asctime)s [%(levelname)s] %(name)s: %(message)s", ) client = NodeClient(get_host_config()) try: await client.run() except KeyboardInterrupt: await client.stop() if __name__ == "__main__": asyncio.run(main())