Yuyao Huang (Sam) 64297e5e27 feat: 实现多主机架构的核心组件
新增路由器、主机客户端和共享协议模块,支持多主机部署模式:
- 路由器作为中央节点管理主机连接和消息路由
- 主机客户端作为工作节点运行本地代理
- 共享协议定义通信消息格式
- 新增独立运行模式standalone.py
- 更新配置系统支持路由模式
2026-03-28 14:08:47 +08:00

281 lines
8.7 KiB
Python

"""Host client main module.
Connects to the router via WebSocket, receives forwarded messages,
runs the local mailboy LLM, and sends responses back.
"""
from __future__ import annotations
import asyncio
import logging
import secrets
import time
from typing import Optional
import websockets
from websockets.client import WebSocketClientProtocol
from agent.manager import manager
from agent.scheduler import scheduler
from agent.task_runner import task_runner
from host_client.config import HostConfig, get_host_config
from orchestrator.agent import run as run_mailboy
from orchestrator.tools import set_current_user, set_current_chat
from shared import (
RegisterMessage,
ForwardRequest,
ForwardResponse,
TaskComplete,
Heartbeat,
NodeStatus,
encode,
decode,
)
logger = logging.getLogger(__name__)
class NodeClient:
"""WebSocket client that connects to the router and handles messages."""
def __init__(self, config: HostConfig):
self.config = config
self.ws: Optional[WebSocketClientProtocol] = None
self._running = False
self._last_heartbeat = time.time()
self._reconnect_delay = 1.0
async def connect(self) -> bool:
"""Connect to the router WebSocket."""
headers = {}
if self.config.router_secret:
headers["Authorization"] = f"Bearer {self.config.router_secret}"
try:
self.ws = await websockets.connect(
self.config.router_url,
extra_headers=headers,
ping_interval=30,
ping_timeout=10,
)
logger.info("Connected to router: %s", self.config.router_url)
self._reconnect_delay = 1.0
return True
except Exception as e:
logger.error("Failed to connect to router: %s", e)
return False
async def register(self) -> bool:
"""Send registration message to the router."""
if not self.ws:
return False
msg = RegisterMessage(
node_id=self.config.node_id,
display_name=self.config.display_name,
serves_users=self.config.serves_users,
working_dir=self.config.working_dir,
capabilities=self.config.capabilities,
)
try:
await self.ws.send(encode(msg))
logger.info("Sent registration for node: %s", self.config.node_id)
return True
except Exception as e:
logger.error("Failed to send registration: %s", e)
return False
async def handle_forward(self, request: ForwardRequest) -> None:
"""Handle a forwarded message from the router."""
logger.info("Received forward request %s from user %s", request.id, request.user_id)
set_current_user(request.user_id)
set_current_chat(request.chat_id)
try:
reply = await run_mailboy(request.user_id, request.text)
response = ForwardResponse(
id=request.id,
reply=reply,
error="",
)
except Exception as e:
logger.exception("Error processing forward request %s", request.id)
response = ForwardResponse(
id=request.id,
reply="",
error=str(e),
)
if self.ws:
try:
await self.ws.send(encode(response))
except Exception as e:
logger.error("Failed to send response: %s", e)
async def send_heartbeat(self) -> None:
"""Send a ping heartbeat to the router."""
if self.ws:
try:
await self.ws.send(encode(Heartbeat(type="ping")))
self._last_heartbeat = time.time()
except Exception as e:
logger.error("Failed to send heartbeat: %s", e)
async def send_status(self) -> None:
"""Send node status update to the router."""
if not self.ws:
return
sessions = manager.list_sessions()
active_sessions = [
{"conv_id": s["conv_id"], "working_dir": s["working_dir"]}
for s in sessions
]
status = NodeStatus(
node_id=self.config.node_id,
sessions=len(sessions),
active_sessions=active_sessions,
)
try:
await self.ws.send(encode(status))
except Exception as e:
logger.error("Failed to send status: %s", e)
async def handle_message(self, data: str) -> None:
"""Handle an incoming message from the router."""
try:
msg = decode(data)
except Exception as e:
logger.error("Failed to decode message: %s", e)
return
if isinstance(msg, ForwardRequest):
await self.handle_forward(msg)
elif isinstance(msg, Heartbeat):
if msg.type == "ping":
if self.ws:
try:
await self.ws.send(encode(Heartbeat(type="pong")))
except Exception as e:
logger.error("Failed to send pong: %s", e)
elif msg.type == "pong":
self._last_heartbeat = time.time()
else:
logger.debug("Received message type: %s", msg.type)
async def receive_loop(self) -> None:
"""Main receive loop for incoming messages."""
if not self.ws:
return
try:
async for data in self.ws:
await self.handle_message(data)
except websockets.ConnectionClosed as e:
logger.warning("Connection closed: %s", e)
except Exception as e:
logger.exception("Error in receive loop: %s", e)
async def heartbeat_loop(self) -> None:
"""Periodic heartbeat loop."""
while self._running:
await asyncio.sleep(30)
if self.ws and self.ws.open:
await self.send_heartbeat()
async def status_loop(self) -> None:
"""Periodic status update loop."""
while self._running:
await asyncio.sleep(60)
if self.ws and self.ws.open:
await self.send_status()
async def run(self) -> None:
"""Main run loop with reconnection."""
self._running = True
await manager.start()
await scheduler.start()
task_runner.set_notification_handler(self._send_task_complete)
while self._running:
if await self.connect():
if await self.register():
try:
await asyncio.gather(
self.receive_loop(),
self.heartbeat_loop(),
self.status_loop(),
)
except Exception:
pass
if self._running:
logger.info("Reconnecting in %.1f seconds...", self._reconnect_delay)
await asyncio.sleep(self._reconnect_delay)
self._reconnect_delay = min(self._reconnect_delay * 2, 60)
async def _send_task_complete(self, task) -> None:
"""Send TaskComplete notification to router."""
if not self.ws:
return
from shared import TaskComplete, encode
msg = TaskComplete(
task_id=task.task_id,
user_id=task.user_id or "",
chat_id=task.notify_chat_id or "",
result=task.result or task.error or "",
)
try:
await self.ws.send(encode(msg))
logger.info("Sent TaskComplete for task %s", task.task_id)
except Exception as e:
logger.error("Failed to send TaskComplete: %s", e)
async def stop(self) -> None:
"""Stop the client."""
self._running = False
if self.ws:
await self.ws.close()
await manager.stop()
await scheduler.stop()
logger.info("Node client stopped")
@classmethod
def from_keyring(cls, router_url: Optional[str] = None, secret: Optional[str] = None) -> "NodeClient":
"""Create a client from keyring.yaml (for standalone mode)."""
config = HostConfig.from_keyring()
if router_url:
config.router_url = router_url
if secret:
config.router_secret = secret
return cls(config)
async def main() -> None:
"""Entry point for standalone host client."""
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
)
client = NodeClient(get_host_config())
try:
await client.run()
except KeyboardInterrupt:
await client.stop()
if __name__ == "__main__":
asyncio.run(main())