refactor: 统一使用现代类型注解替代传统类型注解

- 将 Dict、List 等传统类型注解替换为 dict、list 等现代类型注解
- 更新类型注解以更精确地反映变量类型
- 修复部分类型注解与实际使用不匹配的问题
- 优化部分代码逻辑以提高类型安全性
This commit is contained in:
Yuyao Huang (Sam) 2026-03-28 14:27:21 +08:00
parent 64297e5e27
commit 09b63341cd
13 changed files with 72 additions and 55 deletions

View File

@ -46,7 +46,7 @@ class SessionManager:
"""Registry of active Claude Code project sessions with persistence and user isolation."""
def __init__(self) -> None:
self._sessions: Dict[str, Session] = {}
self._sessions: dict[str, Session] = {}
self._lock = asyncio.Lock()
self._reaper_task: Optional[asyncio.Task] = None

View File

@ -53,8 +53,8 @@ class Scheduler:
"""Singleton that manages scheduled jobs with Feishu notifications."""
def __init__(self) -> None:
self._jobs: Dict[str, ScheduledJob] = {}
self._tasks: Dict[str, asyncio.Task] = {}
self._jobs: dict[str, ScheduledJob] = {}
self._tasks: dict[str, asyncio.Task] = {}
self._lock = asyncio.Lock()
self._started = False

View File

