Yuyao Huang (Sam) 80e4953cf9 feat: 优化WebSocket连接和心跳机制
- 在main.py和standalone.py中添加ws_ping_interval和ws_ping_timeout配置
- 调整ws.py中的心跳发送逻辑,先发送ping再等待
- 在host_client中优化消息处理,使用任务队列处理转发请求
- 更新WebTool以适配新的API格式并增加搜索结果限制
- 在agent.py中添加日期显示和web调用次数限制
- 修复bot/handler.py中的事件循环问题
2026-03-28 15:53:44 +08:00

103 lines
2.9 KiB
Python

"""WebSocket endpoint for host client connections.
Handles:
- Connection authentication
- Node registration
- Message forwarding
- Heartbeat
"""
from __future__ import annotations
import asyncio
import logging
from typing import Optional
from fastapi import WebSocket, WebSocketDisconnect, WebSocketException
from router.nodes import get_node_registry
from router.rpc import handle_task_complete
from shared import (
RegisterMessage,
ForwardRequest,
ForwardResponse,
TaskComplete,
Heartbeat,
NodeStatus,
decode,
encode,
)
logger = logging.getLogger(__name__)
async def ws_node_endpoint(websocket: WebSocket) -> None:
"""WebSocket endpoint for host client connections."""
await websocket.accept()
registry = get_node_registry()
secret = websocket.headers.get("authorization", "")
if secret.startswith("Bearer "):
secret = secret[7:]
if not registry.validate_secret(secret):
logger.warning("Invalid router secret, rejecting connection")
await websocket.close(code=4001, reason="Invalid secret")
return
node_id: Optional[str] = None
heartbeat_task: Optional[asyncio.Task[None]] = None
async def send_heartbeat():
"""Send periodic pings to the host client."""
try:
while True:
try:
await websocket.send_text(encode(Heartbeat(type="ping")))
except Exception:
break
await asyncio.sleep(30)
except asyncio.CancelledError:
pass
try:
async for data in websocket.iter_text():
try:
msg = decode(data)
except Exception as e:
logger.error("Failed to decode message: %s", e)
continue
if isinstance(msg, RegisterMessage):
node_id = msg.node_id
await registry.register(websocket, msg)
heartbeat_task = asyncio.create_task(send_heartbeat())
elif isinstance(msg, ForwardResponse):
from router.rpc import resolve_response
await resolve_response(msg)
elif isinstance(msg, TaskComplete):
await handle_task_complete(msg)
elif isinstance(msg, Heartbeat):
if msg.type == "pong" and node_id:
await registry.update_heartbeat(node_id)
elif isinstance(msg, NodeStatus):
await registry.update_status(msg)
else:
logger.debug("Received unhandled message type: %s", type(msg).__name__)
except WebSocketDisconnect:
logger.info("WebSocket disconnected")
except Exception as e:
logger.exception("WebSocket error: %s", e)
finally:
if heartbeat_task:
heartbeat_task.cancel()
if node_id:
await registry.unregister(node_id)