新增路由器、主机客户端和共享协议模块,支持多主机部署模式: - 路由器作为中央节点管理主机连接和消息路由 - 主机客户端作为工作节点运行本地代理 - 共享协议定义通信消息格式 - 新增独立运行模式standalone.py - 更新配置系统支持路由模式
210 lines
7.3 KiB
Python
210 lines
7.3 KiB
Python
"""Node registry for managing connected host clients.
|
|
|
|
Maintains:
|
|
- Connected nodes with their WebSocket connections
|
|
- User-to-node mapping (which users each node serves)
|
|
- Active node preference per user
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import logging
|
|
import time
|
|
from dataclasses import dataclass, field
|
|
from typing import Any, Dict, List, Optional, Set
|
|
|
|
from shared import RegisterMessage, NodeStatus
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
@dataclass
|
|
class NodeConnection:
|
|
"""Represents a connected host client."""
|
|
node_id: str
|
|
ws: Any
|
|
display_name: str = ""
|
|
serves_users: Set[str] = field(default_factory=set)
|
|
working_dir: str = ""
|
|
capabilities: List[str] = field(default_factory=list)
|
|
connected_at: float = field(default_factory=time.time)
|
|
last_heartbeat: float = field(default_factory=time.time)
|
|
sessions: int = 0
|
|
active_sessions: List[Dict[str, Any]] = field(default_factory=list)
|
|
|
|
@property
|
|
def is_online(self) -> bool:
|
|
"""Check if node is still considered online (heartbeat within 60s)."""
|
|
return time.time() - self.last_heartbeat < 60
|
|
|
|
def to_dict(self) -> Dict[str, Any]:
|
|
"""Serialize for API responses."""
|
|
return {
|
|
"node_id": self.node_id,
|
|
"display_name": self.display_name,
|
|
"status": "online" if self.is_online else "offline",
|
|
"users": len(self.serves_users),
|
|
"sessions": self.sessions,
|
|
"capabilities": self.capabilities,
|
|
"connected_at": self.connected_at,
|
|
"last_heartbeat": self.last_heartbeat,
|
|
}
|
|
|
|
|
|
class NodeRegistry:
|
|
"""Registry of connected host clients."""
|
|
|
|
def __init__(self, router_secret: str = ""):
|
|
self._nodes: Dict[str, NodeConnection] = {}
|
|
self._user_nodes: Dict[str, Set[str]] = {}
|
|
self._active_node: Dict[str, str] = {}
|
|
self._secret = router_secret
|
|
self._lock = asyncio.Lock()
|
|
|
|
def validate_secret(self, secret: str) -> bool:
|
|
"""Validate router secret."""
|
|
if not self._secret:
|
|
return True
|
|
return secret == self._secret
|
|
|
|
async def register(self, ws: Any, msg: RegisterMessage) -> NodeConnection:
|
|
"""Register a new node connection."""
|
|
async with self._lock:
|
|
is_reconnect = msg.node_id in self._nodes
|
|
|
|
node = NodeConnection(
|
|
node_id=msg.node_id,
|
|
ws=ws,
|
|
display_name=msg.display_name or msg.node_id,
|
|
serves_users=set(msg.serves_users),
|
|
working_dir=msg.working_dir,
|
|
capabilities=msg.capabilities,
|
|
)
|
|
self._nodes[msg.node_id] = node
|
|
|
|
for user_id in msg.serves_users:
|
|
if user_id not in self._user_nodes:
|
|
self._user_nodes[user_id] = set()
|
|
self._user_nodes[user_id].add(msg.node_id)
|
|
|
|
logger.info(
|
|
"Node registered: %s (users: %s, capabilities: %s)",
|
|
msg.node_id,
|
|
msg.serves_users,
|
|
msg.capabilities,
|
|
)
|
|
|
|
if is_reconnect:
|
|
for user_id in msg.serves_users:
|
|
asyncio.create_task(self._notify_reconnect(user_id, node.display_name))
|
|
|
|
return node
|
|
|
|
async def _notify_reconnect(self, user_id: str, node_name: str) -> None:
|
|
"""Notify user about node reconnect."""
|
|
try:
|
|
from bot.feishu import send_text
|
|
await send_text(user_id, "open_id", f"✅ Node \"{node_name}\" reconnected.")
|
|
except Exception as e:
|
|
logger.error("Failed to send reconnect notification: %s", e)
|
|
|
|
async def unregister(self, node_id: str) -> None:
|
|
"""Unregister a node connection."""
|
|
async with self._lock:
|
|
node = self._nodes.pop(node_id, None)
|
|
if node:
|
|
affected_users = list(node.serves_users)
|
|
|
|
for user_id in node.serves_users:
|
|
if user_id in self._user_nodes:
|
|
self._user_nodes[user_id].discard(node_id)
|
|
if not self._user_nodes[user_id]:
|
|
del self._user_nodes[user_id]
|
|
|
|
for user_id in list(self._active_node.keys()):
|
|
if self._active_node[user_id] == node_id:
|
|
del self._active_node[user_id]
|
|
|
|
logger.info("Node unregistered: %s", node_id)
|
|
|
|
for user_id in affected_users:
|
|
asyncio.create_task(self._notify_disconnect(user_id, node.display_name))
|
|
|
|
async def _notify_disconnect(self, user_id: str, node_name: str) -> None:
|
|
"""Notify user about node disconnect."""
|
|
try:
|
|
from bot.feishu import send_text
|
|
await send_text(user_id, "open_id", f"⚠️ Node \"{node_name}\" disconnected.")
|
|
except Exception as e:
|
|
logger.error("Failed to send disconnect notification: %s", e)
|
|
|
|
async def update_status(self, msg: NodeStatus) -> None:
|
|
"""Update node status from heartbeat."""
|
|
async with self._lock:
|
|
node = self._nodes.get(msg.node_id)
|
|
if node:
|
|
node.sessions = msg.sessions
|
|
node.active_sessions = msg.active_sessions
|
|
node.last_heartbeat = time.time()
|
|
|
|
async def update_heartbeat(self, node_id: str) -> None:
|
|
"""Update node heartbeat timestamp."""
|
|
async with self._lock:
|
|
node = self._nodes.get(node_id)
|
|
if node:
|
|
node.last_heartbeat = time.time()
|
|
|
|
def get_node(self, node_id: str) -> Optional[NodeConnection]:
|
|
"""Get a node by ID."""
|
|
return self._nodes.get(node_id)
|
|
|
|
def get_nodes_for_user(self, user_id: str) -> List[NodeConnection]:
|
|
"""Get all nodes that serve a user."""
|
|
node_ids = self._user_nodes.get(user_id, set())
|
|
return [self._nodes[nid] for nid in node_ids if nid in self._nodes]
|
|
|
|
def get_active_node(self, user_id: str) -> Optional[NodeConnection]:
|
|
"""Get the active node for a user."""
|
|
node_id = self._active_node.get(user_id)
|
|
if node_id:
|
|
return self._nodes.get(node_id)
|
|
|
|
nodes = self.get_nodes_for_user(user_id)
|
|
if nodes:
|
|
online = [n for n in nodes if n.is_online]
|
|
if online:
|
|
return online[0]
|
|
|
|
return None
|
|
|
|
def set_active_node(self, user_id: str, node_id: str) -> bool:
|
|
"""Set the active node for a user."""
|
|
if node_id not in self._nodes:
|
|
return False
|
|
self._active_node[user_id] = node_id
|
|
logger.info("Active node for user %s set to %s", user_id, node_id)
|
|
return True
|
|
|
|
def list_nodes(self) -> List[Dict[str, Any]]:
|
|
"""List all nodes with their status."""
|
|
return [node.to_dict() for node in self._nodes.values()]
|
|
|
|
def get_affected_users(self, node_id: str) -> List[str]:
|
|
"""Get users affected by a node disconnect."""
|
|
node = self._nodes.get(node_id)
|
|
if node:
|
|
return list(node.serves_users)
|
|
return []
|
|
|
|
|
|
node_registry: Optional[NodeRegistry] = None
|
|
|
|
|
|
def get_node_registry() -> NodeRegistry:
|
|
"""Get the global node registry instance."""
|
|
global node_registry
|
|
if node_registry is None:
|
|
node_registry = NodeRegistry()
|
|
return node_registry
|