"""Node registry for managing connected host clients. Maintains: - Connected nodes with their WebSocket connections - User-to-node mapping (which users each node serves) - Active node preference per user """ from __future__ import annotations import asyncio import logging import time from dataclasses import dataclass, field from typing import Any, Dict, List, Optional, Set from shared import RegisterMessage, NodeStatus logger = logging.getLogger(__name__) @dataclass class NodeConnection: """Represents a connected host client.""" node_id: str ws: Any display_name: str = "" serves_users: Set[str] = field(default_factory=set) working_dir: str = "" capabilities: list[str] = field(default_factory=list) connected_at: float = field(default_factory=time.time) last_heartbeat: float = field(default_factory=time.time) sessions: int = 0 active_sessions: list[dict[str, Any]] = field(default_factory=list) @property def is_online(self) -> bool: """Check if node is still considered online (heartbeat within 60s).""" return time.time() - self.last_heartbeat < 60 def to_dict(self) -> dict[str, Any]: """Serialize for API responses.""" return { "node_id": self.node_id, "display_name": self.display_name, "status": "online" if self.is_online else "offline", "users": len(self.serves_users), "sessions": self.sessions, "capabilities": self.capabilities, "connected_at": self.connected_at, "last_heartbeat": self.last_heartbeat, } class NodeRegistry: """Registry of connected host clients.""" def __init__(self, router_secret: str = ""): self._nodes: dict[str, NodeConnection] = {} self._user_nodes: dict[str, Set[str]] = {} self._active_node: dict[str, str] = {} self._known_users: Set[str] = set() self._secret = router_secret self._lock = asyncio.Lock() def validate_secret(self, secret: str) -> bool: """Validate router secret.""" if not self._secret: return True return secret == self._secret def track_user(self, user_id: str) -> None: """Record a user as known (has sent at least one message).""" self._known_users.add(user_id) def _get_notifiable_users(self, node: NodeConnection) -> Set[str]: """Get users to notify about a node event. If the node has explicit serves_users, return those. Otherwise (serves everyone), return all known users. """ if node.serves_users: return node.serves_users return self._known_users.copy() async def register(self, ws: Any, msg: RegisterMessage) -> NodeConnection: """Register a new node connection.""" async with self._lock: is_reconnect = msg.node_id in self._nodes node = NodeConnection( node_id=msg.node_id, ws=ws, display_name=msg.display_name or msg.node_id, serves_users=set(msg.serves_users), working_dir=msg.working_dir, capabilities=msg.capabilities, ) self._nodes[msg.node_id] = node for user_id in msg.serves_users: if user_id not in self._user_nodes: self._user_nodes[user_id] = set() self._user_nodes[user_id].add(msg.node_id) logger.info( "Node registered: %s (users: %s, capabilities: %s)", msg.node_id, msg.serves_users, msg.capabilities, ) users_to_notify = self._get_notifiable_users(node) if is_reconnect: for user_id in users_to_notify: asyncio.create_task(self._notify_reconnect(user_id, node.display_name)) else: for user_id in users_to_notify: asyncio.create_task(self._notify_new_node(user_id, node.display_name)) return node async def _notify_new_node(self, user_id: str, node_name: str) -> None: """Notify user about a new node coming online.""" try: from bot.feishu import send_text await send_text(user_id, "open_id", f"🟢 Node \"{node_name}\" is online.") except Exception as e: logger.error("Failed to send new node notification: %s", e) async def _notify_reconnect(self, user_id: str, node_name: str) -> None: """Notify user about node reconnect.""" try: from bot.feishu import send_text await send_text(user_id, "open_id", f"✅ Node \"{node_name}\" reconnected.") except Exception as e: logger.error("Failed to send reconnect notification: %s", e) async def unregister(self, node_id: str) -> None: """Unregister a node connection.""" async with self._lock: node = self._nodes.pop(node_id, None) if node: users_to_notify = self._get_notifiable_users(node) for user_id in node.serves_users: if user_id in self._user_nodes: self._user_nodes[user_id].discard(node_id) if not self._user_nodes[user_id]: del self._user_nodes[user_id] for user_id in list(self._active_node.keys()): if self._active_node[user_id] == node_id: del self._active_node[user_id] logger.info("Node unregistered: %s", node_id) for user_id in users_to_notify: asyncio.create_task(self._notify_disconnect(user_id, node.display_name)) async def _notify_disconnect(self, user_id: str, node_name: str) -> None: """Notify user about node disconnect.""" try: from bot.feishu import send_text await send_text(user_id, "open_id", f"⚠️ Node \"{node_name}\" disconnected.") except Exception as e: logger.error("Failed to send disconnect notification: %s", e) async def update_status(self, msg: NodeStatus) -> None: """Update node status from heartbeat.""" async with self._lock: node = self._nodes.get(msg.node_id) if node: node.sessions = msg.sessions node.active_sessions = msg.active_sessions node.last_heartbeat = time.time() async def update_heartbeat(self, node_id: str) -> None: """Update node heartbeat timestamp.""" async with self._lock: node = self._nodes.get(node_id) if node: node.last_heartbeat = time.time() def get_node(self, node_id: str) -> Optional[NodeConnection]: """Get a node by ID.""" return self._nodes.get(node_id) def get_nodes_for_user(self, user_id: str) -> list[NodeConnection]: """Get all nodes that serve a user.""" # Get nodes explicitly mapped to this user user_node_ids = self._user_nodes.get(user_id, set()) # Get nodes that serve all users (empty serves_users set) all_users_node_ids = set() for node_id, node in self._nodes.items(): if not node.serves_users: all_users_node_ids.add(node_id) # Combine both sets node_ids = user_node_ids | all_users_node_ids return [self._nodes[nid] for nid in node_ids if nid in self._nodes] def get_active_node(self, user_id: str) -> Optional[NodeConnection]: """Get the active node for a user.""" node_id = self._active_node.get(user_id) if node_id: return self._nodes.get(node_id) nodes = self.get_nodes_for_user(user_id) if nodes: online = [n for n in nodes if n.is_online] if online: return online[0] return None def set_active_node(self, user_id: str, node_id: str) -> bool: """Set the active node for a user.""" if node_id not in self._nodes: return False self._active_node[user_id] = node_id logger.info("Active node for user %s set to %s", user_id, node_id) return True def list_nodes(self) -> list[dict[str, Any]]: """List all nodes with their status.""" return [node.to_dict() for node in self._nodes.values()] def get_affected_users(self, node_id: str) -> list[str]: """Get users affected by a node disconnect.""" node = self._nodes.get(node_id) if node: return list(node.serves_users) return [] node_registry: Optional[NodeRegistry] = None def get_node_registry() -> NodeRegistry: """Get the global node registry instance.""" global node_registry if node_registry is None: node_registry = NodeRegistry() return node_registry