Yuyao Huang (Sam) 64297e5e27 feat: 实现多主机架构的核心组件
新增路由器、主机客户端和共享协议模块,支持多主机部署模式:
- 路由器作为中央节点管理主机连接和消息路由
- 主机客户端作为工作节点运行本地代理
- 共享协议定义通信消息格式
- 新增独立运行模式standalone.py
- 更新配置系统支持路由模式
2026-03-28 14:08:47 +08:00

103 lines
2.9 KiB
Python

"""WebSocket endpoint for host client connections.
Handles:
- Connection authentication
- Node registration
- Message forwarding
- Heartbeat
"""
from __future__ import annotations
import asyncio
import logging
from typing import Optional
from fastapi import WebSocket, WebSocketDisconnect, WebSocketException
from router.nodes import get_node_registry
from router.rpc import handle_task_complete
from shared import (
RegisterMessage,
ForwardRequest,
ForwardResponse,
TaskComplete,
Heartbeat,
NodeStatus,
decode,
encode,
)
logger = logging.getLogger(__name__)
async def ws_node_endpoint(websocket: WebSocket) -> None:
"""WebSocket endpoint for host client connections."""
await websocket.accept()
registry = get_node_registry()
secret = websocket.headers.get("authorization", "")
if secret.startswith("Bearer "):
secret = secret[7:]
if not registry.validate_secret(secret):
logger.warning("Invalid router secret, rejecting connection")
await websocket.close(code=4001, reason="Invalid secret")
return
node_id: Optional[str] = None
heartbeat_task: Optional[asyncio.Task] = None
async def send_heartbeat():
"""Send periodic pings to the host client."""
try:
while True:
await asyncio.sleep(30)
try:
await websocket.send_text(encode(Heartbeat(type="ping")))
except Exception:
break
except asyncio.CancelledError:
pass
try:
async for data in websocket.iter_text():
try:
msg = decode(data)
except Exception as e:
logger.error("Failed to decode message: %s", e)
continue
if isinstance(msg, RegisterMessage):
node_id = msg.node_id
await registry.register(websocket, msg)
heartbeat_task = asyncio.create_task(send_heartbeat())
elif isinstance(msg, ForwardResponse):
from router.rpc import resolve_response
await resolve_response(msg)
elif isinstance(msg, TaskComplete):
await handle_task_complete(msg)
elif isinstance(msg, Heartbeat):
if msg.type == "pong" and node_id:
await registry.update_heartbeat(node_id)
elif isinstance(msg, NodeStatus):
await registry.update_status(msg)
else:
logger.debug("Received unhandled message type: %s", type(msg).__name__)
except WebSocketDisconnect:
logger.info("WebSocket disconnected")
except Exception as e:
logger.exception("WebSocket error: %s", e)
finally:
if heartbeat_task:
heartbeat_task.cancel()
if node_id:
await registry.unregister(node_id)