diff --git a/agent/manager.py b/agent/manager.py index 851767e..6d805ad 100644 --- a/agent/manager.py +++ b/agent/manager.py @@ -46,7 +46,7 @@ class SessionManager: """Registry of active Claude Code project sessions with persistence and user isolation.""" def __init__(self) -> None: - self._sessions: Dict[str, Session] = {} + self._sessions: dict[str, Session] = {} self._lock = asyncio.Lock() self._reaper_task: Optional[asyncio.Task] = None diff --git a/agent/scheduler.py b/agent/scheduler.py index 6b9affe..a997166 100644 --- a/agent/scheduler.py +++ b/agent/scheduler.py @@ -53,8 +53,8 @@ class Scheduler: """Singleton that manages scheduled jobs with Feishu notifications.""" def __init__(self) -> None: - self._jobs: Dict[str, ScheduledJob] = {} - self._tasks: Dict[str, asyncio.Task] = {} + self._jobs: dict[str, ScheduledJob] = {} + self._tasks: dict[str, asyncio.Task] = {} self._lock = asyncio.Lock() self._started = False diff --git a/agent/task_runner.py b/agent/task_runner.py index 58b3fcf..460cb07 100644 --- a/agent/task_runner.py +++ b/agent/task_runner.py @@ -8,7 +8,7 @@ import time import uuid from dataclasses import dataclass, field from enum import Enum -from typing import Any, Callable, Dict, Optional +from typing import Any, Awaitable, Callable, Dict, Optional logger = logging.getLogger(__name__) @@ -43,17 +43,17 @@ class TaskRunner: """Singleton that manages background tasks with Feishu notifications.""" def __init__(self) -> None: - self._tasks: Dict[str, BackgroundTask] = {} + self._tasks: dict[str, BackgroundTask] = {} self._lock = asyncio.Lock() - self._notification_handler: Optional[Callable] = None + self._notification_handler: Optional[Callable[[BackgroundTask], Awaitable[None]]] = None - def set_notification_handler(self, handler: Optional[Callable]) -> None: + def set_notification_handler(self, handler: Optional[Callable[[BackgroundTask], Awaitable[None]]]) -> None: """Set custom notification handler for M3 mode (host client -> router).""" self._notification_handler = handler async def submit( self, - coro: Callable[[], Any], + coro: Awaitable[Any], description: str, notify_chat_id: Optional[str] = None, user_id: Optional[str] = None, @@ -76,7 +76,7 @@ class TaskRunner: logger.info("Submitted background task %s: %s", task_id, description) return task_id - async def _run_task(self, task_id: str, coro: Callable[[], Any]) -> None: + async def _run_task(self, task_id: str, coro: Awaitable[Any]) -> None: """Execute a task and send notification on completion.""" async with self._lock: task = self._tasks.get(task_id) @@ -130,14 +130,15 @@ class TaskRunner: msg += f"\n\n**Error:** {task.error}" try: - await send_text(task.notify_chat_id, "chat_id", msg) + if task.notify_chat_id: + await send_text(task.notify_chat_id, "chat_id", msg) except Exception: logger.exception("Failed to send notification for task %s", task.task_id) def get_task(self, task_id: str) -> Optional[BackgroundTask]: return self._tasks.get(task_id) - def list_tasks(self, limit: int = 20) -> list[dict]: + def list_tasks(self, limit: int = 20) -> list[dict[str, Any]]: tasks = sorted( self._tasks.values(), key=lambda t: t.started_at, diff --git a/bot/handler.py b/bot/handler.py index ae34c4c..0d142d6 100644 --- a/bot/handler.py +++ b/bot/handler.py @@ -7,6 +7,7 @@ import json import logging import threading import time +from typing import Any, Dict import lark_oapi as lark from lark_oapi.api.im.v1 import P2ImMessageReceiveV1 @@ -25,7 +26,7 @@ _last_message_time: float = 0.0 _reconnect_count: int = 0 -def get_ws_status() -> dict: +def get_ws_status() -> dict[str, Any]: """Return WebSocket connection status.""" return { "connected": _ws_connected, @@ -39,8 +40,17 @@ def _handle_message(data: P2ImMessageReceiveV1) -> None: _last_message_time = time.time() try: - message = data.event.message - sender = data.event.sender + event = data.event + if event is None: + logger.warning("Received event with no data") + return + + message = event.message + sender = event.sender + + if message is None: + logger.warning("Received event with no message") + return logger.debug( "event type=%r chat_type=%r content=%r", @@ -53,7 +63,7 @@ def _handle_message(data: P2ImMessageReceiveV1) -> None: logger.info("Skipping non-text message_type=%r", message.message_type) return - chat_id: str = message.chat_id + chat_id: str = message.chat_id or "" raw_content: str = message.content or "{}" content_obj = json.loads(raw_content) text: str = content_obj.get("text", "").strip() @@ -119,8 +129,9 @@ async def _process_message(user_id: str, chat_id: str, text: str) -> None: if nodes: lines = ["Connected Nodes:"] for n in nodes: - marker = " → " if n.get("node_id") == registry.get_active_node(user_id) else " " - lines.append(f"{marker}{n['display_name']} sessions={n['sessions']} {n['status']}") + active_node = registry.get_active_node(user_id) + marker = " → " if n.get("node_id") == (active_node.node_id if active_node else None) else " " + lines.append(f"{marker}{n.get('display_name', 'unknown')} sessions={n.get('sessions', 0)} {n.get('status', 'unknown')}") lines.append("\nUse \"/node \" to switch active node.") await send_text(chat_id, "chat_id", "\n".join(lines)) else: @@ -144,7 +155,9 @@ async def _process_message(user_id: str, chat_id: str, text: str) -> None: def _handle_any(data: lark.CustomizedEvent) -> None: """Catch-all handler to log any event the SDK receives.""" - logger.info("RAW CustomizedEvent: %s", lark.JSON.marshal(data)[:500]) + marshaled = lark.JSON.marshal(data) + if marshaled: + logger.info("RAW CustomizedEvent: %s", marshaled[:500]) def build_event_handler() -> lark.EventDispatcherHandler: diff --git a/config.py b/config.py index bf97777..67e717c 100644 --- a/config.py +++ b/config.py @@ -1,11 +1,11 @@ import yaml from pathlib import Path -from typing import List, Optional +from typing import Any, Dict, List, Optional _CONFIG_PATH = Path(__file__).parent / "keyring.yaml" -def _load() -> dict: +def _load() -> dict[str, Any]: with open(_CONFIG_PATH, "r", encoding="utf-8") as f: return yaml.safe_load(f) or {} @@ -23,9 +23,8 @@ METASO_API_KEY: str = _cfg.get("METASO_API_KEY", "") ROUTER_MODE: bool = _cfg.get("ROUTER_MODE", False) ROUTER_SECRET: str = _cfg.get("ROUTER_SECRET", "") -ALLOWED_OPEN_IDS: List[str] = _cfg.get("ALLOWED_OPEN_IDS", []) -if ALLOWED_OPEN_IDS and not isinstance(ALLOWED_OPEN_IDS, list): - ALLOWED_OPEN_IDS = [str(ALLOWED_OPEN_IDS)] +_allowed_open_ids_raw = _cfg.get("ALLOWED_OPEN_IDS", []) +ALLOWED_OPEN_IDS: list[str] = _allowed_open_ids_raw if isinstance(_allowed_open_ids_raw, list) else [str(_allowed_open_ids_raw)] def is_user_allowed(open_id: str) -> bool: diff --git a/host_client/config.py b/host_client/config.py index 63bfeb4..1b4d863 100644 --- a/host_client/config.py +++ b/host_client/config.py @@ -46,9 +46,9 @@ class HostConfig: self.metaso_api_key: Optional[str] = data.get("METASO_API_KEY") serves_users = data.get("SERVES_USERS", []) - self.serves_users: List[str] = serves_users if isinstance(serves_users, list) else [] + self.serves_users: list[str] = serves_users if isinstance(serves_users, list) else [] - self.capabilities: List[str] = data.get( + self.capabilities: list[str] = data.get( "CAPABILITIES", ["claude_code", "shell", "file_ops", "web", "scheduler"], ) diff --git a/host_client/main.py b/host_client/main.py index 3f4548f..94e7075 100644 --- a/host_client/main.py +++ b/host_client/main.py @@ -10,16 +10,15 @@ import asyncio import logging import secrets import time -from typing import Optional +from typing import Any, Optional import websockets -from websockets.client import WebSocketClientProtocol from agent.manager import manager from agent.scheduler import scheduler from agent.task_runner import task_runner from host_client.config import HostConfig, get_host_config -from orchestrator.agent import run as run_mailboy +from orchestrator.agent import agent from orchestrator.tools import set_current_user, set_current_chat from shared import ( RegisterMessage, @@ -40,7 +39,7 @@ class NodeClient: def __init__(self, config: HostConfig): self.config = config - self.ws: Optional[WebSocketClientProtocol] = None + self.ws: Any = None self._running = False self._last_heartbeat = time.time() self._reconnect_delay = 1.0 @@ -94,7 +93,7 @@ class NodeClient: set_current_chat(request.chat_id) try: - reply = await run_mailboy(request.user_id, request.text) + reply = await agent.run(request.user_id, request.text) response = ForwardResponse( id=request.id, diff --git a/orchestrator/agent.py b/orchestrator/agent.py index d276398..7619e5d 100644 --- a/orchestrator/agent.py +++ b/orchestrator/agent.py @@ -97,18 +97,18 @@ class OrchestrationAgent: base_url=OPENAI_BASE_URL, api_key=OPENAI_API_KEY, model=OPENAI_MODEL, - temperature=0.0, + temperature=0.1, ) self._llm_with_tools = llm.bind_tools(TOOLS) # user_id -> list[BaseMessage] - self._history: Dict[str, List[BaseMessage]] = defaultdict(list) + self._history: dict[str, list[BaseMessage]] = defaultdict(list) # user_id -> most recently active conv_id - self._active_conv: Dict[str, Optional[str]] = defaultdict(lambda: None) + self._active_conv: dict[str, Optional[str]] = defaultdict(lambda: None) # user_id -> asyncio.Lock (prevents concurrent processing per user) - self._user_locks: Dict[str, asyncio.Lock] = defaultdict(asyncio.Lock) + self._user_locks: dict[str, asyncio.Lock] = defaultdict(asyncio.Lock) # user_id -> passthrough mode enabled - self._passthrough: Dict[str, bool] = defaultdict(lambda: False) + self._passthrough: dict[str, bool] = defaultdict(lambda: False) def _build_system_prompt(self, user_id: str) -> str: conv_id = self._active_conv[user_id] @@ -173,7 +173,7 @@ class OrchestrationAgent: response = await llm_no_tools.ainvoke([HumanMessage(content=qa_prompt)]) return response.content or "" - messages: List[BaseMessage] = ( + messages: list[BaseMessage] = ( [SystemMessage(content=self._build_system_prompt(user_id))] + self._history[user_id] + [HumanMessage(content=text)] diff --git a/router/nodes.py b/router/nodes.py index aba26d2..783a4ab 100644 --- a/router/nodes.py +++ b/router/nodes.py @@ -27,18 +27,18 @@ class NodeConnection: display_name: str = "" serves_users: Set[str] = field(default_factory=set) working_dir: str = "" - capabilities: List[str] = field(default_factory=list) + capabilities: list[str] = field(default_factory=list) connected_at: float = field(default_factory=time.time) last_heartbeat: float = field(default_factory=time.time) sessions: int = 0 - active_sessions: List[Dict[str, Any]] = field(default_factory=list) + active_sessions: list[dict[str, Any]] = field(default_factory=list) @property def is_online(self) -> bool: """Check if node is still considered online (heartbeat within 60s).""" return time.time() - self.last_heartbeat < 60 - def to_dict(self) -> Dict[str, Any]: + def to_dict(self) -> dict[str, Any]: """Serialize for API responses.""" return { "node_id": self.node_id, @@ -56,9 +56,9 @@ class NodeRegistry: """Registry of connected host clients.""" def __init__(self, router_secret: str = ""): - self._nodes: Dict[str, NodeConnection] = {} - self._user_nodes: Dict[str, Set[str]] = {} - self._active_node: Dict[str, str] = {} + self._nodes: dict[str, NodeConnection] = {} + self._user_nodes: dict[str, Set[str]] = {} + self._active_node: dict[str, str] = {} self._secret = router_secret self._lock = asyncio.Lock() @@ -159,7 +159,7 @@ class NodeRegistry: """Get a node by ID.""" return self._nodes.get(node_id) - def get_nodes_for_user(self, user_id: str) -> List[NodeConnection]: + def get_nodes_for_user(self, user_id: str) -> list[NodeConnection]: """Get all nodes that serve a user.""" node_ids = self._user_nodes.get(user_id, set()) return [self._nodes[nid] for nid in node_ids if nid in self._nodes] @@ -186,11 +186,11 @@ class NodeRegistry: logger.info("Active node for user %s set to %s", user_id, node_id) return True - def list_nodes(self) -> List[Dict[str, Any]]: + def list_nodes(self) -> list[dict[str, Any]]: """List all nodes with their status.""" return [node.to_dict() for node in self._nodes.values()] - def get_affected_users(self, node_id: str) -> List[str]: + def get_affected_users(self, node_id: str) -> list[str]: """Get users affected by a node disconnect.""" node = self._nodes.get(node_id) if node: diff --git a/router/routing_agent.py b/router/routing_agent.py index d812569..8fa7314 100644 --- a/router/routing_agent.py +++ b/router/routing_agent.py @@ -12,6 +12,7 @@ from typing import List, Optional from langchain_core.messages import HumanMessage, SystemMessage from langchain_openai import ChatOpenAI +from pydantic import SecretStr from config import OPENAI_API_KEY, OPENAI_BASE_URL, OPENAI_MODEL from router.nodes import NodeConnection, get_node_registry @@ -36,7 +37,7 @@ Respond with a JSON object: """ -def _format_nodes_info(nodes: List[NodeConnection], active_node_id: Optional[str] = None) -> str: +def _format_nodes_info(nodes: list[NodeConnection], active_node_id: Optional[str] = None) -> str: """Format node information for the routing prompt.""" lines = [] for node in nodes: @@ -86,8 +87,8 @@ async def route(user_id: str, chat_id: str, text: str) -> tuple[Optional[str], s try: llm = ChatOpenAI( model=OPENAI_MODEL, - openai_api_key=OPENAI_API_KEY, - openai_api_base=OPENAI_BASE_URL, + api_key=SecretStr(OPENAI_API_KEY), + base_url=OPENAI_BASE_URL, temperature=0, ) @@ -98,7 +99,11 @@ async def route(user_id: str, chat_id: str, text: str) -> tuple[Optional[str], s ] response = await llm.ainvoke(messages) - content = response.content.strip() + content = response.content + if isinstance(content, str): + content = content.strip() + else: + content = str(content).strip() if content.startswith("```"): content = content.split("\n", 1)[1] diff --git a/router/rpc.py b/router/rpc.py index b8e2473..9dc793b 100644 --- a/router/rpc.py +++ b/router/rpc.py @@ -18,7 +18,7 @@ from router.nodes import get_node_registry logger = logging.getLogger(__name__) -_pending_requests: Dict[str, asyncio.Future] = {} +_pending_requests: dict[str, asyncio.Future[str]] = {} _default_timeout = 600.0 @@ -52,7 +52,7 @@ async def forward( raise RuntimeError(f"Node not connected: {node_id}") request_id = str(uuid.uuid4()) - future: asyncio.Future = asyncio.get_event_loop().create_future() + future: asyncio.Future[str] = asyncio.get_event_loop().create_future() _pending_requests[request_id] = future request = ForwardRequest( diff --git a/router/ws.py b/router/ws.py index 52bb7ab..7b688c5 100644 --- a/router/ws.py +++ b/router/ws.py @@ -47,7 +47,7 @@ async def ws_node_endpoint(websocket: WebSocket) -> None: return node_id: Optional[str] = None - heartbeat_task: Optional[asyncio.Task] = None + heartbeat_task: Optional[asyncio.Task[None]] = None async def send_heartbeat(): """Send periodic pings to the host client.""" diff --git a/shared/protocol.py b/shared/protocol.py index 8de2225..4aab117 100644 --- a/shared/protocol.py +++ b/shared/protocol.py @@ -16,9 +16,9 @@ class RegisterMessage: """Host client -> Router: Register this node.""" type: str = "register" node_id: str = "" - serves_users: List[str] = field(default_factory=list) + serves_users: list[str] = field(default_factory=list) working_dir: str = "" - capabilities: List[str] = field(default_factory=list) + capabilities: list[str] = field(default_factory=list) display_name: str = "" @@ -63,7 +63,7 @@ class NodeStatus: type: str = "node_status" node_id: str = "" sessions: int = 0 - active_sessions: List[Dict[str, Any]] = field(default_factory=list) + active_sessions: list[dict[str, Any]] = field(default_factory=list) MESSAGE_TYPES = {