Yuyao Huang (Sam) 09b63341cd refactor: 统一使用现代类型注解替代传统类型注解
- 将 Dict、List 等传统类型注解替换为 dict、list 等现代类型注解
- 更新类型注解以更精确地反映变量类型
- 修复部分类型注解与实际使用不匹配的问题
- 优化部分代码逻辑以提高类型安全性
2026-03-28 14:27:21 +08:00

103 lines
2.9 KiB
Python

"""WebSocket endpoint for host client connections.
Handles:
- Connection authentication
- Node registration
- Message forwarding
- Heartbeat
"""
from __future__ import annotations
import asyncio
import logging
from typing import Optional
from fastapi import WebSocket, WebSocketDisconnect, WebSocketException
from router.nodes import get_node_registry
from router.rpc import handle_task_complete
from shared import (
RegisterMessage,
ForwardRequest,
ForwardResponse,
TaskComplete,
Heartbeat,
NodeStatus,
decode,
encode,
)
logger = logging.getLogger(__name__)
async def ws_node_endpoint(websocket: WebSocket) -> None:
"""WebSocket endpoint for host client connections."""
await websocket.accept()
registry = get_node_registry()
secret = websocket.headers.get("authorization", "")
if secret.startswith("Bearer "):
secret = secret[7:]
if not registry.validate_secret(secret):
logger.warning("Invalid router secret, rejecting connection")
await websocket.close(code=4001, reason="Invalid secret")
return
node_id: Optional[str] = None
heartbeat_task: Optional[asyncio.Task[None]] = None
async def send_heartbeat():
"""Send periodic pings to the host client."""
try:
while True:
await asyncio.sleep(30)
try:
await websocket.send_text(encode(Heartbeat(type="ping")))
except Exception:
break
except asyncio.CancelledError:
pass
try:
async for data in websocket.iter_text():
try:
msg = decode(data)
except Exception as e:
logger.error("Failed to decode message: %s", e)
continue
if isinstance(msg, RegisterMessage):
node_id = msg.node_id
await registry.register(websocket, msg)
heartbeat_task = asyncio.create_task(send_heartbeat())
elif isinstance(msg, ForwardResponse):
from router.rpc import resolve_response
await resolve_response(msg)
elif isinstance(msg, TaskComplete):
await handle_task_complete(msg)
elif isinstance(msg, Heartbeat):
if msg.type == "pong" and node_id:
await registry.update_heartbeat(node_id)
elif isinstance(msg, NodeStatus):
await registry.update_status(msg)
else:
logger.debug("Received unhandled message type: %s", type(msg).__name__)
except WebSocketDisconnect:
logger.info("WebSocket disconnected")
except Exception as e:
logger.exception("WebSocket error: %s", e)
finally:
if heartbeat_task:
heartbeat_task.cancel()
if node_id:
await registry.unregister(node_id)