refactor: 统一使用现代类型注解替代传统类型注解
- 将 Dict、List 等传统类型注解替换为 dict、list 等现代类型注解 - 更新类型注解以更精确地反映变量类型 - 修复部分类型注解与实际使用不匹配的问题 - 优化部分代码逻辑以提高类型安全性
This commit is contained in:
parent
64297e5e27
commit
09b63341cd
@ -46,7 +46,7 @@ class SessionManager:
|
|||||||
"""Registry of active Claude Code project sessions with persistence and user isolation."""
|
"""Registry of active Claude Code project sessions with persistence and user isolation."""
|
||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
self._sessions: Dict[str, Session] = {}
|
self._sessions: dict[str, Session] = {}
|
||||||
self._lock = asyncio.Lock()
|
self._lock = asyncio.Lock()
|
||||||
self._reaper_task: Optional[asyncio.Task] = None
|
self._reaper_task: Optional[asyncio.Task] = None
|
||||||
|
|
||||||
|
|||||||
@ -53,8 +53,8 @@ class Scheduler:
|
|||||||
"""Singleton that manages scheduled jobs with Feishu notifications."""
|
"""Singleton that manages scheduled jobs with Feishu notifications."""
|
||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
self._jobs: Dict[str, ScheduledJob] = {}
|
self._jobs: dict[str, ScheduledJob] = {}
|
||||||
self._tasks: Dict[str, asyncio.Task] = {}
|
self._tasks: dict[str, asyncio.Task] = {}
|
||||||
self._lock = asyncio.Lock()
|
self._lock = asyncio.Lock()
|
||||||
self._started = False
|
self._started = False
|
||||||
|
|
||||||
|
|||||||
@ -8,7 +8,7 @@ import time
|
|||||||
import uuid
|
import uuid
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any, Callable, Dict, Optional
|
from typing import Any, Awaitable, Callable, Dict, Optional
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -43,17 +43,17 @@ class TaskRunner:
|
|||||||
"""Singleton that manages background tasks with Feishu notifications."""
|
"""Singleton that manages background tasks with Feishu notifications."""
|
||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
self._tasks: Dict[str, BackgroundTask] = {}
|
self._tasks: dict[str, BackgroundTask] = {}
|
||||||
self._lock = asyncio.Lock()
|
self._lock = asyncio.Lock()
|
||||||
self._notification_handler: Optional[Callable] = None
|
self._notification_handler: Optional[Callable[[BackgroundTask], Awaitable[None]]] = None
|
||||||
|
|
||||||
def set_notification_handler(self, handler: Optional[Callable]) -> None:
|
def set_notification_handler(self, handler: Optional[Callable[[BackgroundTask], Awaitable[None]]]) -> None:
|
||||||
"""Set custom notification handler for M3 mode (host client -> router)."""
|
"""Set custom notification handler for M3 mode (host client -> router)."""
|
||||||
self._notification_handler = handler
|
self._notification_handler = handler
|
||||||
|
|
||||||
async def submit(
|
async def submit(
|
||||||
self,
|
self,
|
||||||
coro: Callable[[], Any],
|
coro: Awaitable[Any],
|
||||||
description: str,
|
description: str,
|
||||||
notify_chat_id: Optional[str] = None,
|
notify_chat_id: Optional[str] = None,
|
||||||
user_id: Optional[str] = None,
|
user_id: Optional[str] = None,
|
||||||
@ -76,7 +76,7 @@ class TaskRunner:
|
|||||||
logger.info("Submitted background task %s: %s", task_id, description)
|
logger.info("Submitted background task %s: %s", task_id, description)
|
||||||
return task_id
|
return task_id
|
||||||
|
|
||||||
async def _run_task(self, task_id: str, coro: Callable[[], Any]) -> None:
|
async def _run_task(self, task_id: str, coro: Awaitable[Any]) -> None:
|
||||||
"""Execute a task and send notification on completion."""
|
"""Execute a task and send notification on completion."""
|
||||||
async with self._lock:
|
async with self._lock:
|
||||||
task = self._tasks.get(task_id)
|
task = self._tasks.get(task_id)
|
||||||
@ -130,14 +130,15 @@ class TaskRunner:
|
|||||||
msg += f"\n\n**Error:** {task.error}"
|
msg += f"\n\n**Error:** {task.error}"
|
||||||
|
|
||||||
try:
|
try:
|
||||||
await send_text(task.notify_chat_id, "chat_id", msg)
|
if task.notify_chat_id:
|
||||||
|
await send_text(task.notify_chat_id, "chat_id", msg)
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception("Failed to send notification for task %s", task.task_id)
|
logger.exception("Failed to send notification for task %s", task.task_id)
|
||||||
|
|
||||||
def get_task(self, task_id: str) -> Optional[BackgroundTask]:
|
def get_task(self, task_id: str) -> Optional[BackgroundTask]:
|
||||||
return self._tasks.get(task_id)
|
return self._tasks.get(task_id)
|
||||||
|
|
||||||
def list_tasks(self, limit: int = 20) -> list[dict]:
|
def list_tasks(self, limit: int = 20) -> list[dict[str, Any]]:
|
||||||
tasks = sorted(
|
tasks = sorted(
|
||||||
self._tasks.values(),
|
self._tasks.values(),
|
||||||
key=lambda t: t.started_at,
|
key=lambda t: t.started_at,
|
||||||
|
|||||||
@ -7,6 +7,7 @@ import json
|
|||||||
import logging
|
import logging
|
||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
|
from typing import Any, Dict
|
||||||
|
|
||||||
import lark_oapi as lark
|
import lark_oapi as lark
|
||||||
from lark_oapi.api.im.v1 import P2ImMessageReceiveV1
|
from lark_oapi.api.im.v1 import P2ImMessageReceiveV1
|
||||||
@ -25,7 +26,7 @@ _last_message_time: float = 0.0
|
|||||||
_reconnect_count: int = 0
|
_reconnect_count: int = 0
|
||||||
|
|
||||||
|
|
||||||
def get_ws_status() -> dict:
|
def get_ws_status() -> dict[str, Any]:
|
||||||
"""Return WebSocket connection status."""
|
"""Return WebSocket connection status."""
|
||||||
return {
|
return {
|
||||||
"connected": _ws_connected,
|
"connected": _ws_connected,
|
||||||
@ -39,8 +40,17 @@ def _handle_message(data: P2ImMessageReceiveV1) -> None:
|
|||||||
_last_message_time = time.time()
|
_last_message_time = time.time()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
message = data.event.message
|
event = data.event
|
||||||
sender = data.event.sender
|
if event is None:
|
||||||
|
logger.warning("Received event with no data")
|
||||||
|
return
|
||||||
|
|
||||||
|
message = event.message
|
||||||
|
sender = event.sender
|
||||||
|
|
||||||
|
if message is None:
|
||||||
|
logger.warning("Received event with no message")
|
||||||
|
return
|
||||||
|
|
||||||
logger.debug(
|
logger.debug(
|
||||||
"event type=%r chat_type=%r content=%r",
|
"event type=%r chat_type=%r content=%r",
|
||||||
@ -53,7 +63,7 @@ def _handle_message(data: P2ImMessageReceiveV1) -> None:
|
|||||||
logger.info("Skipping non-text message_type=%r", message.message_type)
|
logger.info("Skipping non-text message_type=%r", message.message_type)
|
||||||
return
|
return
|
||||||
|
|
||||||
chat_id: str = message.chat_id
|
chat_id: str = message.chat_id or ""
|
||||||
raw_content: str = message.content or "{}"
|
raw_content: str = message.content or "{}"
|
||||||
content_obj = json.loads(raw_content)
|
content_obj = json.loads(raw_content)
|
||||||
text: str = content_obj.get("text", "").strip()
|
text: str = content_obj.get("text", "").strip()
|
||||||
@ -119,8 +129,9 @@ async def _process_message(user_id: str, chat_id: str, text: str) -> None:
|
|||||||
if nodes:
|
if nodes:
|
||||||
lines = ["Connected Nodes:"]
|
lines = ["Connected Nodes:"]
|
||||||
for n in nodes:
|
for n in nodes:
|
||||||
marker = " → " if n.get("node_id") == registry.get_active_node(user_id) else " "
|
active_node = registry.get_active_node(user_id)
|
||||||
lines.append(f"{marker}{n['display_name']} sessions={n['sessions']} {n['status']}")
|
marker = " → " if n.get("node_id") == (active_node.node_id if active_node else None) else " "
|
||||||
|
lines.append(f"{marker}{n.get('display_name', 'unknown')} sessions={n.get('sessions', 0)} {n.get('status', 'unknown')}")
|
||||||
lines.append("\nUse \"/node <name>\" to switch active node.")
|
lines.append("\nUse \"/node <name>\" to switch active node.")
|
||||||
await send_text(chat_id, "chat_id", "\n".join(lines))
|
await send_text(chat_id, "chat_id", "\n".join(lines))
|
||||||
else:
|
else:
|
||||||
@ -144,7 +155,9 @@ async def _process_message(user_id: str, chat_id: str, text: str) -> None:
|
|||||||
|
|
||||||
def _handle_any(data: lark.CustomizedEvent) -> None:
|
def _handle_any(data: lark.CustomizedEvent) -> None:
|
||||||
"""Catch-all handler to log any event the SDK receives."""
|
"""Catch-all handler to log any event the SDK receives."""
|
||||||
logger.info("RAW CustomizedEvent: %s", lark.JSON.marshal(data)[:500])
|
marshaled = lark.JSON.marshal(data)
|
||||||
|
if marshaled:
|
||||||
|
logger.info("RAW CustomizedEvent: %s", marshaled[:500])
|
||||||
|
|
||||||
|
|
||||||
def build_event_handler() -> lark.EventDispatcherHandler:
|
def build_event_handler() -> lark.EventDispatcherHandler:
|
||||||
|
|||||||
@ -1,11 +1,11 @@
|
|||||||
import yaml
|
import yaml
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
_CONFIG_PATH = Path(__file__).parent / "keyring.yaml"
|
_CONFIG_PATH = Path(__file__).parent / "keyring.yaml"
|
||||||
|
|
||||||
|
|
||||||
def _load() -> dict:
|
def _load() -> dict[str, Any]:
|
||||||
with open(_CONFIG_PATH, "r", encoding="utf-8") as f:
|
with open(_CONFIG_PATH, "r", encoding="utf-8") as f:
|
||||||
return yaml.safe_load(f) or {}
|
return yaml.safe_load(f) or {}
|
||||||
|
|
||||||
@ -23,9 +23,8 @@ METASO_API_KEY: str = _cfg.get("METASO_API_KEY", "")
|
|||||||
ROUTER_MODE: bool = _cfg.get("ROUTER_MODE", False)
|
ROUTER_MODE: bool = _cfg.get("ROUTER_MODE", False)
|
||||||
ROUTER_SECRET: str = _cfg.get("ROUTER_SECRET", "")
|
ROUTER_SECRET: str = _cfg.get("ROUTER_SECRET", "")
|
||||||
|
|
||||||
ALLOWED_OPEN_IDS: List[str] = _cfg.get("ALLOWED_OPEN_IDS", [])
|
_allowed_open_ids_raw = _cfg.get("ALLOWED_OPEN_IDS", [])
|
||||||
if ALLOWED_OPEN_IDS and not isinstance(ALLOWED_OPEN_IDS, list):
|
ALLOWED_OPEN_IDS: list[str] = _allowed_open_ids_raw if isinstance(_allowed_open_ids_raw, list) else [str(_allowed_open_ids_raw)]
|
||||||
ALLOWED_OPEN_IDS = [str(ALLOWED_OPEN_IDS)]
|
|
||||||
|
|
||||||
|
|
||||||
def is_user_allowed(open_id: str) -> bool:
|
def is_user_allowed(open_id: str) -> bool:
|
||||||
|
|||||||
@ -46,9 +46,9 @@ class HostConfig:
|
|||||||
self.metaso_api_key: Optional[str] = data.get("METASO_API_KEY")
|
self.metaso_api_key: Optional[str] = data.get("METASO_API_KEY")
|
||||||
|
|
||||||
serves_users = data.get("SERVES_USERS", [])
|
serves_users = data.get("SERVES_USERS", [])
|
||||||
self.serves_users: List[str] = serves_users if isinstance(serves_users, list) else []
|
self.serves_users: list[str] = serves_users if isinstance(serves_users, list) else []
|
||||||
|
|
||||||
self.capabilities: List[str] = data.get(
|
self.capabilities: list[str] = data.get(
|
||||||
"CAPABILITIES",
|
"CAPABILITIES",
|
||||||
["claude_code", "shell", "file_ops", "web", "scheduler"],
|
["claude_code", "shell", "file_ops", "web", "scheduler"],
|
||||||
)
|
)
|
||||||
|
|||||||
@ -10,16 +10,15 @@ import asyncio
|
|||||||
import logging
|
import logging
|
||||||
import secrets
|
import secrets
|
||||||
import time
|
import time
|
||||||
from typing import Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
import websockets
|
import websockets
|
||||||
from websockets.client import WebSocketClientProtocol
|
|
||||||
|
|
||||||
from agent.manager import manager
|
from agent.manager import manager
|
||||||
from agent.scheduler import scheduler
|
from agent.scheduler import scheduler
|
||||||
from agent.task_runner import task_runner
|
from agent.task_runner import task_runner
|
||||||
from host_client.config import HostConfig, get_host_config
|
from host_client.config import HostConfig, get_host_config
|
||||||
from orchestrator.agent import run as run_mailboy
|
from orchestrator.agent import agent
|
||||||
from orchestrator.tools import set_current_user, set_current_chat
|
from orchestrator.tools import set_current_user, set_current_chat
|
||||||
from shared import (
|
from shared import (
|
||||||
RegisterMessage,
|
RegisterMessage,
|
||||||
@ -40,7 +39,7 @@ class NodeClient:
|
|||||||
|
|
||||||
def __init__(self, config: HostConfig):
|
def __init__(self, config: HostConfig):
|
||||||
self.config = config
|
self.config = config
|
||||||
self.ws: Optional[WebSocketClientProtocol] = None
|
self.ws: Any = None
|
||||||
self._running = False
|
self._running = False
|
||||||
self._last_heartbeat = time.time()
|
self._last_heartbeat = time.time()
|
||||||
self._reconnect_delay = 1.0
|
self._reconnect_delay = 1.0
|
||||||
@ -94,7 +93,7 @@ class NodeClient:
|
|||||||
set_current_chat(request.chat_id)
|
set_current_chat(request.chat_id)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
reply = await run_mailboy(request.user_id, request.text)
|
reply = await agent.run(request.user_id, request.text)
|
||||||
|
|
||||||
response = ForwardResponse(
|
response = ForwardResponse(
|
||||||
id=request.id,
|
id=request.id,
|
||||||
|
|||||||
@ -97,18 +97,18 @@ class OrchestrationAgent:
|
|||||||
base_url=OPENAI_BASE_URL,
|
base_url=OPENAI_BASE_URL,
|
||||||
api_key=OPENAI_API_KEY,
|
api_key=OPENAI_API_KEY,
|
||||||
model=OPENAI_MODEL,
|
model=OPENAI_MODEL,
|
||||||
temperature=0.0,
|
temperature=0.1,
|
||||||
)
|
)
|
||||||
self._llm_with_tools = llm.bind_tools(TOOLS)
|
self._llm_with_tools = llm.bind_tools(TOOLS)
|
||||||
|
|
||||||
# user_id -> list[BaseMessage]
|
# user_id -> list[BaseMessage]
|
||||||
self._history: Dict[str, List[BaseMessage]] = defaultdict(list)
|
self._history: dict[str, list[BaseMessage]] = defaultdict(list)
|
||||||
# user_id -> most recently active conv_id
|
# user_id -> most recently active conv_id
|
||||||
self._active_conv: Dict[str, Optional[str]] = defaultdict(lambda: None)
|
self._active_conv: dict[str, Optional[str]] = defaultdict(lambda: None)
|
||||||
# user_id -> asyncio.Lock (prevents concurrent processing per user)
|
# user_id -> asyncio.Lock (prevents concurrent processing per user)
|
||||||
self._user_locks: Dict[str, asyncio.Lock] = defaultdict(asyncio.Lock)
|
self._user_locks: dict[str, asyncio.Lock] = defaultdict(asyncio.Lock)
|
||||||
# user_id -> passthrough mode enabled
|
# user_id -> passthrough mode enabled
|
||||||
self._passthrough: Dict[str, bool] = defaultdict(lambda: False)
|
self._passthrough: dict[str, bool] = defaultdict(lambda: False)
|
||||||
|
|
||||||
def _build_system_prompt(self, user_id: str) -> str:
|
def _build_system_prompt(self, user_id: str) -> str:
|
||||||
conv_id = self._active_conv[user_id]
|
conv_id = self._active_conv[user_id]
|
||||||
@ -173,7 +173,7 @@ class OrchestrationAgent:
|
|||||||
response = await llm_no_tools.ainvoke([HumanMessage(content=qa_prompt)])
|
response = await llm_no_tools.ainvoke([HumanMessage(content=qa_prompt)])
|
||||||
return response.content or ""
|
return response.content or ""
|
||||||
|
|
||||||
messages: List[BaseMessage] = (
|
messages: list[BaseMessage] = (
|
||||||
[SystemMessage(content=self._build_system_prompt(user_id))]
|
[SystemMessage(content=self._build_system_prompt(user_id))]
|
||||||
+ self._history[user_id]
|
+ self._history[user_id]
|
||||||
+ [HumanMessage(content=text)]
|
+ [HumanMessage(content=text)]
|
||||||
|
|||||||
@ -27,18 +27,18 @@ class NodeConnection:
|
|||||||
display_name: str = ""
|
display_name: str = ""
|
||||||
serves_users: Set[str] = field(default_factory=set)
|
serves_users: Set[str] = field(default_factory=set)
|
||||||
working_dir: str = ""
|
working_dir: str = ""
|
||||||
capabilities: List[str] = field(default_factory=list)
|
capabilities: list[str] = field(default_factory=list)
|
||||||
connected_at: float = field(default_factory=time.time)
|
connected_at: float = field(default_factory=time.time)
|
||||||
last_heartbeat: float = field(default_factory=time.time)
|
last_heartbeat: float = field(default_factory=time.time)
|
||||||
sessions: int = 0
|
sessions: int = 0
|
||||||
active_sessions: List[Dict[str, Any]] = field(default_factory=list)
|
active_sessions: list[dict[str, Any]] = field(default_factory=list)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def is_online(self) -> bool:
|
def is_online(self) -> bool:
|
||||||
"""Check if node is still considered online (heartbeat within 60s)."""
|
"""Check if node is still considered online (heartbeat within 60s)."""
|
||||||
return time.time() - self.last_heartbeat < 60
|
return time.time() - self.last_heartbeat < 60
|
||||||
|
|
||||||
def to_dict(self) -> Dict[str, Any]:
|
def to_dict(self) -> dict[str, Any]:
|
||||||
"""Serialize for API responses."""
|
"""Serialize for API responses."""
|
||||||
return {
|
return {
|
||||||
"node_id": self.node_id,
|
"node_id": self.node_id,
|
||||||
@ -56,9 +56,9 @@ class NodeRegistry:
|
|||||||
"""Registry of connected host clients."""
|
"""Registry of connected host clients."""
|
||||||
|
|
||||||
def __init__(self, router_secret: str = ""):
|
def __init__(self, router_secret: str = ""):
|
||||||
self._nodes: Dict[str, NodeConnection] = {}
|
self._nodes: dict[str, NodeConnection] = {}
|
||||||
self._user_nodes: Dict[str, Set[str]] = {}
|
self._user_nodes: dict[str, Set[str]] = {}
|
||||||
self._active_node: Dict[str, str] = {}
|
self._active_node: dict[str, str] = {}
|
||||||
self._secret = router_secret
|
self._secret = router_secret
|
||||||
self._lock = asyncio.Lock()
|
self._lock = asyncio.Lock()
|
||||||
|
|
||||||
@ -159,7 +159,7 @@ class NodeRegistry:
|
|||||||
"""Get a node by ID."""
|
"""Get a node by ID."""
|
||||||
return self._nodes.get(node_id)
|
return self._nodes.get(node_id)
|
||||||
|
|
||||||
def get_nodes_for_user(self, user_id: str) -> List[NodeConnection]:
|
def get_nodes_for_user(self, user_id: str) -> list[NodeConnection]:
|
||||||
"""Get all nodes that serve a user."""
|
"""Get all nodes that serve a user."""
|
||||||
node_ids = self._user_nodes.get(user_id, set())
|
node_ids = self._user_nodes.get(user_id, set())
|
||||||
return [self._nodes[nid] for nid in node_ids if nid in self._nodes]
|
return [self._nodes[nid] for nid in node_ids if nid in self._nodes]
|
||||||
@ -186,11 +186,11 @@ class NodeRegistry:
|
|||||||
logger.info("Active node for user %s set to %s", user_id, node_id)
|
logger.info("Active node for user %s set to %s", user_id, node_id)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def list_nodes(self) -> List[Dict[str, Any]]:
|
def list_nodes(self) -> list[dict[str, Any]]:
|
||||||
"""List all nodes with their status."""
|
"""List all nodes with their status."""
|
||||||
return [node.to_dict() for node in self._nodes.values()]
|
return [node.to_dict() for node in self._nodes.values()]
|
||||||
|
|
||||||
def get_affected_users(self, node_id: str) -> List[str]:
|
def get_affected_users(self, node_id: str) -> list[str]:
|
||||||
"""Get users affected by a node disconnect."""
|
"""Get users affected by a node disconnect."""
|
||||||
node = self._nodes.get(node_id)
|
node = self._nodes.get(node_id)
|
||||||
if node:
|
if node:
|
||||||
|
|||||||
@ -12,6 +12,7 @@ from typing import List, Optional
|
|||||||
|
|
||||||
from langchain_core.messages import HumanMessage, SystemMessage
|
from langchain_core.messages import HumanMessage, SystemMessage
|
||||||
from langchain_openai import ChatOpenAI
|
from langchain_openai import ChatOpenAI
|
||||||
|
from pydantic import SecretStr
|
||||||
|
|
||||||
from config import OPENAI_API_KEY, OPENAI_BASE_URL, OPENAI_MODEL
|
from config import OPENAI_API_KEY, OPENAI_BASE_URL, OPENAI_MODEL
|
||||||
from router.nodes import NodeConnection, get_node_registry
|
from router.nodes import NodeConnection, get_node_registry
|
||||||
@ -36,7 +37,7 @@ Respond with a JSON object:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
def _format_nodes_info(nodes: List[NodeConnection], active_node_id: Optional[str] = None) -> str:
|
def _format_nodes_info(nodes: list[NodeConnection], active_node_id: Optional[str] = None) -> str:
|
||||||
"""Format node information for the routing prompt."""
|
"""Format node information for the routing prompt."""
|
||||||
lines = []
|
lines = []
|
||||||
for node in nodes:
|
for node in nodes:
|
||||||
@ -86,8 +87,8 @@ async def route(user_id: str, chat_id: str, text: str) -> tuple[Optional[str], s
|
|||||||
try:
|
try:
|
||||||
llm = ChatOpenAI(
|
llm = ChatOpenAI(
|
||||||
model=OPENAI_MODEL,
|
model=OPENAI_MODEL,
|
||||||
openai_api_key=OPENAI_API_KEY,
|
api_key=SecretStr(OPENAI_API_KEY),
|
||||||
openai_api_base=OPENAI_BASE_URL,
|
base_url=OPENAI_BASE_URL,
|
||||||
temperature=0,
|
temperature=0,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -98,7 +99,11 @@ async def route(user_id: str, chat_id: str, text: str) -> tuple[Optional[str], s
|
|||||||
]
|
]
|
||||||
|
|
||||||
response = await llm.ainvoke(messages)
|
response = await llm.ainvoke(messages)
|
||||||
content = response.content.strip()
|
content = response.content
|
||||||
|
if isinstance(content, str):
|
||||||
|
content = content.strip()
|
||||||
|
else:
|
||||||
|
content = str(content).strip()
|
||||||
|
|
||||||
if content.startswith("```"):
|
if content.startswith("```"):
|
||||||
content = content.split("\n", 1)[1]
|
content = content.split("\n", 1)[1]
|
||||||
|
|||||||
@ -18,7 +18,7 @@ from router.nodes import get_node_registry
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
_pending_requests: Dict[str, asyncio.Future] = {}
|
_pending_requests: dict[str, asyncio.Future[str]] = {}
|
||||||
_default_timeout = 600.0
|
_default_timeout = 600.0
|
||||||
|
|
||||||
|
|
||||||
@ -52,7 +52,7 @@ async def forward(
|
|||||||
raise RuntimeError(f"Node not connected: {node_id}")
|
raise RuntimeError(f"Node not connected: {node_id}")
|
||||||
|
|
||||||
request_id = str(uuid.uuid4())
|
request_id = str(uuid.uuid4())
|
||||||
future: asyncio.Future = asyncio.get_event_loop().create_future()
|
future: asyncio.Future[str] = asyncio.get_event_loop().create_future()
|
||||||
_pending_requests[request_id] = future
|
_pending_requests[request_id] = future
|
||||||
|
|
||||||
request = ForwardRequest(
|
request = ForwardRequest(
|
||||||
|
|||||||
@ -47,7 +47,7 @@ async def ws_node_endpoint(websocket: WebSocket) -> None:
|
|||||||
return
|
return
|
||||||
|
|
||||||
node_id: Optional[str] = None
|
node_id: Optional[str] = None
|
||||||
heartbeat_task: Optional[asyncio.Task] = None
|
heartbeat_task: Optional[asyncio.Task[None]] = None
|
||||||
|
|
||||||
async def send_heartbeat():
|
async def send_heartbeat():
|
||||||
"""Send periodic pings to the host client."""
|
"""Send periodic pings to the host client."""
|
||||||
|
|||||||
@ -16,9 +16,9 @@ class RegisterMessage:
|
|||||||
"""Host client -> Router: Register this node."""
|
"""Host client -> Router: Register this node."""
|
||||||
type: str = "register"
|
type: str = "register"
|
||||||
node_id: str = ""
|
node_id: str = ""
|
||||||
serves_users: List[str] = field(default_factory=list)
|
serves_users: list[str] = field(default_factory=list)
|
||||||
working_dir: str = ""
|
working_dir: str = ""
|
||||||
capabilities: List[str] = field(default_factory=list)
|
capabilities: list[str] = field(default_factory=list)
|
||||||
display_name: str = ""
|
display_name: str = ""
|
||||||
|
|
||||||
|
|
||||||
@ -63,7 +63,7 @@ class NodeStatus:
|
|||||||
type: str = "node_status"
|
type: str = "node_status"
|
||||||
node_id: str = ""
|
node_id: str = ""
|
||||||
sessions: int = 0
|
sessions: int = 0
|
||||||
active_sessions: List[Dict[str, Any]] = field(default_factory=list)
|
active_sessions: list[dict[str, Any]] = field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
MESSAGE_TYPES = {
|
MESSAGE_TYPES = {
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user