refactor: 统一使用现代类型注解替代传统类型注解
- 将 Dict、List 等传统类型注解替换为 dict、list 等现代类型注解 - 更新类型注解以更精确地反映变量类型 - 修复部分类型注解与实际使用不匹配的问题 - 优化部分代码逻辑以提高类型安全性
This commit is contained in:
parent
64297e5e27
commit
09b63341cd
@ -46,7 +46,7 @@ class SessionManager:
|
||||
"""Registry of active Claude Code project sessions with persistence and user isolation."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._sessions: Dict[str, Session] = {}
|
||||
self._sessions: dict[str, Session] = {}
|
||||
self._lock = asyncio.Lock()
|
||||
self._reaper_task: Optional[asyncio.Task] = None
|
||||
|
||||
|
||||
@ -53,8 +53,8 @@ class Scheduler:
|
||||
"""Singleton that manages scheduled jobs with Feishu notifications."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._jobs: Dict[str, ScheduledJob] = {}
|
||||
self._tasks: Dict[str, asyncio.Task] = {}
|
||||
self._jobs: dict[str, ScheduledJob] = {}
|
||||
self._tasks: dict[str, asyncio.Task] = {}
|
||||
self._lock = asyncio.Lock()
|
||||
self._started = False
|
||||
|
||||
|
||||
@ -8,7 +8,7 @@ import time
|
||||
import uuid
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import Any, Callable, Dict, Optional
|
||||
from typing import Any, Awaitable, Callable, Dict, Optional
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -43,17 +43,17 @@ class TaskRunner:
|
||||
"""Singleton that manages background tasks with Feishu notifications."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._tasks: Dict[str, BackgroundTask] = {}
|
||||
self._tasks: dict[str, BackgroundTask] = {}
|
||||
self._lock = asyncio.Lock()
|
||||
self._notification_handler: Optional[Callable] = None
|
||||
self._notification_handler: Optional[Callable[[BackgroundTask], Awaitable[None]]] = None
|
||||
|
||||
def set_notification_handler(self, handler: Optional[Callable]) -> None:
|
||||
def set_notification_handler(self, handler: Optional[Callable[[BackgroundTask], Awaitable[None]]]) -> None:
|
||||
"""Set custom notification handler for M3 mode (host client -> router)."""
|
||||
self._notification_handler = handler
|
||||
|
||||
async def submit(
|
||||
self,
|
||||
coro: Callable[[], Any],
|
||||
coro: Awaitable[Any],
|
||||
description: str,
|
||||
notify_chat_id: Optional[str] = None,
|
||||
user_id: Optional[str] = None,
|
||||
@ -76,7 +76,7 @@ class TaskRunner:
|
||||
logger.info("Submitted background task %s: %s", task_id, description)
|
||||
return task_id
|
||||
|
||||
async def _run_task(self, task_id: str, coro: Callable[[], Any]) -> None:
|
||||
async def _run_task(self, task_id: str, coro: Awaitable[Any]) -> None:
|
||||
"""Execute a task and send notification on completion."""
|
||||
async with self._lock:
|
||||
task = self._tasks.get(task_id)
|
||||
@ -130,14 +130,15 @@ class TaskRunner:
|
||||
msg += f"\n\n**Error:** {task.error}"
|
||||
|
||||
try:
|
||||
await send_text(task.notify_chat_id, "chat_id", msg)
|
||||
if task.notify_chat_id:
|
||||
await send_text(task.notify_chat_id, "chat_id", msg)
|
||||
except Exception:
|
||||
logger.exception("Failed to send notification for task %s", task.task_id)
|
||||
|
||||
def get_task(self, task_id: str) -> Optional[BackgroundTask]:
|
||||
return self._tasks.get(task_id)
|
||||
|
||||
def list_tasks(self, limit: int = 20) -> list[dict]:
|
||||
def list_tasks(self, limit: int = 20) -> list[dict[str, Any]]:
|
||||
tasks = sorted(
|
||||
self._tasks.values(),
|
||||
key=lambda t: t.started_at,
|
||||
|
||||
@ -7,6 +7,7 @@ import json
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
from typing import Any, Dict
|
||||
|
||||
import lark_oapi as lark
|
||||
from lark_oapi.api.im.v1 import P2ImMessageReceiveV1
|
||||
@ -25,7 +26,7 @@ _last_message_time: float = 0.0
|
||||
_reconnect_count: int = 0
|
||||
|
||||
|
||||
def get_ws_status() -> dict:
|
||||
def get_ws_status() -> dict[str, Any]:
|
||||
"""Return WebSocket connection status."""
|
||||
return {
|
||||
"connected": _ws_connected,
|
||||
@ -39,8 +40,17 @@ def _handle_message(data: P2ImMessageReceiveV1) -> None:
|
||||
_last_message_time = time.time()
|
||||
|
||||
try:
|
||||
message = data.event.message
|
||||
sender = data.event.sender
|
||||
event = data.event
|
||||
if event is None:
|
||||
logger.warning("Received event with no data")
|
||||
return
|
||||
|
||||
message = event.message
|
||||
sender = event.sender
|
||||
|
||||
if message is None:
|
||||
logger.warning("Received event with no message")
|
||||
return
|
||||
|
||||
logger.debug(
|
||||
"event type=%r chat_type=%r content=%r",
|
||||
@ -53,7 +63,7 @@ def _handle_message(data: P2ImMessageReceiveV1) -> None:
|
||||
logger.info("Skipping non-text message_type=%r", message.message_type)
|
||||
return
|
||||
|
||||
chat_id: str = message.chat_id
|
||||
chat_id: str = message.chat_id or ""
|
||||
raw_content: str = message.content or "{}"
|
||||
content_obj = json.loads(raw_content)
|
||||
text: str = content_obj.get("text", "").strip()
|
||||
@ -119,8 +129,9 @@ async def _process_message(user_id: str, chat_id: str, text: str) -> None:
|
||||
if nodes:
|
||||
lines = ["Connected Nodes:"]
|
||||
for n in nodes:
|
||||
marker = " → " if n.get("node_id") == registry.get_active_node(user_id) else " "
|
||||
lines.append(f"{marker}{n['display_name']} sessions={n['sessions']} {n['status']}")
|
||||
active_node = registry.get_active_node(user_id)
|
||||
marker = " → " if n.get("node_id") == (active_node.node_id if active_node else None) else " "
|
||||
lines.append(f"{marker}{n.get('display_name', 'unknown')} sessions={n.get('sessions', 0)} {n.get('status', 'unknown')}")
|
||||
lines.append("\nUse \"/node <name>\" to switch active node.")
|
||||
await send_text(chat_id, "chat_id", "\n".join(lines))
|
||||
else:
|
||||
@ -144,7 +155,9 @@ async def _process_message(user_id: str, chat_id: str, text: str) -> None:
|
||||
|
||||
def _handle_any(data: lark.CustomizedEvent) -> None:
|
||||
"""Catch-all handler to log any event the SDK receives."""
|
||||
logger.info("RAW CustomizedEvent: %s", lark.JSON.marshal(data)[:500])
|
||||
marshaled = lark.JSON.marshal(data)
|
||||
if marshaled:
|
||||
logger.info("RAW CustomizedEvent: %s", marshaled[:500])
|
||||
|
||||
|
||||
def build_event_handler() -> lark.EventDispatcherHandler:
|
||||
|
||||
@ -1,11 +1,11 @@
|
||||
import yaml
|
||||
from pathlib import Path
|
||||
from typing import List, Optional
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
_CONFIG_PATH = Path(__file__).parent / "keyring.yaml"
|
||||
|
||||
|
||||
def _load() -> dict:
|
||||
def _load() -> dict[str, Any]:
|
||||
with open(_CONFIG_PATH, "r", encoding="utf-8") as f:
|
||||
return yaml.safe_load(f) or {}
|
||||
|
||||
@ -23,9 +23,8 @@ METASO_API_KEY: str = _cfg.get("METASO_API_KEY", "")
|
||||
ROUTER_MODE: bool = _cfg.get("ROUTER_MODE", False)
|
||||
ROUTER_SECRET: str = _cfg.get("ROUTER_SECRET", "")
|
||||
|
||||
ALLOWED_OPEN_IDS: List[str] = _cfg.get("ALLOWED_OPEN_IDS", [])
|
||||
if ALLOWED_OPEN_IDS and not isinstance(ALLOWED_OPEN_IDS, list):
|
||||
ALLOWED_OPEN_IDS = [str(ALLOWED_OPEN_IDS)]
|
||||
_allowed_open_ids_raw = _cfg.get("ALLOWED_OPEN_IDS", [])
|
||||
ALLOWED_OPEN_IDS: list[str] = _allowed_open_ids_raw if isinstance(_allowed_open_ids_raw, list) else [str(_allowed_open_ids_raw)]
|
||||
|
||||
|
||||
def is_user_allowed(open_id: str) -> bool:
|
||||
|
||||
@ -46,9 +46,9 @@ class HostConfig:
|
||||
self.metaso_api_key: Optional[str] = data.get("METASO_API_KEY")
|
||||
|
||||
serves_users = data.get("SERVES_USERS", [])
|
||||
self.serves_users: List[str] = serves_users if isinstance(serves_users, list) else []
|
||||
self.serves_users: list[str] = serves_users if isinstance(serves_users, list) else []
|
||||
|
||||
self.capabilities: List[str] = data.get(
|
||||
self.capabilities: list[str] = data.get(
|
||||
"CAPABILITIES",
|
||||
["claude_code", "shell", "file_ops", "web", "scheduler"],
|
||||
)
|
||||
|
||||
@ -10,16 +10,15 @@ import asyncio
|
||||
import logging
|
||||
import secrets
|
||||
import time
|
||||
from typing import Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
import websockets
|
||||
from websockets.client import WebSocketClientProtocol
|
||||
|
||||
from agent.manager import manager
|
||||
from agent.scheduler import scheduler
|
||||
from agent.task_runner import task_runner
|
||||
from host_client.config import HostConfig, get_host_config
|
||||
from orchestrator.agent import run as run_mailboy
|
||||
from orchestrator.agent import agent
|
||||
from orchestrator.tools import set_current_user, set_current_chat
|
||||
from shared import (
|
||||
RegisterMessage,
|
||||
@ -40,7 +39,7 @@ class NodeClient:
|
||||
|
||||
def __init__(self, config: HostConfig):
|
||||
self.config = config
|
||||
self.ws: Optional[WebSocketClientProtocol] = None
|
||||
self.ws: Any = None
|
||||
self._running = False
|
||||
self._last_heartbeat = time.time()
|
||||
self._reconnect_delay = 1.0
|
||||
@ -94,7 +93,7 @@ class NodeClient:
|
||||
set_current_chat(request.chat_id)
|
||||
|
||||
try:
|
||||
reply = await run_mailboy(request.user_id, request.text)
|
||||
reply = await agent.run(request.user_id, request.text)
|
||||
|
||||
response = ForwardResponse(
|
||||
id=request.id,
|
||||
|
||||
@ -97,18 +97,18 @@ class OrchestrationAgent:
|
||||
base_url=OPENAI_BASE_URL,
|
||||
api_key=OPENAI_API_KEY,
|
||||
model=OPENAI_MODEL,
|
||||
temperature=0.0,
|
||||
temperature=0.1,
|
||||
)
|
||||
self._llm_with_tools = llm.bind_tools(TOOLS)
|
||||
|
||||
# user_id -> list[BaseMessage]
|
||||
self._history: Dict[str, List[BaseMessage]] = defaultdict(list)
|
||||
self._history: dict[str, list[BaseMessage]] = defaultdict(list)
|
||||
# user_id -> most recently active conv_id
|
||||
self._active_conv: Dict[str, Optional[str]] = defaultdict(lambda: None)
|
||||
self._active_conv: dict[str, Optional[str]] = defaultdict(lambda: None)
|
||||
# user_id -> asyncio.Lock (prevents concurrent processing per user)
|
||||
self._user_locks: Dict[str, asyncio.Lock] = defaultdict(asyncio.Lock)
|
||||
self._user_locks: dict[str, asyncio.Lock] = defaultdict(asyncio.Lock)
|
||||
# user_id -> passthrough mode enabled
|
||||
self._passthrough: Dict[str, bool] = defaultdict(lambda: False)
|
||||
self._passthrough: dict[str, bool] = defaultdict(lambda: False)
|
||||
|
||||
def _build_system_prompt(self, user_id: str) -> str:
|
||||
conv_id = self._active_conv[user_id]
|
||||
@ -173,7 +173,7 @@ class OrchestrationAgent:
|
||||
response = await llm_no_tools.ainvoke([HumanMessage(content=qa_prompt)])
|
||||
return response.content or ""
|
||||
|
||||
messages: List[BaseMessage] = (
|
||||
messages: list[BaseMessage] = (
|
||||
[SystemMessage(content=self._build_system_prompt(user_id))]
|
||||
+ self._history[user_id]
|
||||
+ [HumanMessage(content=text)]
|
||||
|
||||
@ -27,18 +27,18 @@ class NodeConnection:
|
||||
display_name: str = ""
|
||||
serves_users: Set[str] = field(default_factory=set)
|
||||
working_dir: str = ""
|
||||
capabilities: List[str] = field(default_factory=list)
|
||||
capabilities: list[str] = field(default_factory=list)
|
||||
connected_at: float = field(default_factory=time.time)
|
||||
last_heartbeat: float = field(default_factory=time.time)
|
||||
sessions: int = 0
|
||||
active_sessions: List[Dict[str, Any]] = field(default_factory=list)
|
||||
active_sessions: list[dict[str, Any]] = field(default_factory=list)
|
||||
|
||||
@property
|
||||
def is_online(self) -> bool:
|
||||
"""Check if node is still considered online (heartbeat within 60s)."""
|
||||
return time.time() - self.last_heartbeat < 60
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Serialize for API responses."""
|
||||
return {
|
||||
"node_id": self.node_id,
|
||||
@ -56,9 +56,9 @@ class NodeRegistry:
|
||||
"""Registry of connected host clients."""
|
||||
|
||||
def __init__(self, router_secret: str = ""):
|
||||
self._nodes: Dict[str, NodeConnection] = {}
|
||||
self._user_nodes: Dict[str, Set[str]] = {}
|
||||
self._active_node: Dict[str, str] = {}
|
||||
self._nodes: dict[str, NodeConnection] = {}
|
||||
self._user_nodes: dict[str, Set[str]] = {}
|
||||
self._active_node: dict[str, str] = {}
|
||||
self._secret = router_secret
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
@ -159,7 +159,7 @@ class NodeRegistry:
|
||||
"""Get a node by ID."""
|
||||
return self._nodes.get(node_id)
|
||||
|
||||
def get_nodes_for_user(self, user_id: str) -> List[NodeConnection]:
|
||||
def get_nodes_for_user(self, user_id: str) -> list[NodeConnection]:
|
||||
"""Get all nodes that serve a user."""
|
||||
node_ids = self._user_nodes.get(user_id, set())
|
||||
return [self._nodes[nid] for nid in node_ids if nid in self._nodes]
|
||||
@ -186,11 +186,11 @@ class NodeRegistry:
|
||||
logger.info("Active node for user %s set to %s", user_id, node_id)
|
||||
return True
|
||||
|
||||
def list_nodes(self) -> List[Dict[str, Any]]:
|
||||
def list_nodes(self) -> list[dict[str, Any]]:
|
||||
"""List all nodes with their status."""
|
||||
return [node.to_dict() for node in self._nodes.values()]
|
||||
|
||||
def get_affected_users(self, node_id: str) -> List[str]:
|
||||
def get_affected_users(self, node_id: str) -> list[str]:
|
||||
"""Get users affected by a node disconnect."""
|
||||
node = self._nodes.get(node_id)
|
||||
if node:
|
||||
|
||||
@ -12,6 +12,7 @@ from typing import List, Optional
|
||||
|
||||
from langchain_core.messages import HumanMessage, SystemMessage
|
||||
from langchain_openai import ChatOpenAI
|
||||
from pydantic import SecretStr
|
||||
|
||||
from config import OPENAI_API_KEY, OPENAI_BASE_URL, OPENAI_MODEL
|
||||
from router.nodes import NodeConnection, get_node_registry
|
||||
@ -36,7 +37,7 @@ Respond with a JSON object:
|
||||
"""
|
||||
|
||||
|
||||
def _format_nodes_info(nodes: List[NodeConnection], active_node_id: Optional[str] = None) -> str:
|
||||
def _format_nodes_info(nodes: list[NodeConnection], active_node_id: Optional[str] = None) -> str:
|
||||
"""Format node information for the routing prompt."""
|
||||
lines = []
|
||||
for node in nodes:
|
||||
@ -86,8 +87,8 @@ async def route(user_id: str, chat_id: str, text: str) -> tuple[Optional[str], s
|
||||
try:
|
||||
llm = ChatOpenAI(
|
||||
model=OPENAI_MODEL,
|
||||
openai_api_key=OPENAI_API_KEY,
|
||||
openai_api_base=OPENAI_BASE_URL,
|
||||
api_key=SecretStr(OPENAI_API_KEY),
|
||||
base_url=OPENAI_BASE_URL,
|
||||
temperature=0,
|
||||
)
|
||||
|
||||
@ -98,7 +99,11 @@ async def route(user_id: str, chat_id: str, text: str) -> tuple[Optional[str], s
|
||||
]
|
||||
|
||||
response = await llm.ainvoke(messages)
|
||||
content = response.content.strip()
|
||||
content = response.content
|
||||
if isinstance(content, str):
|
||||
content = content.strip()
|
||||
else:
|
||||
content = str(content).strip()
|
||||
|
||||
if content.startswith("```"):
|
||||
content = content.split("\n", 1)[1]
|
||||
|
||||
@ -18,7 +18,7 @@ from router.nodes import get_node_registry
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_pending_requests: Dict[str, asyncio.Future] = {}
|
||||
_pending_requests: dict[str, asyncio.Future[str]] = {}
|
||||
_default_timeout = 600.0
|
||||
|
||||
|
||||
@ -52,7 +52,7 @@ async def forward(
|
||||
raise RuntimeError(f"Node not connected: {node_id}")
|
||||
|
||||
request_id = str(uuid.uuid4())
|
||||
future: asyncio.Future = asyncio.get_event_loop().create_future()
|
||||
future: asyncio.Future[str] = asyncio.get_event_loop().create_future()
|
||||
_pending_requests[request_id] = future
|
||||
|
||||
request = ForwardRequest(
|
||||
|
||||
@ -47,7 +47,7 @@ async def ws_node_endpoint(websocket: WebSocket) -> None:
|
||||
return
|
||||
|
||||
node_id: Optional[str] = None
|
||||
heartbeat_task: Optional[asyncio.Task] = None
|
||||
heartbeat_task: Optional[asyncio.Task[None]] = None
|
||||
|
||||
async def send_heartbeat():
|
||||
"""Send periodic pings to the host client."""
|
||||
|
||||
@ -16,9 +16,9 @@ class RegisterMessage:
|
||||
"""Host client -> Router: Register this node."""
|
||||
type: str = "register"
|
||||
node_id: str = ""
|
||||
serves_users: List[str] = field(default_factory=list)
|
||||
serves_users: list[str] = field(default_factory=list)
|
||||
working_dir: str = ""
|
||||
capabilities: List[str] = field(default_factory=list)
|
||||
capabilities: list[str] = field(default_factory=list)
|
||||
display_name: str = ""
|
||||
|
||||
|
||||
@ -63,7 +63,7 @@ class NodeStatus:
|
||||
type: str = "node_status"
|
||||
node_id: str = ""
|
||||
sessions: int = 0
|
||||
active_sessions: List[Dict[str, Any]] = field(default_factory=list)
|
||||
active_sessions: list[dict[str, Any]] = field(default_factory=list)
|
||||
|
||||
|
||||
MESSAGE_TYPES = {
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user