@ -8,7 +8,7 @@ import time
import uuid
from dataclasses import dataclass, field
from enum import Enum
from typing import Any, Callable, Dict, Optional
from typing import Any, Awaitable, Callable, Dict, Optional
logger = logging.getLogger(__name__)
@ -43,17 +43,17 @@ class TaskRunner:
"""Singleton that manages background tasks with Feishu notifications."""
def __init__(self) -> None:
self._tasks: Dict[str, BackgroundTask] = {}
self._tasks: dict[str, BackgroundTask] = {}
self._lock = asyncio.Lock()
self._notification_handler: Optional[Callable] = None
self._notification_handler: Optional[Callable[[BackgroundTask], Awaitable[None]]] = None
def set_notification_handler(self, handler: Optional[Callable]) -> None:
def set_notification_handler(self, handler: Optional[Callable[[BackgroundTask], Awaitable[None]]]) -> None:
"""Set custom notification handler for M3 mode (host client -> router)."""
self._notification_handler = handler
async def submit(
self,
coro: Callable[[], Any],
coro: Awaitable[Any],
description: str,
notify_chat_id: Optional[str] = None,
user_id: Optional[str] = None,
@ -76,7 +76,7 @@ class TaskRunner:
logger.info("Submitted background task %s: %s", task_id, description)
return task_id
async def _run_task(self, task_id: str, coro: Callable[[], Any]) -> None:
async def _run_task(self, task_id: str, coro: Awaitable[Any]) -> None:
"""Execute a task and send notification on completion."""
async with self._lock:
task = self._tasks.get(task_id)
@ -130,14 +130,15 @@ class TaskRunner:
msg += f"\n\n**Error:** {task.error}"
try:
await send_text(task.notify_chat_id, "chat_id", msg)
if task.notify_chat_id:
await send_text(task.notify_chat_id, "chat_id", msg)
except Exception:
logger.exception("Failed to send notification for task %s", task.task_id)
def get_task(self, task_id: str) -> Optional[BackgroundTask]:
return self._tasks.get(task_id)
def list_tasks(self, limit: int = 20) -> list[dict]:
def list_tasks(self, limit: int = 20) -> list[dict[str, Any]]:
tasks = sorted(
self._tasks.values(),
key=lambda t: t.started_at,

View File

@ -7,6 +7,7 @@ import json
import logging
import threading
import time
from typing import Any, Dict
import lark_oapi as lark
from lark_oapi.api.im.v1 import P2ImMessageReceiveV1
@ -25,7 +26,7 @@ _last_message_time: float = 0.0
_reconnect_count: int = 0
def get_ws_status() -> dict:
def get_ws_status() -> dict[str, Any]:
"""Return WebSocket connection status."""
return {
"connected": _ws_connected,
@ -39,8 +40,17 @@ def _handle_message(data: P2ImMessageReceiveV1) -> None:
_last_message_time = time.time()
try:
message = data.event.message
sender = data.event.sender
event = data.event
if event is None:
logger.warning("Received event with no data")
return
message = event.message
sender = event.sender
if message is None:
logger.warning("Received event with no message")
return
logger.debug(
"event type=%r chat_type=%r content=%r",
@ -53,7 +63,7 @@ def _handle_message(data: P2ImMessageReceiveV1) -> None:
logger.info("Skipping non-text message_type=%r", message.message_type)
return
chat_id: str = message.chat_id
chat_id: str = message.chat_id or ""
raw_content: str = message.content or "{}"
content_obj = json.loads(raw_content)
text: str = content_obj.get("text", "").strip()
@ -119,8 +129,9 @@ async def _process_message(user_id: str, chat_id: str, text: str) -> None:
if nodes:
lines = ["Connected Nodes:"]
for n in nodes:
marker = "" if n.get("node_id") == registry.get_active_node(user_id) else " "
lines.append(f"{marker}{n['display_name']} sessions={n['sessions']} {n['status']}")
active_node = registry.get_active_node(user_id)
marker = "" if n.get("node_id") == (active_node.node_id if active_node else None) else " "
lines.append(f"{marker}{n.get('display_name', 'unknown')} sessions={n.get('sessions', 0)} {n.get('status', 'unknown')}")
lines.append("\nUse \"/node <name>\" to switch active node.")
await send_text(chat_id, "chat_id", "\n".join(lines))
else:
@ -144,7 +155,9 @@ async def _process_message(user_id: str, chat_id: str, text: str) -> None:
def _handle_any(data: lark.CustomizedEvent) -> None:
"""Catch-all handler to log any event the SDK receives."""
logger.info("RAW CustomizedEvent: %s", lark.JSON.marshal(data)[:500])
marshaled = lark.JSON.marshal(data)
if marshaled:
logger.info("RAW CustomizedEvent: %s", marshaled[:500])
def build_event_handler() -> lark.EventDispatcherHandler:

View File

@ -1,11 +1,11 @@
import yaml
from pathlib import Path
from typing import List, Optional
from typing import Any, Dict, List, Optional
_CONFIG_PATH = Path(__file__).parent / "keyring.yaml"
def _load() -> dict:
def _load() -> dict[str, Any]:
with open(_CONFIG_PATH, "r", encoding="utf-8") as f:
return yaml.safe_load(f) or {}
@ -23,9 +23,8 @@ METASO_API_KEY: str = _cfg.get("METASO_API_KEY", "")
ROUTER_MODE: bool = _cfg.get("ROUTER_MODE", False)
ROUTER_SECRET: str = _cfg.get("ROUTER_SECRET", "")
ALLOWED_OPEN_IDS: List[str] = _cfg.get("ALLOWED_OPEN_IDS", [])
if ALLOWED_OPEN_IDS and not isinstance(ALLOWED_OPEN_IDS, list):
ALLOWED_OPEN_IDS = [str(ALLOWED_OPEN_IDS)]
_allowed_open_ids_raw = _cfg.get("ALLOWED_OPEN_IDS", [])
ALLOWED_OPEN_IDS: list[str] = _allowed_open_ids_raw if isinstance(_allowed_open_ids_raw, list) else [str(_allowed_open_ids_raw)]
def is_user_allowed(open_id: str) -> bool:

View File

@ -46,9 +46,9 @@ class HostConfig:
self.metaso_api_key: Optional[str] = data.get("METASO_API_KEY")
serves_users = data.get("SERVES_USERS", [])
self.serves_users: List[str] = serves_users if isinstance(serves_users, list) else []
self.serves_users: list[str] = serves_users if isinstance(serves_users, list) else []
self.capabilities: List[str] = data.get(
self.capabilities: list[str] = data.get(
"CAPABILITIES",
["claude_code", "shell", "file_ops", "web", "scheduler"],
)

View File

@ -10,16 +10,15 @@ import asyncio
import logging
import secrets
import time
from typing import Optional
from typing import Any, Optional
import websockets
from websockets.client import WebSocketClientProtocol
from agent.manager import manager
from agent.scheduler import scheduler
from agent.task_runner import task_runner
from host_client.config import HostConfig, get_host_config
from orchestrator.agent import run as run_mailboy
from orchestrator.agent import agent
from orchestrator.tools import set_current_user, set_current_chat
from shared import (
RegisterMessage,
@ -40,7 +39,7 @@ class NodeClient:
def __init__(self, config: HostConfig):
self.config = config
self.ws: Optional[WebSocketClientProtocol] = None
self.ws: Any = None
self._running = False
self._last_heartbeat = time.time()
self._reconnect_delay = 1.0
@ -94,7 +93,7 @@ class NodeClient:
set_current_chat(request.chat_id)
try:
reply = await run_mailboy(request.user_id, request.text)
reply = await agent.run(request.user_id, request.text)
response = ForwardResponse(
id=request.id,

View File

@ -97,18 +97,18 @@ class OrchestrationAgent:
base_url=OPENAI_BASE_URL,
api_key=OPENAI_API_KEY,
model=OPENAI_MODEL,
temperature=0.0,
temperature=0.1,
)
self._llm_with_tools = llm.bind_tools(TOOLS)
# user_id -> list[BaseMessage]
self._history: Dict[str, List[BaseMessage]] = defaultdict(list)
self._history: dict[str, list[BaseMessage]] = defaultdict(list)
# user_id -> most recently active conv_id
self._active_conv: Dict[str, Optional[str]] = defaultdict(lambda: None)
self._active_conv: dict[str, Optional[str]] = defaultdict(lambda: None)
# user_id -> asyncio.Lock (prevents concurrent processing per user)
self._user_locks: Dict[str, asyncio.Lock] = defaultdict(asyncio.Lock)
self._user_locks: dict[str, asyncio.Lock] = defaultdict(asyncio.Lock)
# user_id -> passthrough mode enabled
self._passthrough: Dict[str, bool] = defaultdict(lambda: False)
self._passthrough: dict[str, bool] = defaultdict(lambda: False)
def _build_system_prompt(self, user_id: str) -> str:
conv_id = self._active_conv[user_id]
@ -173,7 +173,7 @@ class OrchestrationAgent:
response = await llm_no_tools.ainvoke([HumanMessage(content=qa_prompt)])
return response.content or ""
messages: List[BaseMessage] = (
messages: list[BaseMessage] = (
[SystemMessage(content=self._build_system_prompt(user_id))]
+ self._history[user_id]
+ [HumanMessage(content=text)]

View File

@ -27,18 +27,18 @@ class NodeConnection:
display_name: str = ""
serves_users: Set[str] = field(default_factory=set)
working_dir: str = ""
capabilities: List[str] = field(default_factory=list)
capabilities: list[str] = field(default_factory=list)
connected_at: float = field(default_factory=time.time)
last_heartbeat: float = field(default_factory=time.time)
sessions: int = 0
active_sessions: List[Dict[str, Any]] = field(default_factory=list)
active_sessions: list[dict[str, Any]] = field(default_factory=list)
@property
def is_online(self) -> bool:
"""Check if node is still considered online (heartbeat within 60s)."""
return time.time() - self.last_heartbeat < 60
def to_dict(self) -> Dict[str, Any]:
def to_dict(self) -> dict[str, Any]:
"""Serialize for API responses."""
return {
"node_id": self.node_id,
@ -56,9 +56,9 @@ class NodeRegistry:
"""Registry of connected host clients."""
def __init__(self, router_secret: str = ""):
self._nodes: Dict[str, NodeConnection] = {}
self._user_nodes: Dict[str, Set[str]] = {}
self._active_node: Dict[str, str] = {}
self._nodes: dict[str, NodeConnection] = {}
self._user_nodes: dict[str, Set[str]] = {}
self._active_node: dict[str, str] = {}
self._secret = router_secret
self._lock = asyncio.Lock()
@ -159,7 +159,7 @@ class NodeRegistry:
"""Get a node by ID."""
return self._nodes.get(node_id)
def get_nodes_for_user(self, user_id: str) -> List[NodeConnection]:
def get_nodes_for_user(self, user_id: str) -> list[NodeConnection]:
"""Get all nodes that serve a user."""
node_ids = self._user_nodes.get(user_id, set())
return [self._nodes[nid] for nid in node_ids if nid in self._nodes]
@ -186,11 +186,11 @@ class NodeRegistry:
logger.info("Active node for user %s set to %s", user_id, node_id)
return True
def list_nodes(self) -> List[Dict[str, Any]]:
def list_nodes(self) -> list[dict[str, Any]]:
"""List all nodes with their status."""
return [node.to_dict() for node in self._nodes.values()]
def get_affected_users(self, node_id: str) -> List[str]:
def get_affected_users(self, node_id: str) -> list[str]:
"""Get users affected by a node disconnect."""
node = self._nodes.get(node_id)
if node:

View File

@ -12,6 +12,7 @@ from typing import List, Optional
from langchain_core.messages import HumanMessage, SystemMessage
from langchain_openai import ChatOpenAI
from pydantic import SecretStr
from config import OPENAI_API_KEY, OPENAI_BASE_URL, OPENAI_MODEL
from router.nodes import NodeConnection, get_node_registry
@ -36,7 +37,7 @@ Respond with a JSON object:
"""
def _format_nodes_info(nodes: List[NodeConnection], active_node_id: Optional[str] = None) -> str:
def _format_nodes_info(nodes: list[NodeConnection], active_node_id: Optional[str] = None) -> str:
"""Format node information for the routing prompt."""
lines = []
for node in nodes:
@ -86,8 +87,8 @@ async def route(user_id: str, chat_id: str, text: str) -> tuple[Optional[str], s
try:
llm = ChatOpenAI(
model=OPENAI_MODEL,
openai_api_key=OPENAI_API_KEY,
openai_api_base=OPENAI_BASE_URL,
api_key=SecretStr(OPENAI_API_KEY),
base_url=OPENAI_BASE_URL,
temperature=0,
)
@ -98,7 +99,11 @@ async def route(user_id: str, chat_id: str, text: str) -> tuple[Optional[str], s
]
response = await llm.ainvoke(messages)
content = response.content.strip()
content = response.content
if isinstance(content, str):
content = content.strip()
else:
content = str(content).strip()
if content.startswith("```"):
content = content.split("\n", 1)[1]

View File

@ -18,7 +18,7 @@ from router.nodes import get_node_registry
logger = logging.getLogger(__name__)
_pending_requests: Dict[str, asyncio.Future] = {}
_pending_requests: dict[str, asyncio.Future[str]] = {}
_default_timeout = 600.0
@ -52,7 +52,7 @@ async def forward(
raise RuntimeError(f"Node not connected: {node_id}")
request_id = str(uuid.uuid4())
future: asyncio.Future = asyncio.get_event_loop().create_future()
future: asyncio.Future[str] = asyncio.get_event_loop().create_future()
_pending_requests[request_id] = future
request = ForwardRequest(

View File

@ -47,7 +47,7 @@ async def ws_node_endpoint(websocket: WebSocket) -> None:
return
node_id: Optional[str] = None
heartbeat_task: Optional[asyncio.Task] = None
heartbeat_task: Optional[asyncio.Task[None]] = None
async def send_heartbeat():
"""Send periodic pings to the host client."""

View File

@ -16,9 +16,9 @@ class RegisterMessage:
"""Host client -> Router: Register this node."""
type: str = "register"
node_id: str = ""
serves_users: List[str] = field(default_factory=list)
serves_users: list[str] = field(default_factory=list)
working_dir: str = ""
capabilities: List[str] = field(default_factory=list)
capabilities: list[str] = field(default_factory=list)
display_name: str = ""
@ -63,7 +63,7 @@ class NodeStatus:
type: str = "node_status"
node_id: str = ""
sessions: int = 0
active_sessions: List[Dict[str, Any]] = field(default_factory=list)
active_sessions: list[dict[str, Any]] = field(default_factory=list)
MESSAGE_TYPES = {