Compare commits

...

2 Commits

Author SHA1 Message Date
Yuyao Huang (Sam)
a3622ce26d refactor: 替换 asyncio.get_event_loop 为 get_running_loop 并优化会话卡片
- 将多处 asyncio.get_event_loop() 替换为更安全的 asyncio.get_running_loop()
- 重构 Feishu 卡片功能,新增 build_sessions_card 方法显示所有会话
- 优化文件路径处理逻辑,支持绝对路径和相对路径
- 在健康检查接口中添加 pending_requests 计数
- 更新会话状态命令以支持卡片显示
2026-03-28 14:59:33 +08:00
Yuyao Huang (Sam)
09b63341cd refactor: 统一使用现代类型注解替代传统类型注解
- 将 Dict、List 等传统类型注解替换为 dict、list 等现代类型注解
- 更新类型注解以更精确地反映变量类型
- 修复部分类型注解与实际使用不匹配的问题
- 优化部分代码逻辑以提高类型安全性
2026-03-28 14:27:21 +08:00
18 changed files with 142 additions and 143 deletions

View File

@ -46,7 +46,7 @@ class SessionManager:
"""Registry of active Claude Code project sessions with persistence and user isolation.""" """Registry of active Claude Code project sessions with persistence and user isolation."""
def __init__(self) -> None: def __init__(self) -> None:
self._sessions: Dict[str, Session] = {} self._sessions: dict[str, Session] = {}
self._lock = asyncio.Lock() self._lock = asyncio.Lock()
self._reaper_task: Optional[asyncio.Task] = None self._reaper_task: Optional[asyncio.Task] = None

View File

@ -53,8 +53,8 @@ class Scheduler:
"""Singleton that manages scheduled jobs with Feishu notifications.""" """Singleton that manages scheduled jobs with Feishu notifications."""
def __init__(self) -> None: def __init__(self) -> None:
self._jobs: Dict[str, ScheduledJob] = {} self._jobs: dict[str, ScheduledJob] = {}
self._tasks: Dict[str, asyncio.Task] = {} self._tasks: dict[str, asyncio.Task] = {}
self._lock = asyncio.Lock() self._lock = asyncio.Lock()
self._started = False self._started = False

View File

