PhoneWork/router/nodes.py
Yuyao Huang (Sam) 2a8f745b3d feat(router): 添加用户追踪和节点通知功能
在ROUTER_MODE启用时跟踪用户消息,并在节点注册/注销时通知相关用户。新增_known_users集合记录活跃用户,重构通知逻辑以支持所有已知用户或特定服务用户的通知。
2026-03-29 18:11:00 +08:00

247 lines
8.8 KiB
Python

"""Node registry for managing connected host clients.
Maintains:
- Connected nodes with their WebSocket connections
- User-to-node mapping (which users each node serves)
- Active node preference per user
"""
from __future__ import annotations
import asyncio
import logging
import time
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Set
from shared import RegisterMessage, NodeStatus
logger = logging.getLogger(__name__)
@dataclass
class NodeConnection:
"""Represents a connected host client."""
node_id: str
ws: Any
display_name: str = ""
serves_users: Set[str] = field(default_factory=set)
working_dir: str = ""
capabilities: list[str] = field(default_factory=list)
connected_at: float = field(default_factory=time.time)
last_heartbeat: float = field(default_factory=time.time)
sessions: int = 0
active_sessions: list[dict[str, Any]] = field(default_factory=list)
@property
def is_online(self) -> bool:
"""Check if node is still considered online (heartbeat within 60s)."""
return time.time() - self.last_heartbeat < 60
def to_dict(self) -> dict[str, Any]:
"""Serialize for API responses."""
return {
"node_id": self.node_id,
"display_name": self.display_name,
"status": "online" if self.is_online else "offline",
"users": len(self.serves_users),
"sessions": self.sessions,
"capabilities": self.capabilities,
"connected_at": self.connected_at,
"last_heartbeat": self.last_heartbeat,
}
class NodeRegistry:
"""Registry of connected host clients."""
def __init__(self, router_secret: str = ""):
self._nodes: dict[str, NodeConnection] = {}
self._user_nodes: dict[str, Set[str]] = {}
self._active_node: dict[str, str] = {}
self._known_users: Set[str] = set()
self._secret = router_secret
self._lock = asyncio.Lock()
def validate_secret(self, secret: str) -> bool:
"""Validate router secret."""
if not self._secret:
return True
return secret == self._secret
def track_user(self, user_id: str) -> None:
"""Record a user as known (has sent at least one message)."""
self._known_users.add(user_id)
def _get_notifiable_users(self, node: NodeConnection) -> Set[str]:
"""Get users to notify about a node event.
If the node has explicit serves_users, return those.
Otherwise (serves everyone), return all known users.
"""
if node.serves_users:
return node.serves_users
return self._known_users.copy()
async def register(self, ws: Any, msg: RegisterMessage) -> NodeConnection:
"""Register a new node connection."""
async with self._lock:
is_reconnect = msg.node_id in self._nodes
node = NodeConnection(
node_id=msg.node_id,
ws=ws,
display_name=msg.display_name or msg.node_id,
serves_users=set(msg.serves_users),
working_dir=msg.working_dir,
capabilities=msg.capabilities,
)
self._nodes[msg.node_id] = node
for user_id in msg.serves_users:
if user_id not in self._user_nodes:
self._user_nodes[user_id] = set()
self._user_nodes[user_id].add(msg.node_id)
logger.info(
"Node registered: %s (users: %s, capabilities: %s)",
msg.node_id,
msg.serves_users,
msg.capabilities,
)
users_to_notify = self._get_notifiable_users(node)
if is_reconnect:
for user_id in users_to_notify:
asyncio.create_task(self._notify_reconnect(user_id, node.display_name))
else:
for user_id in users_to_notify:
asyncio.create_task(self._notify_new_node(user_id, node.display_name))
return node
async def _notify_new_node(self, user_id: str, node_name: str) -> None:
"""Notify user about a new node coming online."""
try:
from bot.feishu import send_text
await send_text(user_id, "open_id", f"🟢 Node \"{node_name}\" is online.")
except Exception as e:
logger.error("Failed to send new node notification: %s", e)
async def _notify_reconnect(self, user_id: str, node_name: str) -> None:
"""Notify user about node reconnect."""
try:
from bot.feishu import send_text
await send_text(user_id, "open_id", f"✅ Node \"{node_name}\" reconnected.")
except Exception as e:
logger.error("Failed to send reconnect notification: %s", e)
async def unregister(self, node_id: str) -> None:
"""Unregister a node connection."""
async with self._lock:
node = self._nodes.pop(node_id, None)
if node:
users_to_notify = self._get_notifiable_users(node)
for user_id in node.serves_users:
if user_id in self._user_nodes:
self._user_nodes[user_id].discard(node_id)
if not self._user_nodes[user_id]:
del self._user_nodes[user_id]
for user_id in list(self._active_node.keys()):
if self._active_node[user_id] == node_id:
del self._active_node[user_id]
logger.info("Node unregistered: %s", node_id)
for user_id in users_to_notify:
asyncio.create_task(self._notify_disconnect(user_id, node.display_name))
async def _notify_disconnect(self, user_id: str, node_name: str) -> None:
"""Notify user about node disconnect."""
try:
from bot.feishu import send_text
await send_text(user_id, "open_id", f"⚠️ Node \"{node_name}\" disconnected.")
except Exception as e:
logger.error("Failed to send disconnect notification: %s", e)
async def update_status(self, msg: NodeStatus) -> None:
"""Update node status from heartbeat."""
async with self._lock:
node = self._nodes.get(msg.node_id)
if node:
node.sessions = msg.sessions
node.active_sessions = msg.active_sessions
node.last_heartbeat = time.time()
async def update_heartbeat(self, node_id: str) -> None:
"""Update node heartbeat timestamp."""
async with self._lock:
node = self._nodes.get(node_id)
if node:
node.last_heartbeat = time.time()
def get_node(self, node_id: str) -> Optional[NodeConnection]:
"""Get a node by ID."""
return self._nodes.get(node_id)
def get_nodes_for_user(self, user_id: str) -> list[NodeConnection]:
"""Get all nodes that serve a user."""
# Get nodes explicitly mapped to this user
user_node_ids = self._user_nodes.get(user_id, set())
# Get nodes that serve all users (empty serves_users set)
all_users_node_ids = set()
for node_id, node in self._nodes.items():
if not node.serves_users:
all_users_node_ids.add(node_id)
# Combine both sets
node_ids = user_node_ids | all_users_node_ids
return [self._nodes[nid] for nid in node_ids if nid in self._nodes]
def get_active_node(self, user_id: str) -> Optional[NodeConnection]:
"""Get the active node for a user."""
node_id = self._active_node.get(user_id)
if node_id:
return self._nodes.get(node_id)
nodes = self.get_nodes_for_user(user_id)
if nodes:
online = [n for n in nodes if n.is_online]
if online:
return online[0]
return None
def set_active_node(self, user_id: str, node_id: str) -> bool:
"""Set the active node for a user."""
if node_id not in self._nodes:
return False
self._active_node[user_id] = node_id
logger.info("Active node for user %s set to %s", user_id, node_id)
return True
def list_nodes(self) -> list[dict[str, Any]]:
"""List all nodes with their status."""
return [node.to_dict() for node in self._nodes.values()]
def get_affected_users(self, node_id: str) -> list[str]:
"""Get users affected by a node disconnect."""
node = self._nodes.get(node_id)
if node:
return list(node.serves_users)
return []
node_registry: Optional[NodeRegistry] = None
def get_node_registry() -> NodeRegistry:
"""Get the global node registry instance."""
global node_registry
if node_registry is None:
node_registry = NodeRegistry()
return node_registry