- 将 Dict、List 等传统类型注解替换为 dict、list 等现代类型注解 - 更新类型注解以更精确地反映变量类型 - 修复部分类型注解与实际使用不匹配的问题 - 优化部分代码逻辑以提高类型安全性
103 lines
2.9 KiB
Python
103 lines
2.9 KiB
Python
"""WebSocket endpoint for host client connections.
|
|
|
|
Handles:
|
|
- Connection authentication
|
|
- Node registration
|
|
- Message forwarding
|
|
- Heartbeat
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import logging
|
|
from typing import Optional
|
|
|
|
from fastapi import WebSocket, WebSocketDisconnect, WebSocketException
|
|
|
|
from router.nodes import get_node_registry
|
|
from router.rpc import handle_task_complete
|
|
from shared import (
|
|
RegisterMessage,
|
|
ForwardRequest,
|
|
ForwardResponse,
|
|
TaskComplete,
|
|
Heartbeat,
|
|
NodeStatus,
|
|
decode,
|
|
encode,
|
|
)
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
async def ws_node_endpoint(websocket: WebSocket) -> None:
|
|
"""WebSocket endpoint for host client connections."""
|
|
await websocket.accept()
|
|
|
|
registry = get_node_registry()
|
|
|
|
secret = websocket.headers.get("authorization", "")
|
|
if secret.startswith("Bearer "):
|
|
secret = secret[7:]
|
|
|
|
if not registry.validate_secret(secret):
|
|
logger.warning("Invalid router secret, rejecting connection")
|
|
await websocket.close(code=4001, reason="Invalid secret")
|
|
return
|
|
|
|
node_id: Optional[str] = None
|
|
heartbeat_task: Optional[asyncio.Task[None]] = None
|
|
|
|
async def send_heartbeat():
|
|
"""Send periodic pings to the host client."""
|
|
try:
|
|
while True:
|
|
await asyncio.sleep(30)
|
|
try:
|
|
await websocket.send_text(encode(Heartbeat(type="ping")))
|
|
except Exception:
|
|
break
|
|
except asyncio.CancelledError:
|
|
pass
|
|
|
|
try:
|
|
async for data in websocket.iter_text():
|
|
try:
|
|
msg = decode(data)
|
|
except Exception as e:
|
|
logger.error("Failed to decode message: %s", e)
|
|
continue
|
|
|
|
if isinstance(msg, RegisterMessage):
|
|
node_id = msg.node_id
|
|
await registry.register(websocket, msg)
|
|
heartbeat_task = asyncio.create_task(send_heartbeat())
|
|
|
|
elif isinstance(msg, ForwardResponse):
|
|
from router.rpc import resolve_response
|
|
await resolve_response(msg)
|
|
|
|
elif isinstance(msg, TaskComplete):
|
|
await handle_task_complete(msg)
|
|
|
|
elif isinstance(msg, Heartbeat):
|
|
if msg.type == "pong" and node_id:
|
|
await registry.update_heartbeat(node_id)
|
|
|
|
elif isinstance(msg, NodeStatus):
|
|
await registry.update_status(msg)
|
|
|
|
else:
|
|
logger.debug("Received unhandled message type: %s", type(msg).__name__)
|
|
|
|
except WebSocketDisconnect:
|
|
logger.info("WebSocket disconnected")
|
|
except Exception as e:
|
|
logger.exception("WebSocket error: %s", e)
|
|
finally:
|
|
if heartbeat_task:
|
|
heartbeat_task.cancel()
|
|
if node_id:
|
|
await registry.unregister(node_id)
|