@ -8,7 +8,7 @@ import time
import uuid import uuid
from dataclasses import dataclass, field from dataclasses import dataclass, field
from enum import Enum from enum import Enum
from typing import Any, Callable, Dict, Optional from typing import Any, Awaitable, Callable, Dict, Optional
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -43,17 +43,17 @@ class TaskRunner:
"""Singleton that manages background tasks with Feishu notifications.""" """Singleton that manages background tasks with Feishu notifications."""
def __init__(self) -> None: def __init__(self) -> None:
self._tasks: Dict[str, BackgroundTask] = {} self._tasks: dict[str, BackgroundTask] = {}
self._lock = asyncio.Lock() self._lock = asyncio.Lock()
self._notification_handler: Optional[Callable] = None self._notification_handler: Optional[Callable[[BackgroundTask], Awaitable[None]]] = None
def set_notification_handler(self, handler: Optional[Callable]) -> None: def set_notification_handler(self, handler: Optional[Callable[[BackgroundTask], Awaitable[None]]]) -> None:
"""Set custom notification handler for M3 mode (host client -> router).""" """Set custom notification handler for M3 mode (host client -> router)."""
self._notification_handler = handler self._notification_handler = handler
async def submit( async def submit(
self, self,
coro: Callable[[], Any], coro: Awaitable[Any],
description: str, description: str,
notify_chat_id: Optional[str] = None, notify_chat_id: Optional[str] = None,
user_id: Optional[str] = None, user_id: Optional[str] = None,
@ -76,7 +76,7 @@ class TaskRunner:
logger.info("Submitted background task %s: %s", task_id, description) logger.info("Submitted background task %s: %s", task_id, description)
return task_id return task_id
async def _run_task(self, task_id: str, coro: Callable[[], Any]) -> None: async def _run_task(self, task_id: str, coro: Awaitable[Any]) -> None:
"""Execute a task and send notification on completion.""" """Execute a task and send notification on completion."""
async with self._lock: async with self._lock:
task = self._tasks.get(task_id) task = self._tasks.get(task_id)
@ -130,14 +130,15 @@ class TaskRunner:
msg += f"\n\n**Error:** {task.error}" msg += f"\n\n**Error:** {task.error}"
try: try:
await send_text(task.notify_chat_id, "chat_id", msg) if task.notify_chat_id:
await send_text(task.notify_chat_id, "chat_id", msg)
except Exception: except Exception:
logger.exception("Failed to send notification for task %s", task.task_id) logger.exception("Failed to send notification for task %s", task.task_id)
def get_task(self, task_id: str) -> Optional[BackgroundTask]: def get_task(self, task_id: str) -> Optional[BackgroundTask]:
return self._tasks.get(task_id) return self._tasks.get(task_id)
def list_tasks(self, limit: int = 20) -> list[dict]: def list_tasks(self, limit: int = 20) -> list[dict[str, Any]]:
tasks = sorted( tasks = sorted(
self._tasks.values(), self._tasks.values(),
key=lambda t: t.started_at, key=lambda t: t.started_at,

View File

@ -13,7 +13,7 @@ from agent.manager import manager
from agent.scheduler import scheduler from agent.scheduler import scheduler
from agent.task_runner import task_runner from agent.task_runner import task_runner
from orchestrator.agent import agent from orchestrator.agent import agent
from orchestrator.tools import set_current_user, get_current_user, get_current_chat from orchestrator.tools import set_current_user, get_current_chat
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -111,6 +111,18 @@ async def _cmd_new(user_id: str, args: str) -> str:
conv_id = data.get("conv_id", "") conv_id = data.get("conv_id", "")
agent._active_conv[user_id] = conv_id agent._active_conv[user_id] = conv_id
cwd = data.get("working_dir", working_dir) cwd = data.get("working_dir", working_dir)
chat_id = get_current_chat()
if chat_id:
from bot.feishu import send_card, send_text, build_sessions_card
sessions = manager.list_sessions(user_id=user_id)
mode = "Direct 🟢" if agent.get_passthrough(user_id) else "Smart ⚪"
card = build_sessions_card(sessions, conv_id, mode)
await send_card(chat_id, "chat_id", card)
if initial_msg and data.get("response"):
await send_text(chat_id, "chat_id", data["response"])
return ""
reply = f"✓ Created session `{conv_id}` in `{cwd}`" reply = f"✓ Created session `{conv_id}` in `{cwd}`"
if parsed.timeout: if parsed.timeout:
reply += f" (timeout: {parsed.timeout}s)" reply += f" (timeout: {parsed.timeout}s)"
@ -124,17 +136,24 @@ async def _cmd_new(user_id: str, args: str) -> str:
async def _cmd_status(user_id: str) -> str: async def _cmd_status(user_id: str) -> str:
"""Show status: sessions and current mode.""" """Show status: sessions and current mode."""
sessions = manager.list_sessions(user_id=user_id) sessions = manager.list_sessions(user_id=user_id)
if not sessions:
return "No active sessions."
active = agent.get_active_conv(user_id) active = agent.get_active_conv(user_id)
passthrough = agent.get_passthrough(user_id) passthrough = agent.get_passthrough(user_id)
mode = "Direct 🟢" if passthrough else "Smart ⚪"
chat_id = get_current_chat()
if chat_id:
from bot.feishu import send_card, build_sessions_card
card = build_sessions_card(sessions, active, mode)
await send_card(chat_id, "chat_id", card)
return ""
if not sessions:
return "No active sessions."
lines = ["**Your Sessions:**\n"] lines = ["**Your Sessions:**\n"]
for i, s in enumerate(sessions, 1): for i, s in enumerate(sessions, 1):
marker = "" if s["conv_id"] == active else " " marker = "" if s["conv_id"] == active else " "
lines.append(f"{marker}{i}. `{s['conv_id']}` - `{s['cwd']}`") lines.append(f"{marker}{i}. `{s['conv_id']}` - `{s['cwd']}`")
status = "Direct 🟢" if passthrough else "Smart ⚪" lines.append(f"\n**Mode:** {mode}")
lines.append(f"\n**Mode:** {status}")
lines.append("Use `/switch <n>` to activate a session.") lines.append("Use `/switch <n>` to activate a session.")
lines.append("Use `/direct` or `/smart` to change mode.") lines.append("Use `/direct` or `/smart` to change mode.")
return "\n".join(lines) return "\n".join(lines)

View File

@ -67,7 +67,7 @@ async def send_text(receive_id: str, receive_id_type: str, text: str) -> None:
text: message content. text: message content.
""" """
parts = _split_message(text) parts = _split_message(text)
loop = asyncio.get_event_loop() loop = asyncio.get_running_loop()
for i, part in enumerate(parts): for i, part in enumerate(parts):
logger.debug( logger.debug(
@ -108,59 +108,16 @@ async def send_text(receive_id: str, receive_id_type: str, text: str) -> None:
await asyncio.sleep(0.3) await asyncio.sleep(0.3)
async def send_card(receive_id: str, receive_id_type: str, title: str, content: str, buttons: list[dict] | None = None) -> None: async def send_card(receive_id: str, receive_id_type: str, card: dict) -> None:
""" """
Send an interactive card message. Send an interactive card message.
Args: Args:
receive_id: chat_id or open_id receive_id: chat_id or open_id depending on receive_id_type.
receive_id_type: "chat_id" | "open_id" | "user_id" | "union_id" receive_id_type: "chat_id" | "open_id" | "user_id" | "union_id".
title: Card title card: Card content dict (Feishu card JSON schema).
content: Card content (markdown supported)
buttons: List of button dicts with "text" and "value" keys
""" """
elements = [ loop = asyncio.get_running_loop()
{
"tag": "div",
"text": {
"tag": "lark_md",
"content": content,
},
},
]
if buttons:
actions = []
for btn in buttons:
actions.append({
"tag": "button",
"text": {"tag": "plain_text", "content": btn.get("text", "Button")},
"type": "primary",
"value": btn.get("value", {}),
})
elements.append({"tag": "action", "actions": actions})
card = {
"type": "template",
"data": {
"template_id": "AAqkz9****",
"template_variable": {
"title": title,
"elements": elements,
},
},
}
card_content = {
"config": {"wide_screen_mode": True},
"header": {
"title": {"tag": "plain_text", "content": title},
"template": "blue",
},
"elements": elements,
}
loop = asyncio.get_event_loop()
request = ( request = (
CreateMessageRequest.builder() CreateMessageRequest.builder()
.receive_id_type(receive_id_type) .receive_id_type(receive_id_type)
@ -168,7 +125,7 @@ async def send_card(receive_id: str, receive_id_type: str, title: str, content:
CreateMessageRequestBody.builder() CreateMessageRequestBody.builder()
.receive_id(receive_id) .receive_id(receive_id)
.msg_type("interactive") .msg_type("interactive")
.content(json.dumps(card_content, ensure_ascii=False)) .content(json.dumps(card, ensure_ascii=False))
.build() .build()
) )
.build() .build()
@ -185,6 +142,32 @@ async def send_card(receive_id: str, receive_id_type: str, title: str, content:
logger.debug("Sent card to %s (%s)", receive_id, receive_id_type) logger.debug("Sent card to %s (%s)", receive_id, receive_id_type)
def build_sessions_card(sessions: list[dict], active_conv_id: str | None, mode: str) -> dict:
"""Build a card showing all sessions with active marker and mode info."""
if sessions:
lines = []
for i, s in enumerate(sessions, 1):
marker = "" if s["conv_id"] == active_conv_id else " "
started = "🟢" if s["started"] else "🟡"
lines.append(f"{marker} {i}. {started} `{s['conv_id']}` — `{s['cwd']}`")
sessions_md = "\n".join(lines)
else:
sessions_md = "_No active sessions_"
content = f"{sessions_md}\n\n**Mode:** {mode}"
return {
"config": {"wide_screen_mode": True},
"header": {
"title": {"tag": "plain_text", "content": "Claude Code Sessions"},
"template": "turquoise",
},
"elements": [
{"tag": "div", "text": {"tag": "lark_md", "content": content}},
],
}
async def send_file(receive_id: str, receive_id_type: str, file_path: str, file_type: str = "stream") -> None: async def send_file(receive_id: str, receive_id_type: str, file_path: str, file_type: str = "stream") -> None:
""" """
Upload a local file to Feishu and send it as a file message. Upload a local file to Feishu and send it as a file message.
@ -198,7 +181,7 @@ async def send_file(receive_id: str, receive_id_type: str, file_path: str, file_
import os import os
path = os.path.abspath(file_path) path = os.path.abspath(file_path)
file_name = os.path.basename(path) file_name = os.path.basename(path)
loop = asyncio.get_event_loop() loop = asyncio.get_running_loop()
# Step 1: Upload file → get file_key # Step 1: Upload file → get file_key
with open(path, "rb") as f: with open(path, "rb") as f:
@ -259,27 +242,3 @@ async def send_file(receive_id: str, receive_id_type: str, file_path: str, file_
) )
else: else:
logger.debug("Sent file %r to %s (%s)", file_name, receive_id, receive_id_type) logger.debug("Sent file %r to %s (%s)", file_name, receive_id, receive_id_type)
def build_session_card(conv_id: str, cwd: str, started: bool) -> dict:
"""Build a session status card."""
status = "🟢 Active" if started else "🟡 Ready"
content = f"**Session ID:** `{conv_id}`\n**Directory:** `{cwd}`\n**Status:** {status}"
return {
"config": {"wide_screen_mode": True},
"header": {
"title": {"tag": "plain_text", "content": "Claude Code Session"},
"template": "turquoise",
},
"elements": [
{"tag": "div", "text": {"tag": "lark_md", "content": content}},
{"tag": "hr"},
{
"tag": "action",
"actions": [
{"tag": "button", "text": {"tag": "plain_text", "content": "Continue"}, "type": "primary", "value": {"action": "continue", "conv_id": conv_id}},
{"tag": "button", "text": {"tag": "plain_text", "content": "Close"}, "type": "default", "value": {"action": "close", "conv_id": conv_id}},
],
},
],
}

View File

@ -7,6 +7,7 @@ import json
import logging import logging
import threading import threading
import time import time
from typing import Any, Dict
import lark_oapi as lark import lark_oapi as lark
from lark_oapi.api.im.v1 import P2ImMessageReceiveV1 from lark_oapi.api.im.v1 import P2ImMessageReceiveV1
@ -25,7 +26,7 @@ _last_message_time: float = 0.0
_reconnect_count: int = 0 _reconnect_count: int = 0
def get_ws_status() -> dict: def get_ws_status() -> dict[str, Any]:
"""Return WebSocket connection status.""" """Return WebSocket connection status."""
return { return {
"connected": _ws_connected, "connected": _ws_connected,
@ -39,8 +40,17 @@ def _handle_message(data: P2ImMessageReceiveV1) -> None:
_last_message_time = time.time() _last_message_time = time.time()
try: try:
message = data.event.message event = data.event
sender = data.event.sender if event is None:
logger.warning("Received event with no data")
return
message = event.message
sender = event.sender
if message is None:
logger.warning("Received event with no message")
return
logger.debug( logger.debug(
"event type=%r chat_type=%r content=%r", "event type=%r chat_type=%r content=%r",
@ -53,7 +63,7 @@ def _handle_message(data: P2ImMessageReceiveV1) -> None:
logger.info("Skipping non-text message_type=%r", message.message_type) logger.info("Skipping non-text message_type=%r", message.message_type)
return return
chat_id: str = message.chat_id chat_id: str = message.chat_id or ""
raw_content: str = message.content or "{}" raw_content: str = message.content or "{}"
content_obj = json.loads(raw_content) content_obj = json.loads(raw_content)
text: str = content_obj.get("text", "").strip() text: str = content_obj.get("text", "").strip()
@ -119,8 +129,9 @@ async def _process_message(user_id: str, chat_id: str, text: str) -> None:
if nodes: if nodes:
lines = ["Connected Nodes:"] lines = ["Connected Nodes:"]
for n in nodes: for n in nodes:
marker = "" if n.get("node_id") == registry.get_active_node(user_id) else " " active_node = registry.get_active_node(user_id)
lines.append(f"{marker}{n['display_name']} sessions={n['sessions']} {n['status']}") marker = "" if n.get("node_id") == (active_node.node_id if active_node else None) else " "
lines.append(f"{marker}{n.get('display_name', 'unknown')} sessions={n.get('sessions', 0)} {n.get('status', 'unknown')}")
lines.append("\nUse \"/node <name>\" to switch active node.") lines.append("\nUse \"/node <name>\" to switch active node.")
await send_text(chat_id, "chat_id", "\n".join(lines)) await send_text(chat_id, "chat_id", "\n".join(lines))
else: else:
@ -144,7 +155,9 @@ async def _process_message(user_id: str, chat_id: str, text: str) -> None:
def _handle_any(data: lark.CustomizedEvent) -> None: def _handle_any(data: lark.CustomizedEvent) -> None:
"""Catch-all handler to log any event the SDK receives.""" """Catch-all handler to log any event the SDK receives."""
logger.info("RAW CustomizedEvent: %s", lark.JSON.marshal(data)[:500]) marshaled = lark.JSON.marshal(data)
if marshaled:
logger.info("RAW CustomizedEvent: %s", marshaled[:500])
def build_event_handler() -> lark.EventDispatcherHandler: def build_event_handler() -> lark.EventDispatcherHandler:

View File

@ -1,11 +1,11 @@
import yaml import yaml
from pathlib import Path from pathlib import Path
from typing import List, Optional from typing import Any, Dict, List, Optional
_CONFIG_PATH = Path(__file__).parent / "keyring.yaml" _CONFIG_PATH = Path(__file__).parent / "keyring.yaml"
def _load() -> dict: def _load() -> dict[str, Any]:
with open(_CONFIG_PATH, "r", encoding="utf-8") as f: with open(_CONFIG_PATH, "r", encoding="utf-8") as f:
return yaml.safe_load(f) or {} return yaml.safe_load(f) or {}
@ -23,9 +23,8 @@ METASO_API_KEY: str = _cfg.get("METASO_API_KEY", "")
ROUTER_MODE: bool = _cfg.get("ROUTER_MODE", False) ROUTER_MODE: bool = _cfg.get("ROUTER_MODE", False)
ROUTER_SECRET: str = _cfg.get("ROUTER_SECRET", "") ROUTER_SECRET: str = _cfg.get("ROUTER_SECRET", "")
ALLOWED_OPEN_IDS: List[str] = _cfg.get("ALLOWED_OPEN_IDS", []) _allowed_open_ids_raw = _cfg.get("ALLOWED_OPEN_IDS", [])
if ALLOWED_OPEN_IDS and not isinstance(ALLOWED_OPEN_IDS, list): ALLOWED_OPEN_IDS: list[str] = _allowed_open_ids_raw if isinstance(_allowed_open_ids_raw, list) else [str(_allowed_open_ids_raw)]
ALLOWED_OPEN_IDS = [str(ALLOWED_OPEN_IDS)]
def is_user_allowed(open_id: str) -> bool: def is_user_allowed(open_id: str) -> bool:

View File

@ -46,9 +46,9 @@ class HostConfig:
self.metaso_api_key: Optional[str] = data.get("METASO_API_KEY") self.metaso_api_key: Optional[str] = data.get("METASO_API_KEY")
serves_users = data.get("SERVES_USERS", []) serves_users = data.get("SERVES_USERS", [])
self.serves_users: List[str] = serves_users if isinstance(serves_users, list) else [] self.serves_users: list[str] = serves_users if isinstance(serves_users, list) else []
self.capabilities: List[str] = data.get( self.capabilities: list[str] = data.get(
"CAPABILITIES", "CAPABILITIES",
["claude_code", "shell", "file_ops", "web", "scheduler"], ["claude_code", "shell", "file_ops", "web", "scheduler"],
) )

View File

@ -10,16 +10,15 @@ import asyncio
import logging import logging
import secrets import secrets
import time import time
from typing import Optional from typing import Any, Optional
import websockets import websockets
from websockets.client import WebSocketClientProtocol
from agent.manager import manager from agent.manager import manager
from agent.scheduler import scheduler from agent.scheduler import scheduler
from agent.task_runner import task_runner from agent.task_runner import task_runner
from host_client.config import HostConfig, get_host_config from host_client.config import HostConfig, get_host_config
from orchestrator.agent import run as run_mailboy from orchestrator.agent import agent
from orchestrator.tools import set_current_user, set_current_chat from orchestrator.tools import set_current_user, set_current_chat
from shared import ( from shared import (
RegisterMessage, RegisterMessage,
@ -40,7 +39,7 @@ class NodeClient:
def __init__(self, config: HostConfig): def __init__(self, config: HostConfig):
self.config = config self.config = config
self.ws: Optional[WebSocketClientProtocol] = None self.ws: Any = None
self._running = False self._running = False
self._last_heartbeat = time.time() self._last_heartbeat = time.time()
self._reconnect_delay = 1.0 self._reconnect_delay = 1.0
@ -94,7 +93,7 @@ class NodeClient:
set_current_chat(request.chat_id) set_current_chat(request.chat_id)
try: try:
reply = await run_mailboy(request.user_id, request.text) reply = await agent.run(request.user_id, request.text)
response = ForwardResponse( response = ForwardResponse(
id=request.id, id=request.id,
@ -131,7 +130,7 @@ class NodeClient:
sessions = manager.list_sessions() sessions = manager.list_sessions()
active_sessions = [ active_sessions = [
{"conv_id": s["conv_id"], "working_dir": s["working_dir"]} {"conv_id": s["conv_id"], "working_dir": s["cwd"]}
for s in sessions for s in sessions
] ]

View File

@ -91,7 +91,7 @@ async def startup_event() -> None:
await manager.start() await manager.start()
from agent.scheduler import scheduler from agent.scheduler import scheduler
await scheduler.start() await scheduler.start()
loop = asyncio.get_event_loop() loop = asyncio.get_running_loop()
start_websocket_client(loop) start_websocket_client(loop)
logger.info("PhoneWork started") logger.info("PhoneWork started")

View File

@ -97,18 +97,18 @@ class OrchestrationAgent:
base_url=OPENAI_BASE_URL, base_url=OPENAI_BASE_URL,
api_key=OPENAI_API_KEY, api_key=OPENAI_API_KEY,
model=OPENAI_MODEL, model=OPENAI_MODEL,
temperature=0.0, temperature=0.1,
) )
self._llm_with_tools = llm.bind_tools(TOOLS) self._llm_with_tools = llm.bind_tools(TOOLS)
# user_id -> list[BaseMessage] # user_id -> list[BaseMessage]
self._history: Dict[str, List[BaseMessage]] = defaultdict(list) self._history: dict[str, list[BaseMessage]] = defaultdict(list)
# user_id -> most recently active conv_id # user_id -> most recently active conv_id
self._active_conv: Dict[str, Optional[str]] = defaultdict(lambda: None) self._active_conv: dict[str, Optional[str]] = defaultdict(lambda: None)
# user_id -> asyncio.Lock (prevents concurrent processing per user) # user_id -> asyncio.Lock (prevents concurrent processing per user)
self._user_locks: Dict[str, asyncio.Lock] = defaultdict(asyncio.Lock) self._user_locks: dict[str, asyncio.Lock] = defaultdict(asyncio.Lock)
# user_id -> passthrough mode enabled # user_id -> passthrough mode enabled
self._passthrough: Dict[str, bool] = defaultdict(lambda: False) self._passthrough: dict[str, bool] = defaultdict(lambda: False)
def _build_system_prompt(self, user_id: str) -> str: def _build_system_prompt(self, user_id: str) -> str:
conv_id = self._active_conv[user_id] conv_id = self._active_conv[user_id]
@ -173,7 +173,7 @@ class OrchestrationAgent:
response = await llm_no_tools.ainvoke([HumanMessage(content=qa_prompt)]) response = await llm_no_tools.ainvoke([HumanMessage(content=qa_prompt)])
return response.content or "" return response.content or ""
messages: List[BaseMessage] = ( messages: list[BaseMessage] = (
[SystemMessage(content=self._build_system_prompt(user_id))] [SystemMessage(content=self._build_system_prompt(user_id))]
+ self._history[user_id] + self._history[user_id]
+ [HumanMessage(content=text)] + [HumanMessage(content=text)]

View File

@ -301,7 +301,8 @@ class FileReadTool(BaseTool):
async def _arun(self, path: str, start_line: Optional[int] = None, end_line: Optional[int] = None) -> str: async def _arun(self, path: str, start_line: Optional[int] = None, end_line: Optional[int] = None) -> str:
try: try:
file_path = _resolve_dir(path) p = Path(path.strip())
file_path = _resolve_dir(str(p.parent)) / p.name if not p.is_absolute() else p.resolve()
if not file_path.is_file(): if not file_path.is_file():
return json.dumps({"error": f"Not a file: {path}"}, ensure_ascii=False) return json.dumps({"error": f"Not a file: {path}"}, ensure_ascii=False)
@ -342,7 +343,8 @@ class FileWriteTool(BaseTool):
async def _arun(self, path: str, content: str, mode: Optional[str] = "overwrite") -> str: async def _arun(self, path: str, content: str, mode: Optional[str] = "overwrite") -> str:
try: try:
file_path = _resolve_dir(path) p = Path(path.strip())
file_path = (_resolve_dir(str(p.parent)) / p.name) if not p.is_absolute() else p.resolve()
file_path.parent.mkdir(parents=True, exist_ok=True) file_path.parent.mkdir(parents=True, exist_ok=True)
write_mode = "a" if mode == "append" else "w" write_mode = "a" if mode == "append" else "w"
@ -484,7 +486,8 @@ class FileSendTool(BaseTool):
async def _arun(self, path: str) -> str: async def _arun(self, path: str) -> str:
try: try:
file_path = _resolve_dir(path) p = Path(path.strip())
file_path = _resolve_dir(str(p.parent)) / p.name if not p.is_absolute() else p.resolve()
if not file_path.is_file(): if not file_path.is_file():
return json.dumps({"error": f"Not a file: {path}"}, ensure_ascii=False) return json.dumps({"error": f"Not a file: {path}"}, ensure_ascii=False)

View File

@ -43,6 +43,7 @@ def create_app(router_secret: Optional[str] = None) -> FastAPI:
@app.get("/health") @app.get("/health")
async def health(): async def health():
from router.rpc import get_pending_count
nodes = registry.list_nodes() nodes = registry.list_nodes()
online_nodes = [n for n in nodes if n["status"] == "online"] online_nodes = [n for n in nodes if n["status"] == "online"]
return { return {
@ -50,7 +51,7 @@ def create_app(router_secret: Optional[str] = None) -> FastAPI:
"nodes": nodes, "nodes": nodes,
"online_nodes": len(online_nodes), "online_nodes": len(online_nodes),
"total_nodes": len(nodes), "total_nodes": len(nodes),
"pending_requests": 0, "pending_requests": get_pending_count(),
} }
@app.get("/nodes") @app.get("/nodes")
@ -64,7 +65,7 @@ def create_app(router_secret: Optional[str] = None) -> FastAPI:
@app.on_event("startup") @app.on_event("startup")
async def startup(): async def startup():
import asyncio import asyncio
loop = asyncio.get_event_loop() loop = asyncio.get_running_loop()
start_websocket_client(loop) start_websocket_client(loop)
logger.info("Router started") logger.info("Router started")

View File

@ -27,18 +27,18 @@ class NodeConnection:
display_name: str = "" display_name: str = ""
serves_users: Set[str] = field(default_factory=set) serves_users: Set[str] = field(default_factory=set)
working_dir: str = "" working_dir: str = ""
capabilities: List[str] = field(default_factory=list) capabilities: list[str] = field(default_factory=list)
connected_at: float = field(default_factory=time.time) connected_at: float = field(default_factory=time.time)
last_heartbeat: float = field(default_factory=time.time) last_heartbeat: float = field(default_factory=time.time)
sessions: int = 0 sessions: int = 0
active_sessions: List[Dict[str, Any]] = field(default_factory=list) active_sessions: list[dict[str, Any]] = field(default_factory=list)
@property @property
def is_online(self) -> bool: def is_online(self) -> bool:
"""Check if node is still considered online (heartbeat within 60s).""" """Check if node is still considered online (heartbeat within 60s)."""
return time.time() - self.last_heartbeat < 60 return time.time() - self.last_heartbeat < 60
def to_dict(self) -> Dict[str, Any]: def to_dict(self) -> dict[str, Any]:
"""Serialize for API responses.""" """Serialize for API responses."""
return { return {
"node_id": self.node_id, "node_id": self.node_id,
@ -56,9 +56,9 @@ class NodeRegistry:
"""Registry of connected host clients.""" """Registry of connected host clients."""
def __init__(self, router_secret: str = ""): def __init__(self, router_secret: str = ""):
self._nodes: Dict[str, NodeConnection] = {} self._nodes: dict[str, NodeConnection] = {}
self._user_nodes: Dict[str, Set[str]] = {} self._user_nodes: dict[str, Set[str]] = {}
self._active_node: Dict[str, str] = {} self._active_node: dict[str, str] = {}
self._secret = router_secret self._secret = router_secret
self._lock = asyncio.Lock() self._lock = asyncio.Lock()
@ -159,7 +159,7 @@ class NodeRegistry:
"""Get a node by ID.""" """Get a node by ID."""
return self._nodes.get(node_id) return self._nodes.get(node_id)
def get_nodes_for_user(self, user_id: str) -> List[NodeConnection]: def get_nodes_for_user(self, user_id: str) -> list[NodeConnection]:
"""Get all nodes that serve a user.""" """Get all nodes that serve a user."""
node_ids = self._user_nodes.get(user_id, set()) node_ids = self._user_nodes.get(user_id, set())
return [self._nodes[nid] for nid in node_ids if nid in self._nodes] return [self._nodes[nid] for nid in node_ids if nid in self._nodes]
@ -186,11 +186,11 @@ class NodeRegistry:
logger.info("Active node for user %s set to %s", user_id, node_id) logger.info("Active node for user %s set to %s", user_id, node_id)
return True return True
def list_nodes(self) -> List[Dict[str, Any]]: def list_nodes(self) -> list[dict[str, Any]]:
"""List all nodes with their status.""" """List all nodes with their status."""
return [node.to_dict() for node in self._nodes.values()] return [node.to_dict() for node in self._nodes.values()]
def get_affected_users(self, node_id: str) -> List[str]: def get_affected_users(self, node_id: str) -> list[str]:
"""Get users affected by a node disconnect.""" """Get users affected by a node disconnect."""
node = self._nodes.get(node_id) node = self._nodes.get(node_id)
if node: if node:

View File

@ -12,6 +12,7 @@ from typing import List, Optional
from langchain_core.messages import HumanMessage, SystemMessage from langchain_core.messages import HumanMessage, SystemMessage
from langchain_openai import ChatOpenAI from langchain_openai import ChatOpenAI
from pydantic import SecretStr
from config import OPENAI_API_KEY, OPENAI_BASE_URL, OPENAI_MODEL from config import OPENAI_API_KEY, OPENAI_BASE_URL, OPENAI_MODEL
from router.nodes import NodeConnection, get_node_registry from router.nodes import NodeConnection, get_node_registry
@ -36,7 +37,7 @@ Respond with a JSON object:
""" """
def _format_nodes_info(nodes: List[NodeConnection], active_node_id: Optional[str] = None) -> str: def _format_nodes_info(nodes: list[NodeConnection], active_node_id: Optional[str] = None) -> str:
"""Format node information for the routing prompt.""" """Format node information for the routing prompt."""
lines = [] lines = []
for node in nodes: for node in nodes:
@ -86,8 +87,8 @@ async def route(user_id: str, chat_id: str, text: str) -> tuple[Optional[str], s
try: try:
llm = ChatOpenAI( llm = ChatOpenAI(
model=OPENAI_MODEL, model=OPENAI_MODEL,
openai_api_key=OPENAI_API_KEY, api_key=SecretStr(OPENAI_API_KEY),
openai_api_base=OPENAI_BASE_URL, base_url=OPENAI_BASE_URL,
temperature=0, temperature=0,
) )
@ -98,7 +99,11 @@ async def route(user_id: str, chat_id: str, text: str) -> tuple[Optional[str], s
] ]
response = await llm.ainvoke(messages) response = await llm.ainvoke(messages)
content = response.content.strip() content = response.content
if isinstance(content, str):
content = content.strip()
else:
content = str(content).strip()
if content.startswith("```"): if content.startswith("```"):
content = content.split("\n", 1)[1] content = content.split("\n", 1)[1]

View File

@ -18,7 +18,7 @@ from router.nodes import get_node_registry
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
_pending_requests: Dict[str, asyncio.Future] = {} _pending_requests: dict[str, asyncio.Future[str]] = {}
_default_timeout = 600.0 _default_timeout = 600.0
@ -52,7 +52,7 @@ async def forward(
raise RuntimeError(f"Node not connected: {node_id}") raise RuntimeError(f"Node not connected: {node_id}")
request_id = str(uuid.uuid4()) request_id = str(uuid.uuid4())
future: asyncio.Future = asyncio.get_event_loop().create_future() future: asyncio.Future[str] = asyncio.get_running_loop().create_future()
_pending_requests[request_id] = future _pending_requests[request_id] = future
request = ForwardRequest( request = ForwardRequest(

View File

@ -47,7 +47,7 @@ async def ws_node_endpoint(websocket: WebSocket) -> None:
return return
node_id: Optional[str] = None node_id: Optional[str] = None
heartbeat_task: Optional[asyncio.Task] = None heartbeat_task: Optional[asyncio.Task[None]] = None
async def send_heartbeat(): async def send_heartbeat():
"""Send periodic pings to the host client.""" """Send periodic pings to the host client."""

View File

@ -16,9 +16,9 @@ class RegisterMessage:
"""Host client -> Router: Register this node.""" """Host client -> Router: Register this node."""
type: str = "register" type: str = "register"
node_id: str = "" node_id: str = ""
serves_users: List[str] = field(default_factory=list) serves_users: list[str] = field(default_factory=list)
working_dir: str = "" working_dir: str = ""
capabilities: List[str] = field(default_factory=list) capabilities: list[str] = field(default_factory=list)
display_name: str = "" display_name: str = ""
@ -63,7 +63,7 @@ class NodeStatus:
type: str = "node_status" type: str = "node_status"
node_id: str = "" node_id: str = ""
sessions: int = 0 sessions: int = 0
active_sessions: List[Dict[str, Any]] = field(default_factory=list) active_sessions: list[dict[str, Any]] = field(default_factory=list)
MESSAGE_TYPES = { MESSAGE_TYPES = {