feat: 实现多主机架构的核心组件
新增路由器、主机客户端和共享协议模块,支持多主机部署模式: - 路由器作为中央节点管理主机连接和消息路由 - 主机客户端作为工作节点运行本地代理 - 共享协议定义通信消息格式 - 新增独立运行模式standalone.py - 更新配置系统支持路由模式
This commit is contained in:
parent
8ecc701d5e
commit
64297e5e27
33
ROADMAP.md
33
ROADMAP.md
@ -529,3 +529,36 @@ PhoneWork/
|
|||||||
- [ ] Host client reconnects → re-registered, messages flow again
|
- [ ] Host client reconnects → re-registered, messages flow again
|
||||||
- [ ] Long CC task on node finishes → router forwards completion notification to Feishu
|
- [ ] Long CC task on node finishes → router forwards completion notification to Feishu
|
||||||
- [ ] Wrong `ROUTER_SECRET` → connection rejected with 401
|
- [ ] Wrong `ROUTER_SECRET` → connection rejected with 401
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## M3 Implementation Notes (from M2 code review)
|
||||||
|
|
||||||
|
Three concrete details discovered from reading the actual M2 code that must be handled
|
||||||
|
during M3 implementation:
|
||||||
|
|
||||||
|
### 1. `bot/commands.py` accesses node-local state directly
|
||||||
|
|
||||||
|
The current `commands.py` calls `agent._active_conv`, `manager.list_sessions()`,
|
||||||
|
`task_runner.list_tasks()`, `scheduler` — all of which move to the host client in M3.
|
||||||
|
|
||||||
|
**Resolution:** At the router, `bot/commands.py` is reduced to two commands:
|
||||||
|
`/nodes` and `/node <name>`. All other slash commands (`/new`, `/status`, `/close`,
|
||||||
|
`/switch`, `/direct`, `/smart`, `/shell`, `/tasks`, `/remind`) are forwarded to the
|
||||||
|
active node as-is — the node's mailboy handles them using its local `commands.py`.
|
||||||
|
The node's command handler remains unchanged from M2.
|
||||||
|
|
||||||
|
### 2. `chat_id` must be forwarded to the node
|
||||||
|
|
||||||
|
`bot/handler.py` calls `set_current_chat(chat_id)` before invoking the agent.
|
||||||
|
In M3, `handler.py` stays at the router but the agent (and `set_current_chat`) moves
|
||||||
|
to the node. The `chat_id` travels in `ForwardRequest` (already planned), and
|
||||||
|
`host_client/main.py` must call `set_current_chat(msg.chat_id)` before invoking the
|
||||||
|
local `agent.run()`. This is essential for `FileSendTool` and `SchedulerTool` to work.
|
||||||
|
|
||||||
|
### 3. `orchestrator/tools.py` imports `config.WORKING_DIR`
|
||||||
|
|
||||||
|
`_resolve_dir()` imports `WORKING_DIR` from root `config.py`. When `orchestrator/`
|
||||||
|
moves to the host client, this import must switch to `host_client/config.py`.
|
||||||
|
In standalone mode, `host_client/config.py` can re-export from root `config.py` to
|
||||||
|
keep a single `keyring.yaml`.
|
||||||
|
|||||||
@ -30,6 +30,7 @@ class BackgroundTask:
|
|||||||
result: Optional[str] = None
|
result: Optional[str] = None
|
||||||
error: Optional[str] = None
|
error: Optional[str] = None
|
||||||
notify_chat_id: Optional[str] = None
|
notify_chat_id: Optional[str] = None
|
||||||
|
user_id: Optional[str] = None
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def elapsed(self) -> float:
|
def elapsed(self) -> float:
|
||||||
@ -44,12 +45,18 @@ class TaskRunner:
|
|||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
self._tasks: Dict[str, BackgroundTask] = {}
|
self._tasks: Dict[str, BackgroundTask] = {}
|
||||||
self._lock = asyncio.Lock()
|
self._lock = asyncio.Lock()
|
||||||
|
self._notification_handler: Optional[Callable] = None
|
||||||
|
|
||||||
|
def set_notification_handler(self, handler: Optional[Callable]) -> None:
|
||||||
|
"""Set custom notification handler for M3 mode (host client -> router)."""
|
||||||
|
self._notification_handler = handler
|
||||||
|
|
||||||
async def submit(
|
async def submit(
|
||||||
self,
|
self,
|
||||||
coro: Callable[[], Any],
|
coro: Callable[[], Any],
|
||||||
description: str,
|
description: str,
|
||||||
notify_chat_id: Optional[str] = None,
|
notify_chat_id: Optional[str] = None,
|
||||||
|
user_id: Optional[str] = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Submit a coroutine as a background task."""
|
"""Submit a coroutine as a background task."""
|
||||||
task_id = str(uuid.uuid4())[:8]
|
task_id = str(uuid.uuid4())[:8]
|
||||||
@ -59,6 +66,7 @@ class TaskRunner:
|
|||||||
started_at=time.time(),
|
started_at=time.time(),
|
||||||
status=TaskStatus.PENDING,
|
status=TaskStatus.PENDING,
|
||||||
notify_chat_id=notify_chat_id,
|
notify_chat_id=notify_chat_id,
|
||||||
|
user_id=user_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
async with self._lock:
|
async with self._lock:
|
||||||
@ -94,7 +102,10 @@ class TaskRunner:
|
|||||||
logger.exception("Task %s failed: %s", task_id, exc)
|
logger.exception("Task %s failed: %s", task_id, exc)
|
||||||
|
|
||||||
if task.notify_chat_id:
|
if task.notify_chat_id:
|
||||||
await self._send_notification(task)
|
if self._notification_handler:
|
||||||
|
await self._notification_handler(task)
|
||||||
|
else:
|
||||||
|
await self._send_notification(task)
|
||||||
|
|
||||||
async def _send_notification(self, task: BackgroundTask) -> None:
|
async def _send_notification(self, task: BackgroundTask) -> None:
|
||||||
"""Send Feishu notification about task completion."""
|
"""Send Feishu notification about task completion."""
|
||||||
|
|||||||
@ -67,6 +67,8 @@ async def handle_command(user_id: str, text: str) -> Optional[str]:
|
|||||||
return await _cmd_shell(args)
|
return await _cmd_shell(args)
|
||||||
elif cmd == "/remind":
|
elif cmd == "/remind":
|
||||||
return await _cmd_remind(args)
|
return await _cmd_remind(args)
|
||||||
|
elif cmd in ("/nodes", "/node"):
|
||||||
|
return await _cmd_nodes(user_id, args)
|
||||||
else:
|
else:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@ -278,6 +280,39 @@ async def _cmd_remind(args: str) -> str:
|
|||||||
return f"⏰ Reminder #{job_id} set for {value}{unit} from now"
|
return f"⏰ Reminder #{job_id} set for {value}{unit} from now"
|
||||||
|
|
||||||
|
|
||||||
|
async def _cmd_nodes(user_id: str, args: str) -> str:
|
||||||
|
"""List nodes or switch active node."""
|
||||||
|
from config import ROUTER_MODE
|
||||||
|
if not ROUTER_MODE:
|
||||||
|
return "Not in router mode. Run standalone.py for multi-host support."
|
||||||
|
|
||||||
|
from router.nodes import get_node_registry
|
||||||
|
registry = get_node_registry()
|
||||||
|
|
||||||
|
if args:
|
||||||
|
args = args.strip()
|
||||||
|
if registry.set_active_node(user_id, args):
|
||||||
|
return f"✓ Active node set to: {args}"
|
||||||
|
return f"Error: Node '{args}' not found"
|
||||||
|
|
||||||
|
nodes = registry.list_nodes()
|
||||||
|
if not nodes:
|
||||||
|
return "No nodes connected."
|
||||||
|
|
||||||
|
active_node_id = None
|
||||||
|
active_node = registry.get_active_node(user_id)
|
||||||
|
if active_node:
|
||||||
|
active_node_id = active_node.node_id
|
||||||
|
|
||||||
|
lines = ["**Connected Nodes:**\n"]
|
||||||
|
for n in nodes:
|
||||||
|
marker = "→ " if n["node_id"] == active_node_id else " "
|
||||||
|
status = "🟢" if n["status"] == "online" else "🔴"
|
||||||
|
lines.append(f"{marker}{n['display_name']} {status} sessions={n['sessions']}")
|
||||||
|
lines.append("\nUse `/node <name>` to switch active node.")
|
||||||
|
return "\n".join(lines)
|
||||||
|
|
||||||
|
|
||||||
def _cmd_help() -> str:
|
def _cmd_help() -> str:
|
||||||
"""Show help."""
|
"""Show help."""
|
||||||
return """**Commands:**
|
return """**Commands:**
|
||||||
@ -290,5 +325,7 @@ def _cmd_help() -> str:
|
|||||||
/shell <cmd> - Run shell command (bypasses LLM)
|
/shell <cmd> - Run shell command (bypasses LLM)
|
||||||
/remind <time> <msg> - Set reminder (e.g. /remind 10m check build)
|
/remind <time> <msg> - Set reminder (e.g. /remind 10m check build)
|
||||||
/tasks - List background tasks
|
/tasks - List background tasks
|
||||||
|
/nodes - List connected host nodes
|
||||||
|
/node <name> - Switch active node
|
||||||
/retry - Retry last message
|
/retry - Retry last message
|
||||||
/help - Show this help"""
|
/help - Show this help"""
|
||||||
|
|||||||
@ -86,7 +86,7 @@ def _handle_message(data: P2ImMessageReceiveV1) -> None:
|
|||||||
|
|
||||||
|
|
||||||
async def _process_message(user_id: str, chat_id: str, text: str) -> None:
|
async def _process_message(user_id: str, chat_id: str, text: str) -> None:
|
||||||
"""Process message: check allowlist, then commands, then agent."""
|
"""Process message: check allowlist, then commands, then route to node or local agent."""
|
||||||
try:
|
try:
|
||||||
set_current_chat(chat_id)
|
set_current_chat(chat_id)
|
||||||
|
|
||||||
@ -96,10 +96,48 @@ async def _process_message(user_id: str, chat_id: str, text: str) -> None:
|
|||||||
return
|
return
|
||||||
|
|
||||||
reply = await handle_command(user_id, text)
|
reply = await handle_command(user_id, text)
|
||||||
if reply is None:
|
if reply is not None:
|
||||||
|
if reply:
|
||||||
|
await send_text(chat_id, "chat_id", reply)
|
||||||
|
return
|
||||||
|
|
||||||
|
from config import ROUTER_MODE
|
||||||
|
if ROUTER_MODE:
|
||||||
|
from router.routing_agent import route
|
||||||
|
from router.rpc import forward
|
||||||
|
from router.nodes import get_node_registry
|
||||||
|
|
||||||
|
node_id, reason = await route(user_id, chat_id, text)
|
||||||
|
|
||||||
|
if node_id is None:
|
||||||
|
await send_text(chat_id, "chat_id", f"No host available: {reason}")
|
||||||
|
return
|
||||||
|
|
||||||
|
if node_id == "meta":
|
||||||
|
registry = get_node_registry()
|
||||||
|
nodes = registry.list_nodes()
|
||||||
|
if nodes:
|
||||||
|
lines = ["Connected Nodes:"]
|
||||||
|
for n in nodes:
|
||||||
|
marker = " → " if n.get("node_id") == registry.get_active_node(user_id) else " "
|
||||||
|
lines.append(f"{marker}{n['display_name']} sessions={n['sessions']} {n['status']}")
|
||||||
|
lines.append("\nUse \"/node <name>\" to switch active node.")
|
||||||
|
await send_text(chat_id, "chat_id", "\n".join(lines))
|
||||||
|
else:
|
||||||
|
await send_text(chat_id, "chat_id", "No nodes connected.")
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
reply = await forward(node_id, user_id, chat_id, text)
|
||||||
|
if reply:
|
||||||
|
await send_text(chat_id, "chat_id", reply)
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception("Failed to forward to node %s", node_id)
|
||||||
|
await send_text(chat_id, "chat_id", f"Error communicating with node: {e}")
|
||||||
|
else:
|
||||||
reply = await agent.run(user_id, text)
|
reply = await agent.run(user_id, text)
|
||||||
if reply:
|
if reply:
|
||||||
await send_text(chat_id, "chat_id", reply)
|
await send_text(chat_id, "chat_id", reply)
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception("Error processing message for user %s", user_id)
|
logger.exception("Error processing message for user %s", user_id)
|
||||||
|
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
import yaml
|
import yaml
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import List
|
from typing import List, Optional
|
||||||
|
|
||||||
_CONFIG_PATH = Path(__file__).parent / "keyring.yaml"
|
_CONFIG_PATH = Path(__file__).parent / "keyring.yaml"
|
||||||
|
|
||||||
@ -20,6 +20,9 @@ OPENAI_MODEL: str = _cfg.get("OPENAI_MODEL", "glm-4.7")
|
|||||||
WORKING_DIR: Path = Path(_cfg.get("WORKING_DIR", Path.home())).expanduser().resolve()
|
WORKING_DIR: Path = Path(_cfg.get("WORKING_DIR", Path.home())).expanduser().resolve()
|
||||||
METASO_API_KEY: str = _cfg.get("METASO_API_KEY", "")
|
METASO_API_KEY: str = _cfg.get("METASO_API_KEY", "")
|
||||||
|
|
||||||
|
ROUTER_MODE: bool = _cfg.get("ROUTER_MODE", False)
|
||||||
|
ROUTER_SECRET: str = _cfg.get("ROUTER_SECRET", "")
|
||||||
|
|
||||||
ALLOWED_OPEN_IDS: List[str] = _cfg.get("ALLOWED_OPEN_IDS", [])
|
ALLOWED_OPEN_IDS: List[str] = _cfg.get("ALLOWED_OPEN_IDS", [])
|
||||||
if ALLOWED_OPEN_IDS and not isinstance(ALLOWED_OPEN_IDS, list):
|
if ALLOWED_OPEN_IDS and not isinstance(ALLOWED_OPEN_IDS, list):
|
||||||
ALLOWED_OPEN_IDS = [str(ALLOWED_OPEN_IDS)]
|
ALLOWED_OPEN_IDS = [str(ALLOWED_OPEN_IDS)]
|
||||||
|
|||||||
6
host_client/__init__.py
Normal file
6
host_client/__init__.py
Normal file
@ -0,0 +1,6 @@
|
|||||||
|
"""Host client module - connects to router and runs local mailboy."""
|
||||||
|
|
||||||
|
from host_client.config import HostConfig, get_host_config
|
||||||
|
from host_client.main import NodeClient
|
||||||
|
|
||||||
|
__all__ = ["HostConfig", "get_host_config", "NodeClient"]
|
||||||
97
host_client/config.py
Normal file
97
host_client/config.py
Normal file
@ -0,0 +1,97 @@
|
|||||||
|
"""Host client configuration loader.
|
||||||
|
|
||||||
|
Loads host_config.yaml which contains:
|
||||||
|
- NODE_ID, DISPLAY_NAME
|
||||||
|
- ROUTER_URL, ROUTER_SECRET
|
||||||
|
- LLM config (OPENAI_*)
|
||||||
|
- WORKING_DIR, METASO_API_KEY
|
||||||
|
- SERVES_USERS list
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
import yaml
|
||||||
|
|
||||||
|
|
||||||
|
class HostConfig:
|
||||||
|
"""Configuration for a host client node."""
|
||||||
|
|
||||||
|
def __init__(self, config_path: Optional[Path] = None):
|
||||||
|
config_path = config_path or Path(__file__).parent.parent / "host_config.yaml"
|
||||||
|
self._load(config_path)
|
||||||
|
|
||||||
|
def _load(self, config_path: Path) -> None:
|
||||||
|
if not config_path.exists():
|
||||||
|
raise FileNotFoundError(f"Config file not found: {config_path}")
|
||||||
|
|
||||||
|
with open(config_path, "r", encoding="utf-8") as f:
|
||||||
|
data = yaml.safe_load(f) or {}
|
||||||
|
|
||||||
|
self.node_id: str = data.get("NODE_ID", "unknown-node")
|
||||||
|
self.display_name: str = data.get("DISPLAY_NAME", self.node_id)
|
||||||
|
self.router_url: str = data.get("ROUTER_URL", "ws://127.0.0.1:8000/ws/node")
|
||||||
|
self.router_secret: str = data.get("ROUTER_SECRET", "")
|
||||||
|
|
||||||
|
self.openai_base_url: str = data.get(
|
||||||
|
"OPENAI_BASE_URL", "https://open.bigmodel.cn/api/paas/v4/"
|
||||||
|
)
|
||||||
|
self.openai_api_key: str = data.get("OPENAI_API_KEY", "")
|
||||||
|
self.openai_model: str = data.get("OPENAI_MODEL", "glm-4.7")
|
||||||
|
|
||||||
|
self.working_dir: str = data.get("WORKING_DIR", str(Path.home() / "projects"))
|
||||||
|
self.metaso_api_key: Optional[str] = data.get("METASO_API_KEY")
|
||||||
|
|
||||||
|
serves_users = data.get("SERVES_USERS", [])
|
||||||
|
self.serves_users: List[str] = serves_users if isinstance(serves_users, list) else []
|
||||||
|
|
||||||
|
self.capabilities: List[str] = data.get(
|
||||||
|
"CAPABILITIES",
|
||||||
|
["claude_code", "shell", "file_ops", "web", "scheduler"],
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_keyring(cls, keyring_path: Optional[Path] = None) -> "HostConfig":
|
||||||
|
"""Create config from keyring.yaml (for standalone mode)."""
|
||||||
|
keyring_path = keyring_path or Path(__file__).parent.parent / "keyring.yaml"
|
||||||
|
if not keyring_path.exists():
|
||||||
|
raise FileNotFoundError(f"keyring.yaml not found: {keyring_path}")
|
||||||
|
|
||||||
|
with open(keyring_path, "r", encoding="utf-8") as f:
|
||||||
|
data = yaml.safe_load(f) or {}
|
||||||
|
|
||||||
|
config = cls.__new__(cls)
|
||||||
|
config.node_id = data.get("NODE_ID", "local-node")
|
||||||
|
config.display_name = data.get("DISPLAY_NAME", "Local Machine")
|
||||||
|
config.router_url = data.get("ROUTER_URL", "ws://127.0.0.1:8000/ws/node")
|
||||||
|
config.router_secret = data.get("ROUTER_SECRET", "")
|
||||||
|
|
||||||
|
config.openai_base_url = data.get(
|
||||||
|
"OPENAI_BASE_URL", "https://open.bigmodel.cn/api/paas/v4/"
|
||||||
|
)
|
||||||
|
config.openai_api_key = data.get("OPENAI_API_KEY", "")
|
||||||
|
config.openai_model = data.get("OPENAI_MODEL", "glm-4.7")
|
||||||
|
|
||||||
|
config.working_dir = data.get("WORKING_DIR", str(Path.home() / "projects"))
|
||||||
|
config.metaso_api_key = data.get("METASO_API_KEY")
|
||||||
|
|
||||||
|
serves_users = data.get("ALLOWED_OPEN_IDS", [])
|
||||||
|
config.serves_users = serves_users if isinstance(serves_users, list) else []
|
||||||
|
|
||||||
|
config.capabilities = ["claude_code", "shell", "file_ops", "web", "scheduler"]
|
||||||
|
|
||||||
|
return config
|
||||||
|
|
||||||
|
|
||||||
|
host_config: Optional[HostConfig] = None
|
||||||
|
|
||||||
|
|
||||||
|
def get_host_config() -> HostConfig:
|
||||||
|
"""Get the global host config instance."""
|
||||||
|
global host_config
|
||||||
|
if host_config is None:
|
||||||
|
host_config = HostConfig()
|
||||||
|
return host_config
|
||||||
280
host_client/main.py
Normal file
280
host_client/main.py
Normal file
@ -0,0 +1,280 @@
|
|||||||
|
"""Host client main module.
|
||||||
|
|
||||||
|
Connects to the router via WebSocket, receives forwarded messages,
|
||||||
|
runs the local mailboy LLM, and sends responses back.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
import secrets
|
||||||
|
import time
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import websockets
|
||||||
|
from websockets.client import WebSocketClientProtocol
|
||||||
|
|
||||||
|
from agent.manager import manager
|
||||||
|
from agent.scheduler import scheduler
|
||||||
|
from agent.task_runner import task_runner
|
||||||
|
from host_client.config import HostConfig, get_host_config
|
||||||
|
from orchestrator.agent import run as run_mailboy
|
||||||
|
from orchestrator.tools import set_current_user, set_current_chat
|
||||||
|
from shared import (
|
||||||
|
RegisterMessage,
|
||||||
|
ForwardRequest,
|
||||||
|
ForwardResponse,
|
||||||
|
TaskComplete,
|
||||||
|
Heartbeat,
|
||||||
|
NodeStatus,
|
||||||
|
encode,
|
||||||
|
decode,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class NodeClient:
|
||||||
|
"""WebSocket client that connects to the router and handles messages."""
|
||||||
|
|
||||||
|
def __init__(self, config: HostConfig):
|
||||||
|
self.config = config
|
||||||
|
self.ws: Optional[WebSocketClientProtocol] = None
|
||||||
|
self._running = False
|
||||||
|
self._last_heartbeat = time.time()
|
||||||
|
self._reconnect_delay = 1.0
|
||||||
|
|
||||||
|
async def connect(self) -> bool:
|
||||||
|
"""Connect to the router WebSocket."""
|
||||||
|
headers = {}
|
||||||
|
if self.config.router_secret:
|
||||||
|
headers["Authorization"] = f"Bearer {self.config.router_secret}"
|
||||||
|
|
||||||
|
try:
|
||||||
|
self.ws = await websockets.connect(
|
||||||
|
self.config.router_url,
|
||||||
|
extra_headers=headers,
|
||||||
|
ping_interval=30,
|
||||||
|
ping_timeout=10,
|
||||||
|
)
|
||||||
|
logger.info("Connected to router: %s", self.config.router_url)
|
||||||
|
self._reconnect_delay = 1.0
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("Failed to connect to router: %s", e)
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def register(self) -> bool:
|
||||||
|
"""Send registration message to the router."""
|
||||||
|
if not self.ws:
|
||||||
|
return False
|
||||||
|
|
||||||
|
msg = RegisterMessage(
|
||||||
|
node_id=self.config.node_id,
|
||||||
|
display_name=self.config.display_name,
|
||||||
|
serves_users=self.config.serves_users,
|
||||||
|
working_dir=self.config.working_dir,
|
||||||
|
capabilities=self.config.capabilities,
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
await self.ws.send(encode(msg))
|
||||||
|
logger.info("Sent registration for node: %s", self.config.node_id)
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("Failed to send registration: %s", e)
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def handle_forward(self, request: ForwardRequest) -> None:
|
||||||
|
"""Handle a forwarded message from the router."""
|
||||||
|
logger.info("Received forward request %s from user %s", request.id, request.user_id)
|
||||||
|
|
||||||
|
set_current_user(request.user_id)
|
||||||
|
set_current_chat(request.chat_id)
|
||||||
|
|
||||||
|
try:
|
||||||
|
reply = await run_mailboy(request.user_id, request.text)
|
||||||
|
|
||||||
|
response = ForwardResponse(
|
||||||
|
id=request.id,
|
||||||
|
reply=reply,
|
||||||
|
error="",
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception("Error processing forward request %s", request.id)
|
||||||
|
response = ForwardResponse(
|
||||||
|
id=request.id,
|
||||||
|
reply="",
|
||||||
|
error=str(e),
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.ws:
|
||||||
|
try:
|
||||||
|
await self.ws.send(encode(response))
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("Failed to send response: %s", e)
|
||||||
|
|
||||||
|
async def send_heartbeat(self) -> None:
|
||||||
|
"""Send a ping heartbeat to the router."""
|
||||||
|
if self.ws:
|
||||||
|
try:
|
||||||
|
await self.ws.send(encode(Heartbeat(type="ping")))
|
||||||
|
self._last_heartbeat = time.time()
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("Failed to send heartbeat: %s", e)
|
||||||
|
|
||||||
|
async def send_status(self) -> None:
|
||||||
|
"""Send node status update to the router."""
|
||||||
|
if not self.ws:
|
||||||
|
return
|
||||||
|
|
||||||
|
sessions = manager.list_sessions()
|
||||||
|
active_sessions = [
|
||||||
|
{"conv_id": s["conv_id"], "working_dir": s["working_dir"]}
|
||||||
|
for s in sessions
|
||||||
|
]
|
||||||
|
|
||||||
|
status = NodeStatus(
|
||||||
|
node_id=self.config.node_id,
|
||||||
|
sessions=len(sessions),
|
||||||
|
active_sessions=active_sessions,
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
await self.ws.send(encode(status))
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("Failed to send status: %s", e)
|
||||||
|
|
||||||
|
async def handle_message(self, data: str) -> None:
|
||||||
|
"""Handle an incoming message from the router."""
|
||||||
|
try:
|
||||||
|
msg = decode(data)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("Failed to decode message: %s", e)
|
||||||
|
return
|
||||||
|
|
||||||
|
if isinstance(msg, ForwardRequest):
|
||||||
|
await self.handle_forward(msg)
|
||||||
|
elif isinstance(msg, Heartbeat):
|
||||||
|
if msg.type == "ping":
|
||||||
|
if self.ws:
|
||||||
|
try:
|
||||||
|
await self.ws.send(encode(Heartbeat(type="pong")))
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("Failed to send pong: %s", e)
|
||||||
|
elif msg.type == "pong":
|
||||||
|
self._last_heartbeat = time.time()
|
||||||
|
else:
|
||||||
|
logger.debug("Received message type: %s", msg.type)
|
||||||
|
|
||||||
|
async def receive_loop(self) -> None:
|
||||||
|
"""Main receive loop for incoming messages."""
|
||||||
|
if not self.ws:
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
async for data in self.ws:
|
||||||
|
await self.handle_message(data)
|
||||||
|
except websockets.ConnectionClosed as e:
|
||||||
|
logger.warning("Connection closed: %s", e)
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception("Error in receive loop: %s", e)
|
||||||
|
|
||||||
|
async def heartbeat_loop(self) -> None:
|
||||||
|
"""Periodic heartbeat loop."""
|
||||||
|
while self._running:
|
||||||
|
await asyncio.sleep(30)
|
||||||
|
if self.ws and self.ws.open:
|
||||||
|
await self.send_heartbeat()
|
||||||
|
|
||||||
|
async def status_loop(self) -> None:
|
||||||
|
"""Periodic status update loop."""
|
||||||
|
while self._running:
|
||||||
|
await asyncio.sleep(60)
|
||||||
|
if self.ws and self.ws.open:
|
||||||
|
await self.send_status()
|
||||||
|
|
||||||
|
async def run(self) -> None:
|
||||||
|
"""Main run loop with reconnection."""
|
||||||
|
self._running = True
|
||||||
|
|
||||||
|
await manager.start()
|
||||||
|
await scheduler.start()
|
||||||
|
|
||||||
|
task_runner.set_notification_handler(self._send_task_complete)
|
||||||
|
|
||||||
|
while self._running:
|
||||||
|
if await self.connect():
|
||||||
|
if await self.register():
|
||||||
|
try:
|
||||||
|
await asyncio.gather(
|
||||||
|
self.receive_loop(),
|
||||||
|
self.heartbeat_loop(),
|
||||||
|
self.status_loop(),
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
if self._running:
|
||||||
|
logger.info("Reconnecting in %.1f seconds...", self._reconnect_delay)
|
||||||
|
await asyncio.sleep(self._reconnect_delay)
|
||||||
|
self._reconnect_delay = min(self._reconnect_delay * 2, 60)
|
||||||
|
|
||||||
|
async def _send_task_complete(self, task) -> None:
|
||||||
|
"""Send TaskComplete notification to router."""
|
||||||
|
if not self.ws:
|
||||||
|
return
|
||||||
|
|
||||||
|
from shared import TaskComplete, encode
|
||||||
|
|
||||||
|
msg = TaskComplete(
|
||||||
|
task_id=task.task_id,
|
||||||
|
user_id=task.user_id or "",
|
||||||
|
chat_id=task.notify_chat_id or "",
|
||||||
|
result=task.result or task.error or "",
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
await self.ws.send(encode(msg))
|
||||||
|
logger.info("Sent TaskComplete for task %s", task.task_id)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("Failed to send TaskComplete: %s", e)
|
||||||
|
|
||||||
|
async def stop(self) -> None:
|
||||||
|
"""Stop the client."""
|
||||||
|
self._running = False
|
||||||
|
if self.ws:
|
||||||
|
await self.ws.close()
|
||||||
|
await manager.stop()
|
||||||
|
await scheduler.stop()
|
||||||
|
logger.info("Node client stopped")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_keyring(cls, router_url: Optional[str] = None, secret: Optional[str] = None) -> "NodeClient":
|
||||||
|
"""Create a client from keyring.yaml (for standalone mode)."""
|
||||||
|
config = HostConfig.from_keyring()
|
||||||
|
if router_url:
|
||||||
|
config.router_url = router_url
|
||||||
|
if secret:
|
||||||
|
config.router_secret = secret
|
||||||
|
return cls(config)
|
||||||
|
|
||||||
|
|
||||||
|
async def main() -> None:
|
||||||
|
"""Entry point for standalone host client."""
|
||||||
|
logging.basicConfig(
|
||||||
|
level=logging.INFO,
|
||||||
|
format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
|
||||||
|
)
|
||||||
|
|
||||||
|
client = NodeClient(get_host_config())
|
||||||
|
|
||||||
|
try:
|
||||||
|
await client.run()
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
await client.stop()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
asyncio.run(main())
|
||||||
15
router/__init__.py
Normal file
15
router/__init__.py
Normal file
@ -0,0 +1,15 @@
|
|||||||
|
"""Router module - public-facing component of PhoneWork."""
|
||||||
|
|
||||||
|
from router.nodes import NodeRegistry, NodeConnection, get_node_registry
|
||||||
|
from router.main import create_app
|
||||||
|
from router.rpc import forward
|
||||||
|
from router.routing_agent import route
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"NodeRegistry",
|
||||||
|
"NodeConnection",
|
||||||
|
"get_node_registry",
|
||||||
|
"create_app",
|
||||||
|
"forward",
|
||||||
|
"route",
|
||||||
|
]
|
||||||
75
router/main.py
Normal file
75
router/main.py
Normal file
@ -0,0 +1,75 @@
|
|||||||
|
"""Router main module - FastAPI app factory.
|
||||||
|
|
||||||
|
Creates the FastAPI application with:
|
||||||
|
- Feishu WebSocket client
|
||||||
|
- Node WebSocket endpoint
|
||||||
|
- Health check endpoints
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from fastapi import FastAPI, WebSocket
|
||||||
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
|
|
||||||
|
from bot.handler import start_websocket_client
|
||||||
|
from router.nodes import NodeRegistry, get_node_registry
|
||||||
|
from router.ws import ws_node_endpoint
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def create_app(router_secret: Optional[str] = None) -> FastAPI:
|
||||||
|
"""Create the FastAPI application.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
router_secret: Secret for authenticating host client connections
|
||||||
|
"""
|
||||||
|
app = FastAPI(title="PhoneWork Router", version="3.0.0")
|
||||||
|
|
||||||
|
app.add_middleware(
|
||||||
|
CORSMiddleware,
|
||||||
|
allow_origins=["*"],
|
||||||
|
allow_credentials=True,
|
||||||
|
allow_methods=["*"],
|
||||||
|
allow_headers=["*"],
|
||||||
|
)
|
||||||
|
|
||||||
|
registry = get_node_registry()
|
||||||
|
if router_secret:
|
||||||
|
registry._secret = router_secret
|
||||||
|
|
||||||
|
@app.get("/health")
|
||||||
|
async def health():
|
||||||
|
nodes = registry.list_nodes()
|
||||||
|
online_nodes = [n for n in nodes if n["status"] == "online"]
|
||||||
|
return {
|
||||||
|
"status": "ok",
|
||||||
|
"nodes": nodes,
|
||||||
|
"online_nodes": len(online_nodes),
|
||||||
|
"total_nodes": len(nodes),
|
||||||
|
"pending_requests": 0,
|
||||||
|
}
|
||||||
|
|
||||||
|
@app.get("/nodes")
|
||||||
|
async def list_nodes():
|
||||||
|
return registry.list_nodes()
|
||||||
|
|
||||||
|
@app.websocket("/ws/node")
|
||||||
|
async def ws_node(websocket: WebSocket):
|
||||||
|
await ws_node_endpoint(websocket)
|
||||||
|
|
||||||
|
@app.on_event("startup")
|
||||||
|
async def startup():
|
||||||
|
import asyncio
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
start_websocket_client(loop)
|
||||||
|
logger.info("Router started")
|
||||||
|
|
||||||
|
@app.on_event("shutdown")
|
||||||
|
async def shutdown():
|
||||||
|
logger.info("Router shut down")
|
||||||
|
|
||||||
|
return app
|
||||||
209
router/nodes.py
Normal file
209
router/nodes.py
Normal file
@ -0,0 +1,209 @@
|
|||||||
|
"""Node registry for managing connected host clients.
|
||||||
|
|
||||||
|
Maintains:
|
||||||
|
- Connected nodes with their WebSocket connections
|
||||||
|
- User-to-node mapping (which users each node serves)
|
||||||
|
- Active node preference per user
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Any, Dict, List, Optional, Set
|
||||||
|
|
||||||
|
from shared import RegisterMessage, NodeStatus
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class NodeConnection:
|
||||||
|
"""Represents a connected host client."""
|
||||||
|
node_id: str
|
||||||
|
ws: Any
|
||||||
|
display_name: str = ""
|
||||||
|
serves_users: Set[str] = field(default_factory=set)
|
||||||
|
working_dir: str = ""
|
||||||
|
capabilities: List[str] = field(default_factory=list)
|
||||||
|
connected_at: float = field(default_factory=time.time)
|
||||||
|
last_heartbeat: float = field(default_factory=time.time)
|
||||||
|
sessions: int = 0
|
||||||
|
active_sessions: List[Dict[str, Any]] = field(default_factory=list)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_online(self) -> bool:
|
||||||
|
"""Check if node is still considered online (heartbeat within 60s)."""
|
||||||
|
return time.time() - self.last_heartbeat < 60
|
||||||
|
|
||||||
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
|
"""Serialize for API responses."""
|
||||||
|
return {
|
||||||
|
"node_id": self.node_id,
|
||||||
|
"display_name": self.display_name,
|
||||||
|
"status": "online" if self.is_online else "offline",
|
||||||
|
"users": len(self.serves_users),
|
||||||
|
"sessions": self.sessions,
|
||||||
|
"capabilities": self.capabilities,
|
||||||
|
"connected_at": self.connected_at,
|
||||||
|
"last_heartbeat": self.last_heartbeat,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class NodeRegistry:
|
||||||
|
"""Registry of connected host clients."""
|
||||||
|
|
||||||
|
def __init__(self, router_secret: str = ""):
|
||||||
|
self._nodes: Dict[str, NodeConnection] = {}
|
||||||
|
self._user_nodes: Dict[str, Set[str]] = {}
|
||||||
|
self._active_node: Dict[str, str] = {}
|
||||||
|
self._secret = router_secret
|
||||||
|
self._lock = asyncio.Lock()
|
||||||
|
|
||||||
|
def validate_secret(self, secret: str) -> bool:
|
||||||
|
"""Validate router secret."""
|
||||||
|
if not self._secret:
|
||||||
|
return True
|
||||||
|
return secret == self._secret
|
||||||
|
|
||||||
|
async def register(self, ws: Any, msg: RegisterMessage) -> NodeConnection:
|
||||||
|
"""Register a new node connection."""
|
||||||
|
async with self._lock:
|
||||||
|
is_reconnect = msg.node_id in self._nodes
|
||||||
|
|
||||||
|
node = NodeConnection(
|
||||||
|
node_id=msg.node_id,
|
||||||
|
ws=ws,
|
||||||
|
display_name=msg.display_name or msg.node_id,
|
||||||
|
serves_users=set(msg.serves_users),
|
||||||
|
working_dir=msg.working_dir,
|
||||||
|
capabilities=msg.capabilities,
|
||||||
|
)
|
||||||
|
self._nodes[msg.node_id] = node
|
||||||
|
|
||||||
|
for user_id in msg.serves_users:
|
||||||
|
if user_id not in self._user_nodes:
|
||||||
|
self._user_nodes[user_id] = set()
|
||||||
|
self._user_nodes[user_id].add(msg.node_id)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"Node registered: %s (users: %s, capabilities: %s)",
|
||||||
|
msg.node_id,
|
||||||
|
msg.serves_users,
|
||||||
|
msg.capabilities,
|
||||||
|
)
|
||||||
|
|
||||||
|
if is_reconnect:
|
||||||
|
for user_id in msg.serves_users:
|
||||||
|
asyncio.create_task(self._notify_reconnect(user_id, node.display_name))
|
||||||
|
|
||||||
|
return node
|
||||||
|
|
||||||
|
async def _notify_reconnect(self, user_id: str, node_name: str) -> None:
|
||||||
|
"""Notify user about node reconnect."""
|
||||||
|
try:
|
||||||
|
from bot.feishu import send_text
|
||||||
|
await send_text(user_id, "open_id", f"✅ Node \"{node_name}\" reconnected.")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("Failed to send reconnect notification: %s", e)
|
||||||
|
|
||||||
|
async def unregister(self, node_id: str) -> None:
|
||||||
|
"""Unregister a node connection."""
|
||||||
|
async with self._lock:
|
||||||
|
node = self._nodes.pop(node_id, None)
|
||||||
|
if node:
|
||||||
|
affected_users = list(node.serves_users)
|
||||||
|
|
||||||
|
for user_id in node.serves_users:
|
||||||
|
if user_id in self._user_nodes:
|
||||||
|
self._user_nodes[user_id].discard(node_id)
|
||||||
|
if not self._user_nodes[user_id]:
|
||||||
|
del self._user_nodes[user_id]
|
||||||
|
|
||||||
|
for user_id in list(self._active_node.keys()):
|
||||||
|
if self._active_node[user_id] == node_id:
|
||||||
|
del self._active_node[user_id]
|
||||||
|
|
||||||
|
logger.info("Node unregistered: %s", node_id)
|
||||||
|
|
||||||
|
for user_id in affected_users:
|
||||||
|
asyncio.create_task(self._notify_disconnect(user_id, node.display_name))
|
||||||
|
|
||||||
|
async def _notify_disconnect(self, user_id: str, node_name: str) -> None:
|
||||||
|
"""Notify user about node disconnect."""
|
||||||
|
try:
|
||||||
|
from bot.feishu import send_text
|
||||||
|
await send_text(user_id, "open_id", f"⚠️ Node \"{node_name}\" disconnected.")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("Failed to send disconnect notification: %s", e)
|
||||||
|
|
||||||
|
async def update_status(self, msg: NodeStatus) -> None:
|
||||||
|
"""Update node status from heartbeat."""
|
||||||
|
async with self._lock:
|
||||||
|
node = self._nodes.get(msg.node_id)
|
||||||
|
if node:
|
||||||
|
node.sessions = msg.sessions
|
||||||
|
node.active_sessions = msg.active_sessions
|
||||||
|
node.last_heartbeat = time.time()
|
||||||
|
|
||||||
|
async def update_heartbeat(self, node_id: str) -> None:
|
||||||
|
"""Update node heartbeat timestamp."""
|
||||||
|
async with self._lock:
|
||||||
|
node = self._nodes.get(node_id)
|
||||||
|
if node:
|
||||||
|
node.last_heartbeat = time.time()
|
||||||
|
|
||||||
|
def get_node(self, node_id: str) -> Optional[NodeConnection]:
|
||||||
|
"""Get a node by ID."""
|
||||||
|
return self._nodes.get(node_id)
|
||||||
|
|
||||||
|
def get_nodes_for_user(self, user_id: str) -> List[NodeConnection]:
|
||||||
|
"""Get all nodes that serve a user."""
|
||||||
|
node_ids = self._user_nodes.get(user_id, set())
|
||||||
|
return [self._nodes[nid] for nid in node_ids if nid in self._nodes]
|
||||||
|
|
||||||
|
def get_active_node(self, user_id: str) -> Optional[NodeConnection]:
|
||||||
|
"""Get the active node for a user."""
|
||||||
|
node_id = self._active_node.get(user_id)
|
||||||
|
if node_id:
|
||||||
|
return self._nodes.get(node_id)
|
||||||
|
|
||||||
|
nodes = self.get_nodes_for_user(user_id)
|
||||||
|
if nodes:
|
||||||
|
online = [n for n in nodes if n.is_online]
|
||||||
|
if online:
|
||||||
|
return online[0]
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
def set_active_node(self, user_id: str, node_id: str) -> bool:
|
||||||
|
"""Set the active node for a user."""
|
||||||
|
if node_id not in self._nodes:
|
||||||
|
return False
|
||||||
|
self._active_node[user_id] = node_id
|
||||||
|
logger.info("Active node for user %s set to %s", user_id, node_id)
|
||||||
|
return True
|
||||||
|
|
||||||
|
def list_nodes(self) -> List[Dict[str, Any]]:
|
||||||
|
"""List all nodes with their status."""
|
||||||
|
return [node.to_dict() for node in self._nodes.values()]
|
||||||
|
|
||||||
|
def get_affected_users(self, node_id: str) -> List[str]:
|
||||||
|
"""Get users affected by a node disconnect."""
|
||||||
|
node = self._nodes.get(node_id)
|
||||||
|
if node:
|
||||||
|
return list(node.serves_users)
|
||||||
|
return []
|
||||||
|
|
||||||
|
|
||||||
|
node_registry: Optional[NodeRegistry] = None
|
||||||
|
|
||||||
|
|
||||||
|
def get_node_registry() -> NodeRegistry:
|
||||||
|
"""Get the global node registry instance."""
|
||||||
|
global node_registry
|
||||||
|
if node_registry is None:
|
||||||
|
node_registry = NodeRegistry()
|
||||||
|
return node_registry
|
||||||
128
router/routing_agent.py
Normal file
128
router/routing_agent.py
Normal file
@ -0,0 +1,128 @@
|
|||||||
|
"""Routing LLM for deciding which node to forward messages to.
|
||||||
|
|
||||||
|
This is a lightweight, one-shot LLM call that decides routing.
|
||||||
|
No history, no multi-step loop. Single call with one tool.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
from langchain_core.messages import HumanMessage, SystemMessage
|
||||||
|
from langchain_openai import ChatOpenAI
|
||||||
|
|
||||||
|
from config import OPENAI_API_KEY, OPENAI_BASE_URL, OPENAI_MODEL
|
||||||
|
from router.nodes import NodeConnection, get_node_registry
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
ROUTING_SYSTEM_PROMPT = """You are a routing assistant. A user has sent a message. \
|
||||||
|
Choose which node to forward it to.
|
||||||
|
|
||||||
|
Connected nodes for this user:
|
||||||
|
{nodes_info}
|
||||||
|
|
||||||
|
Rules:
|
||||||
|
- If the message references an active session on a node, route to that node.
|
||||||
|
- If the user names a machine explicitly ("on work-server", "@home-pc"), route there.
|
||||||
|
- If only one node is connected, route there without asking.
|
||||||
|
- If ambiguous with multiple idle nodes, ask the user to clarify.
|
||||||
|
- For meta commands (/nodes, /help, /status), respond with "meta" as the node_id.
|
||||||
|
|
||||||
|
Respond with a JSON object:
|
||||||
|
{{"node_id": "<node_id>", "reason": "<brief reason>"}}
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def _format_nodes_info(nodes: List[NodeConnection], active_node_id: Optional[str] = None) -> str:
|
||||||
|
"""Format node information for the routing prompt."""
|
||||||
|
lines = []
|
||||||
|
for node in nodes:
|
||||||
|
marker = " [ACTIVE]" if node.node_id == active_node_id else ""
|
||||||
|
sessions = ", ".join(
|
||||||
|
s.get("working_dir", "unknown") for s in node.active_sessions[:3]
|
||||||
|
) or "none"
|
||||||
|
lines.append(
|
||||||
|
f"- {node.display_name or node.node_id}{marker}: "
|
||||||
|
f"sessions=[{sessions}], capabilities={node.capabilities}"
|
||||||
|
)
|
||||||
|
return "\n".join(lines)
|
||||||
|
|
||||||
|
|
||||||
|
async def route(user_id: str, chat_id: str, text: str) -> tuple[Optional[str], str]:
|
||||||
|
"""Determine which node to route a message to.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: User's Feishu open_id
|
||||||
|
chat_id: Chat ID for context
|
||||||
|
text: User's message text
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (node_id, reason). node_id is None if no suitable node found.
|
||||||
|
"""
|
||||||
|
registry = get_node_registry()
|
||||||
|
nodes = registry.get_nodes_for_user(user_id)
|
||||||
|
|
||||||
|
if not nodes:
|
||||||
|
return None, "No nodes available for this user"
|
||||||
|
|
||||||
|
online_nodes = [n for n in nodes if n.is_online]
|
||||||
|
if not online_nodes:
|
||||||
|
return None, "All nodes for this user are offline"
|
||||||
|
|
||||||
|
if len(online_nodes) == 1:
|
||||||
|
return online_nodes[0].node_id, "Only one node available"
|
||||||
|
|
||||||
|
if text.strip().startswith("/"):
|
||||||
|
return "meta", "Meta command"
|
||||||
|
|
||||||
|
active_node = registry.get_active_node(user_id)
|
||||||
|
active_node_id = active_node.node_id if active_node else None
|
||||||
|
|
||||||
|
nodes_info = _format_nodes_info(online_nodes, active_node_id)
|
||||||
|
|
||||||
|
try:
|
||||||
|
llm = ChatOpenAI(
|
||||||
|
model=OPENAI_MODEL,
|
||||||
|
openai_api_key=OPENAI_API_KEY,
|
||||||
|
openai_api_base=OPENAI_BASE_URL,
|
||||||
|
temperature=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
prompt = ROUTING_SYSTEM_PROMPT.format(nodes_info=nodes_info)
|
||||||
|
messages = [
|
||||||
|
SystemMessage(content=prompt),
|
||||||
|
HumanMessage(content=text),
|
||||||
|
]
|
||||||
|
|
||||||
|
response = await llm.ainvoke(messages)
|
||||||
|
content = response.content.strip()
|
||||||
|
|
||||||
|
if content.startswith("```"):
|
||||||
|
content = content.split("\n", 1)[1]
|
||||||
|
content = content.rsplit("```", 1)[0]
|
||||||
|
|
||||||
|
result = json.loads(content)
|
||||||
|
node_id = result.get("node_id")
|
||||||
|
reason = result.get("reason", "")
|
||||||
|
|
||||||
|
if node_id == "meta":
|
||||||
|
return "meta", reason
|
||||||
|
|
||||||
|
for node in online_nodes:
|
||||||
|
if node.node_id == node_id or node.display_name == node_id:
|
||||||
|
return node.node_id, reason
|
||||||
|
|
||||||
|
if active_node:
|
||||||
|
return active_node.node_id, f"Defaulting to active node (LLM suggested unavailable: {node_id})"
|
||||||
|
|
||||||
|
return online_nodes[0].node_id, f"Defaulting to first available node (LLM suggested: {node_id})"
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning("Routing LLM failed: %s, falling back to active node", e)
|
||||||
|
|
||||||
|
if active_node:
|
||||||
|
return active_node.node_id, "Fallback to active node"
|
||||||
|
return online_nodes[0].node_id, "Fallback to first available node"
|
||||||
109
router/rpc.py
Normal file
109
router/rpc.py
Normal file
@ -0,0 +1,109 @@
|
|||||||
|
"""RPC module for forwarding requests to host clients.
|
||||||
|
|
||||||
|
Handles:
|
||||||
|
- Request correlation with asyncio.Future
|
||||||
|
- Timeout management
|
||||||
|
- Response routing
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
import uuid
|
||||||
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
|
from shared import ForwardRequest, ForwardResponse, TaskComplete, encode
|
||||||
|
from router.nodes import get_node_registry
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
_pending_requests: Dict[str, asyncio.Future] = {}
|
||||||
|
_default_timeout = 600.0
|
||||||
|
|
||||||
|
|
||||||
|
async def forward(
|
||||||
|
node_id: str,
|
||||||
|
user_id: str,
|
||||||
|
chat_id: str,
|
||||||
|
text: str,
|
||||||
|
timeout: float = _default_timeout,
|
||||||
|
) -> str:
|
||||||
|
"""Forward a message to a host client and wait for response.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
node_id: Target node ID
|
||||||
|
user_id: User's Feishu open_id
|
||||||
|
chat_id: Chat ID for context
|
||||||
|
text: Message text to forward
|
||||||
|
timeout: Timeout in seconds (default 600s for long CC tasks)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Reply text from the host client
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
asyncio.TimeoutError: If no response within timeout
|
||||||
|
RuntimeError: If node is not connected
|
||||||
|
"""
|
||||||
|
registry = get_node_registry()
|
||||||
|
node = registry.get_node(node_id)
|
||||||
|
|
||||||
|
if not node or not node.ws:
|
||||||
|
raise RuntimeError(f"Node not connected: {node_id}")
|
||||||
|
|
||||||
|
request_id = str(uuid.uuid4())
|
||||||
|
future: asyncio.Future = asyncio.get_event_loop().create_future()
|
||||||
|
_pending_requests[request_id] = future
|
||||||
|
|
||||||
|
request = ForwardRequest(
|
||||||
|
id=request_id,
|
||||||
|
user_id=user_id,
|
||||||
|
chat_id=chat_id,
|
||||||
|
text=text,
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
await node.ws.send_text(encode(request))
|
||||||
|
logger.debug("Forwarded request %s to node %s", request_id, node_id)
|
||||||
|
|
||||||
|
result = await asyncio.wait_for(future, timeout=timeout)
|
||||||
|
return result
|
||||||
|
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
logger.warning("Request %s timed out after %ss", request_id, timeout)
|
||||||
|
raise
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("Failed to forward request %s: %s", request_id, e)
|
||||||
|
raise
|
||||||
|
|
||||||
|
finally:
|
||||||
|
_pending_requests.pop(request_id, None)
|
||||||
|
|
||||||
|
|
||||||
|
async def resolve_response(response: ForwardResponse) -> None:
|
||||||
|
"""Resolve a pending request with a response."""
|
||||||
|
future = _pending_requests.get(response.id)
|
||||||
|
if future and not future.done():
|
||||||
|
if response.error:
|
||||||
|
future.set_exception(RuntimeError(response.error))
|
||||||
|
else:
|
||||||
|
future.set_result(response.reply)
|
||||||
|
logger.debug("Resolved request %s", response.id)
|
||||||
|
|
||||||
|
|
||||||
|
async def handle_task_complete(msg: TaskComplete) -> None:
|
||||||
|
"""Handle a task completion notification from a host client."""
|
||||||
|
logger.info("Task %s completed for user %s", msg.task_id, msg.user_id)
|
||||||
|
|
||||||
|
from bot.feishu import send_text
|
||||||
|
|
||||||
|
try:
|
||||||
|
await send_text(msg.chat_id, "chat_id", msg.result)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("Failed to send task completion notification: %s", e)
|
||||||
|
|
||||||
|
|
||||||
|
def get_pending_count() -> int:
|
||||||
|
"""Get the number of pending requests."""
|
||||||
|
return len(_pending_requests)
|
||||||
102
router/ws.py
Normal file
102
router/ws.py
Normal file
@ -0,0 +1,102 @@
|
|||||||
|
"""WebSocket endpoint for host client connections.
|
||||||
|
|
||||||
|
Handles:
|
||||||
|
- Connection authentication
|
||||||
|
- Node registration
|
||||||
|
- Message forwarding
|
||||||
|
- Heartbeat
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from fastapi import WebSocket, WebSocketDisconnect, WebSocketException
|
||||||
|
|
||||||
|
from router.nodes import get_node_registry
|
||||||
|
from router.rpc import handle_task_complete
|
||||||
|
from shared import (
|
||||||
|
RegisterMessage,
|
||||||
|
ForwardRequest,
|
||||||
|
ForwardResponse,
|
||||||
|
TaskComplete,
|
||||||
|
Heartbeat,
|
||||||
|
NodeStatus,
|
||||||
|
decode,
|
||||||
|
encode,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
async def ws_node_endpoint(websocket: WebSocket) -> None:
|
||||||
|
"""WebSocket endpoint for host client connections."""
|
||||||
|
await websocket.accept()
|
||||||
|
|
||||||
|
registry = get_node_registry()
|
||||||
|
|
||||||
|
secret = websocket.headers.get("authorization", "")
|
||||||
|
if secret.startswith("Bearer "):
|
||||||
|
secret = secret[7:]
|
||||||
|
|
||||||
|
if not registry.validate_secret(secret):
|
||||||
|
logger.warning("Invalid router secret, rejecting connection")
|
||||||
|
await websocket.close(code=4001, reason="Invalid secret")
|
||||||
|
return
|
||||||
|
|
||||||
|
node_id: Optional[str] = None
|
||||||
|
heartbeat_task: Optional[asyncio.Task] = None
|
||||||
|
|
||||||
|
async def send_heartbeat():
|
||||||
|
"""Send periodic pings to the host client."""
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
await asyncio.sleep(30)
|
||||||
|
try:
|
||||||
|
await websocket.send_text(encode(Heartbeat(type="ping")))
|
||||||
|
except Exception:
|
||||||
|
break
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
try:
|
||||||
|
async for data in websocket.iter_text():
|
||||||
|
try:
|
||||||
|
msg = decode(data)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("Failed to decode message: %s", e)
|
||||||
|
continue
|
||||||
|
|
||||||
|
if isinstance(msg, RegisterMessage):
|
||||||
|
node_id = msg.node_id
|
||||||
|
await registry.register(websocket, msg)
|
||||||
|
heartbeat_task = asyncio.create_task(send_heartbeat())
|
||||||
|
|
||||||
|
elif isinstance(msg, ForwardResponse):
|
||||||
|
from router.rpc import resolve_response
|
||||||
|
await resolve_response(msg)
|
||||||
|
|
||||||
|
elif isinstance(msg, TaskComplete):
|
||||||
|
await handle_task_complete(msg)
|
||||||
|
|
||||||
|
elif isinstance(msg, Heartbeat):
|
||||||
|
if msg.type == "pong" and node_id:
|
||||||
|
await registry.update_heartbeat(node_id)
|
||||||
|
|
||||||
|
elif isinstance(msg, NodeStatus):
|
||||||
|
await registry.update_status(msg)
|
||||||
|
|
||||||
|
else:
|
||||||
|
logger.debug("Received unhandled message type: %s", type(msg).__name__)
|
||||||
|
|
||||||
|
except WebSocketDisconnect:
|
||||||
|
logger.info("WebSocket disconnected")
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception("WebSocket error: %s", e)
|
||||||
|
finally:
|
||||||
|
if heartbeat_task:
|
||||||
|
heartbeat_task.cancel()
|
||||||
|
if node_id:
|
||||||
|
await registry.unregister(node_id)
|
||||||
23
shared/__init__.py
Normal file
23
shared/__init__.py
Normal file
@ -0,0 +1,23 @@
|
|||||||
|
"""Shared module for Router <-> Host Client communication."""
|
||||||
|
|
||||||
|
from shared.protocol import (
|
||||||
|
RegisterMessage,
|
||||||
|
ForwardRequest,
|
||||||
|
ForwardResponse,
|
||||||
|
TaskComplete,
|
||||||
|
Heartbeat,
|
||||||
|
NodeStatus,
|
||||||
|
encode,
|
||||||
|
decode,
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"RegisterMessage",
|
||||||
|
"ForwardRequest",
|
||||||
|
"ForwardResponse",
|
||||||
|
"TaskComplete",
|
||||||
|
"Heartbeat",
|
||||||
|
"NodeStatus",
|
||||||
|
"encode",
|
||||||
|
"decode",
|
||||||
|
]
|
||||||
94
shared/protocol.py
Normal file
94
shared/protocol.py
Normal file
@ -0,0 +1,94 @@
|
|||||||
|
"""Shared protocol module for Router <-> Host Client communication.
|
||||||
|
|
||||||
|
All message types are dataclasses that serialize to/from JSON.
|
||||||
|
Both router and host client import from this module.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
from dataclasses import asdict, dataclass, field
|
||||||
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class RegisterMessage:
|
||||||
|
"""Host client -> Router: Register this node."""
|
||||||
|
type: str = "register"
|
||||||
|
node_id: str = ""
|
||||||
|
serves_users: List[str] = field(default_factory=list)
|
||||||
|
working_dir: str = ""
|
||||||
|
capabilities: List[str] = field(default_factory=list)
|
||||||
|
display_name: str = ""
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ForwardRequest:
|
||||||
|
"""Router -> Host client: Forward a user message."""
|
||||||
|
type: str = "forward"
|
||||||
|
id: str = ""
|
||||||
|
user_id: str = ""
|
||||||
|
chat_id: str = ""
|
||||||
|
text: str = ""
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ForwardResponse:
|
||||||
|
"""Host client -> Router: Reply to a forwarded message."""
|
||||||
|
type: str = "forward_response"
|
||||||
|
id: str = ""
|
||||||
|
reply: str = ""
|
||||||
|
error: str = ""
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class TaskComplete:
|
||||||
|
"""Host client -> Router: Background task finished."""
|
||||||
|
type: str = "task_complete"
|
||||||
|
task_id: str = ""
|
||||||
|
user_id: str = ""
|
||||||
|
chat_id: str = ""
|
||||||
|
result: str = ""
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Heartbeat:
|
||||||
|
"""Bidirectional ping/pong."""
|
||||||
|
type: str = "ping"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class NodeStatus:
|
||||||
|
"""Host client -> Router: Periodic status update."""
|
||||||
|
type: str = "node_status"
|
||||||
|
node_id: str = ""
|
||||||
|
sessions: int = 0
|
||||||
|
active_sessions: List[Dict[str, Any]] = field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
|
MESSAGE_TYPES = {
|
||||||
|
"register": RegisterMessage,
|
||||||
|
"forward": ForwardRequest,
|
||||||
|
"forward_response": ForwardResponse,
|
||||||
|
"task_complete": TaskComplete,
|
||||||
|
"ping": Heartbeat,
|
||||||
|
"pong": Heartbeat,
|
||||||
|
"node_status": NodeStatus,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def encode(msg: Any) -> str:
|
||||||
|
"""Encode a message to JSON string."""
|
||||||
|
if hasattr(msg, "type"):
|
||||||
|
return json.dumps(asdict(msg), ensure_ascii=False)
|
||||||
|
raise ValueError(f"Invalid message type: {type(msg)}")
|
||||||
|
|
||||||
|
|
||||||
|
def decode(data: str) -> Any:
|
||||||
|
"""Decode a JSON string to a message object."""
|
||||||
|
obj = json.loads(data)
|
||||||
|
msg_type = obj.get("type")
|
||||||
|
if msg_type not in MESSAGE_TYPES:
|
||||||
|
raise ValueError(f"Unknown message type: {msg_type}")
|
||||||
|
cls = MESSAGE_TYPES[msg_type]
|
||||||
|
return cls(**{k: v for k, v in obj.items() if k in cls.__dataclass_fields__})
|
||||||
72
standalone.py
Normal file
72
standalone.py
Normal file
@ -0,0 +1,72 @@
|
|||||||
|
"""Run router + host client in a single process (localhost mode).
|
||||||
|
|
||||||
|
Equivalent to the pre-M3 single-machine setup.
|
||||||
|
Users run `python standalone.py` and get the exact same experience as `python main.py`,
|
||||||
|
but the code paths use the multi-host architecture internally.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
import secrets
|
||||||
|
import sys
|
||||||
|
|
||||||
|
import uvicorn
|
||||||
|
|
||||||
|
logging.basicConfig(
|
||||||
|
level=logging.INFO,
|
||||||
|
format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
|
||||||
|
)
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
async def run_standalone() -> None:
|
||||||
|
"""Run router + host client in a single process."""
|
||||||
|
secret = secrets.token_hex(16)
|
||||||
|
router_url = "ws://127.0.0.1:8000/ws/node"
|
||||||
|
|
||||||
|
from router.main import create_app
|
||||||
|
from host_client.main import NodeClient
|
||||||
|
from host_client.config import HostConfig
|
||||||
|
|
||||||
|
config = HostConfig.from_keyring()
|
||||||
|
config.router_url = router_url
|
||||||
|
config.router_secret = secret
|
||||||
|
|
||||||
|
app = create_app(router_secret=secret)
|
||||||
|
|
||||||
|
config_obj = uvicorn.Config(
|
||||||
|
app,
|
||||||
|
host="0.0.0.0",
|
||||||
|
port=8000,
|
||||||
|
log_level="info",
|
||||||
|
)
|
||||||
|
server = uvicorn.Server(config_obj)
|
||||||
|
|
||||||
|
async def run_server():
|
||||||
|
await server.serve()
|
||||||
|
|
||||||
|
async def run_client():
|
||||||
|
await asyncio.sleep(1.5)
|
||||||
|
client = NodeClient(config)
|
||||||
|
try:
|
||||||
|
await client.run()
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception("Host client error: %s", e)
|
||||||
|
server.should_exit = True
|
||||||
|
|
||||||
|
await asyncio.gather(run_server(), run_client())
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> None:
|
||||||
|
"""Entry point."""
|
||||||
|
logger.info("Starting PhoneWork in standalone mode...")
|
||||||
|
try:
|
||||||
|
asyncio.run(run_standalone())
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
logger.info("Shutting down...")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
Loading…
x
Reference in New Issue
Block a user