- 在main.py和standalone.py中添加ws_ping_interval和ws_ping_timeout配置 - 调整ws.py中的心跳发送逻辑,先发送ping再等待 - 在host_client中优化消息处理,使用任务队列处理转发请求 - 更新WebTool以适配新的API格式并增加搜索结果限制 - 在agent.py中添加日期显示和web调用次数限制 - 修复bot/handler.py中的事件循环问题
103 lines
2.9 KiB
Python
103 lines
2.9 KiB
Python
"""WebSocket endpoint for host client connections.
|
|
|
|
Handles:
|
|
- Connection authentication
|
|
- Node registration
|
|
- Message forwarding
|
|
- Heartbeat
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import logging
|
|
from typing import Optional
|
|
|
|
from fastapi import WebSocket, WebSocketDisconnect, WebSocketException
|
|
|
|
from router.nodes import get_node_registry
|
|
from router.rpc import handle_task_complete
|
|
from shared import (
|
|
RegisterMessage,
|
|
ForwardRequest,
|
|
ForwardResponse,
|
|
TaskComplete,
|
|
Heartbeat,
|
|
NodeStatus,
|
|
decode,
|
|
encode,
|
|
)
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
async def ws_node_endpoint(websocket: WebSocket) -> None:
|
|
"""WebSocket endpoint for host client connections."""
|
|
await websocket.accept()
|
|
|
|
registry = get_node_registry()
|
|
|
|
secret = websocket.headers.get("authorization", "")
|
|
if secret.startswith("Bearer "):
|
|
secret = secret[7:]
|
|
|
|
if not registry.validate_secret(secret):
|
|
logger.warning("Invalid router secret, rejecting connection")
|
|
await websocket.close(code=4001, reason="Invalid secret")
|
|
return
|
|
|
|
node_id: Optional[str] = None
|
|
heartbeat_task: Optional[asyncio.Task[None]] = None
|
|
|
|
async def send_heartbeat():
|
|
"""Send periodic pings to the host client."""
|
|
try:
|
|
while True:
|
|
try:
|
|
await websocket.send_text(encode(Heartbeat(type="ping")))
|
|
except Exception:
|
|
break
|
|
await asyncio.sleep(30)
|
|
except asyncio.CancelledError:
|
|
pass
|
|
|
|
try:
|
|
async for data in websocket.iter_text():
|
|
try:
|
|
msg = decode(data)
|
|
except Exception as e:
|
|
logger.error("Failed to decode message: %s", e)
|
|
continue
|
|
|
|
if isinstance(msg, RegisterMessage):
|
|
node_id = msg.node_id
|
|
await registry.register(websocket, msg)
|
|
heartbeat_task = asyncio.create_task(send_heartbeat())
|
|
|
|
elif isinstance(msg, ForwardResponse):
|
|
from router.rpc import resolve_response
|
|
await resolve_response(msg)
|
|
|
|
elif isinstance(msg, TaskComplete):
|
|
await handle_task_complete(msg)
|
|
|
|
elif isinstance(msg, Heartbeat):
|
|
if msg.type == "pong" and node_id:
|
|
await registry.update_heartbeat(node_id)
|
|
|
|
elif isinstance(msg, NodeStatus):
|
|
await registry.update_status(msg)
|
|
|
|
else:
|
|
logger.debug("Received unhandled message type: %s", type(msg).__name__)
|
|
|
|
except WebSocketDisconnect:
|
|
logger.info("WebSocket disconnected")
|
|
except Exception as e:
|
|
logger.exception("WebSocket error: %s", e)
|
|
finally:
|
|
if heartbeat_task:
|
|
heartbeat_task.cancel()
|
|
if node_id:
|
|
await registry.unregister(node_id)
|