Compare commits
2 Commits
64297e5e27
...
a3622ce26d
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a3622ce26d | ||
|
|
09b63341cd |
@ -46,7 +46,7 @@ class SessionManager:
|
||||
"""Registry of active Claude Code project sessions with persistence and user isolation."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._sessions: Dict[str, Session] = {}
|
||||
self._sessions: dict[str, Session] = {}
|
||||
self._lock = asyncio.Lock()
|
||||
self._reaper_task: Optional[asyncio.Task] = None
|
||||
|
||||
|
||||
@ -53,8 +53,8 @@ class Scheduler:
|
||||
"""Singleton that manages scheduled jobs with Feishu notifications."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._jobs: Dict[str, ScheduledJob] = {}
|
||||
self._tasks: Dict[str, asyncio.Task] = {}
|
||||
self._jobs: dict[str, ScheduledJob] = {}
|
||||
self._tasks: dict[str, asyncio.Task] = {}
|
||||
self._lock = asyncio.Lock()
|
||||
self._started = False
|
||||
|
||||
|
||||
@ -8,7 +8,7 @@ import time
|
||||
import uuid
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import Any, Callable, Dict, Optional
|
||||
from typing import Any, Awaitable, Callable, Dict, Optional
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -43,17 +43,17 @@ class TaskRunner:
|
||||
"""Singleton that manages background tasks with Feishu notifications."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._tasks: Dict[str, BackgroundTask] = {}
|
||||
self._tasks: dict[str, BackgroundTask] = {}
|
||||
self._lock = asyncio.Lock()
|
||||
self._notification_handler: Optional[Callable] = None
|
||||
self._notification_handler: Optional[Callable[[BackgroundTask], Awaitable[None]]] = None
|
||||
|
||||
def set_notification_handler(self, handler: Optional[Callable]) -> None:
|
||||
def set_notification_handler(self, handler: Optional[Callable[[BackgroundTask], Awaitable[None]]]) -> None:
|
||||
"""Set custom notification handler for M3 mode (host client -> router)."""
|
||||
self._notification_handler = handler
|
||||
|
||||
async def submit(
|
||||
self,
|
||||
coro: Callable[[], Any],
|
||||
coro: Awaitable[Any],
|
||||
description: str,
|
||||
notify_chat_id: Optional[str] = None,
|
||||
user_id: Optional[str] = None,
|
||||
@ -76,7 +76,7 @@ class TaskRunner:
|
||||
logger.info("Submitted background task %s: %s", task_id, description)
|
||||
return task_id
|
||||
|
||||
async def _run_task(self, task_id: str, coro: Callable[[], Any]) -> None:
|
||||
async def _run_task(self, task_id: str, coro: Awaitable[Any]) -> None:
|
||||
"""Execute a task and send notification on completion."""
|
||||
async with self._lock:
|
||||
task = self._tasks.get(task_id)
|
||||
@ -130,14 +130,15 @@ class TaskRunner:
|
||||
msg += f"\n\n**Error:** {task.error}"
|
||||
|
||||
try:
|
||||
await send_text(task.notify_chat_id, "chat_id", msg)
|
||||
if task.notify_chat_id:
|
||||
await send_text(task.notify_chat_id, "chat_id", msg)
|
||||
except Exception:
|
||||
logger.exception("Failed to send notification for task %s", task.task_id)
|
||||
|
||||
def get_task(self, task_id: str) -> Optional[BackgroundTask]:
|
||||
return self._tasks.get(task_id)
|
||||
|
||||
def list_tasks(self, limit: int = 20) -> list[dict]:
|
||||
def list_tasks(self, limit: int = 20) -> list[dict[str, Any]]:
|
||||
tasks = sorted(
|
||||
self._tasks.values(),
|
||||
key=lambda t: t.started_at,
|
||||
|
||||
@ -13,7 +13,7 @@ from agent.manager import manager
|
||||
from agent.scheduler import scheduler
|
||||
from agent.task_runner import task_runner
|
||||
from orchestrator.agent import agent
|
||||
from orchestrator.tools import set_current_user, get_current_user, get_current_chat
|
||||
from orchestrator.tools import set_current_user, get_current_chat
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -111,6 +111,18 @@ async def _cmd_new(user_id: str, args: str) -> str:
|
||||
conv_id = data.get("conv_id", "")
|
||||
agent._active_conv[user_id] = conv_id
|
||||
cwd = data.get("working_dir", working_dir)
|
||||
|
||||
chat_id = get_current_chat()
|
||||
if chat_id:
|
||||
from bot.feishu import send_card, send_text, build_sessions_card
|
||||
sessions = manager.list_sessions(user_id=user_id)
|
||||
mode = "Direct 🟢" if agent.get_passthrough(user_id) else "Smart ⚪"
|
||||
card = build_sessions_card(sessions, conv_id, mode)
|
||||
await send_card(chat_id, "chat_id", card)
|
||||
if initial_msg and data.get("response"):
|
||||
await send_text(chat_id, "chat_id", data["response"])
|
||||
return ""
|
||||
|
||||
reply = f"✓ Created session `{conv_id}` in `{cwd}`"
|
||||
if parsed.timeout:
|
||||
reply += f" (timeout: {parsed.timeout}s)"
|
||||
@ -124,17 +136,24 @@ async def _cmd_new(user_id: str, args: str) -> str:
|
||||
async def _cmd_status(user_id: str) -> str:
|
||||
"""Show status: sessions and current mode."""
|
||||
sessions = manager.list_sessions(user_id=user_id)
|
||||
if not sessions:
|
||||
return "No active sessions."
|
||||
|
||||
active = agent.get_active_conv(user_id)
|
||||
passthrough = agent.get_passthrough(user_id)
|
||||
mode = "Direct 🟢" if passthrough else "Smart ⚪"
|
||||
|
||||
chat_id = get_current_chat()
|
||||
if chat_id:
|
||||
from bot.feishu import send_card, build_sessions_card
|
||||
card = build_sessions_card(sessions, active, mode)
|
||||
await send_card(chat_id, "chat_id", card)
|
||||
return ""
|
||||
|
||||
if not sessions:
|
||||
return "No active sessions."
|
||||
lines = ["**Your Sessions:**\n"]
|
||||
for i, s in enumerate(sessions, 1):
|
||||
marker = "→ " if s["conv_id"] == active else " "
|
||||
lines.append(f"{marker}{i}. `{s['conv_id']}` - `{s['cwd']}`")
|
||||
status = "Direct 🟢" if passthrough else "Smart ⚪"
|
||||
lines.append(f"\n**Mode:** {status}")
|
||||
lines.append(f"\n**Mode:** {mode}")
|
||||
lines.append("Use `/switch <n>` to activate a session.")
|
||||
lines.append("Use `/direct` or `/smart` to change mode.")
|
||||
return "\n".join(lines)
|
||||
|
||||
109
bot/feishu.py
109
bot/feishu.py
@ -67,7 +67,7 @@ async def send_text(receive_id: str, receive_id_type: str, text: str) -> None:
|
||||
text: message content.
|
||||
"""
|
||||
parts = _split_message(text)
|
||||
loop = asyncio.get_event_loop()
|
||||
loop = asyncio.get_running_loop()
|
||||
|
||||
for i, part in enumerate(parts):
|
||||
logger.debug(
|
||||
@ -108,59 +108,16 @@ async def send_text(receive_id: str, receive_id_type: str, text: str) -> None:
|
||||
await asyncio.sleep(0.3)
|
||||
|
||||
|
||||
async def send_card(receive_id: str, receive_id_type: str, title: str, content: str, buttons: list[dict] | None = None) -> None:
|
||||
async def send_card(receive_id: str, receive_id_type: str, card: dict) -> None:
|
||||
"""
|
||||
Send an interactive card message.
|
||||
|
||||
Args:
|
||||
receive_id: chat_id or open_id
|
||||
receive_id_type: "chat_id" | "open_id" | "user_id" | "union_id"
|
||||
title: Card title
|
||||
content: Card content (markdown supported)
|
||||
buttons: List of button dicts with "text" and "value" keys
|
||||
receive_id: chat_id or open_id depending on receive_id_type.
|
||||
receive_id_type: "chat_id" | "open_id" | "user_id" | "union_id".
|
||||
card: Card content dict (Feishu card JSON schema).
|
||||
"""
|
||||
elements = [
|
||||
{
|
||||
"tag": "div",
|
||||
"text": {
|
||||
"tag": "lark_md",
|
||||
"content": content,
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
if buttons:
|
||||
actions = []
|
||||
for btn in buttons:
|
||||
actions.append({
|
||||
"tag": "button",
|
||||
"text": {"tag": "plain_text", "content": btn.get("text", "Button")},
|
||||
"type": "primary",
|
||||
"value": btn.get("value", {}),
|
||||
})
|
||||
elements.append({"tag": "action", "actions": actions})
|
||||
|
||||
card = {
|
||||
"type": "template",
|
||||
"data": {
|
||||
"template_id": "AAqkz9****",
|
||||
"template_variable": {
|
||||
"title": title,
|
||||
"elements": elements,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
card_content = {
|
||||
"config": {"wide_screen_mode": True},
|
||||
"header": {
|
||||
"title": {"tag": "plain_text", "content": title},
|
||||
"template": "blue",
|
||||
},
|
||||
"elements": elements,
|
||||
}
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
loop = asyncio.get_running_loop()
|
||||
request = (
|
||||
CreateMessageRequest.builder()
|
||||
.receive_id_type(receive_id_type)
|
||||
@ -168,7 +125,7 @@ async def send_card(receive_id: str, receive_id_type: str, title: str, content:
|
||||
CreateMessageRequestBody.builder()
|
||||
.receive_id(receive_id)
|
||||
.msg_type("interactive")
|
||||
.content(json.dumps(card_content, ensure_ascii=False))
|
||||
.content(json.dumps(card, ensure_ascii=False))
|
||||
.build()
|
||||
)
|
||||
.build()
|
||||
@ -185,6 +142,32 @@ async def send_card(receive_id: str, receive_id_type: str, title: str, content:
|
||||
logger.debug("Sent card to %s (%s)", receive_id, receive_id_type)
|
||||
|
||||
|
||||
def build_sessions_card(sessions: list[dict], active_conv_id: str | None, mode: str) -> dict:
|
||||
"""Build a card showing all sessions with active marker and mode info."""
|
||||
if sessions:
|
||||
lines = []
|
||||
for i, s in enumerate(sessions, 1):
|
||||
marker = "→" if s["conv_id"] == active_conv_id else " "
|
||||
started = "🟢" if s["started"] else "🟡"
|
||||
lines.append(f"{marker} {i}. {started} `{s['conv_id']}` — `{s['cwd']}`")
|
||||
sessions_md = "\n".join(lines)
|
||||
else:
|
||||
sessions_md = "_No active sessions_"
|
||||
|
||||
content = f"{sessions_md}\n\n**Mode:** {mode}"
|
||||
|
||||
return {
|
||||
"config": {"wide_screen_mode": True},
|
||||
"header": {
|
||||
"title": {"tag": "plain_text", "content": "Claude Code Sessions"},
|
||||
"template": "turquoise",
|
||||
},
|
||||
"elements": [
|
||||
{"tag": "div", "text": {"tag": "lark_md", "content": content}},
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
async def send_file(receive_id: str, receive_id_type: str, file_path: str, file_type: str = "stream") -> None:
|
||||
"""
|
||||
Upload a local file to Feishu and send it as a file message.
|
||||
@ -198,7 +181,7 @@ async def send_file(receive_id: str, receive_id_type: str, file_path: str, file_
|
||||
import os
|
||||
path = os.path.abspath(file_path)
|
||||
file_name = os.path.basename(path)
|
||||
loop = asyncio.get_event_loop()
|
||||
loop = asyncio.get_running_loop()
|
||||
|
||||
# Step 1: Upload file → get file_key
|
||||
with open(path, "rb") as f:
|
||||
@ -259,27 +242,3 @@ async def send_file(receive_id: str, receive_id_type: str, file_path: str, file_
|
||||
)
|
||||
else:
|
||||
logger.debug("Sent file %r to %s (%s)", file_name, receive_id, receive_id_type)
|
||||
|
||||
|
||||
def build_session_card(conv_id: str, cwd: str, started: bool) -> dict:
|
||||
"""Build a session status card."""
|
||||
status = "🟢 Active" if started else "🟡 Ready"
|
||||
content = f"**Session ID:** `{conv_id}`\n**Directory:** `{cwd}`\n**Status:** {status}"
|
||||
return {
|
||||
"config": {"wide_screen_mode": True},
|
||||
"header": {
|
||||
"title": {"tag": "plain_text", "content": "Claude Code Session"},
|
||||
"template": "turquoise",
|
||||
},
|
||||
"elements": [
|
||||
{"tag": "div", "text": {"tag": "lark_md", "content": content}},
|
||||
{"tag": "hr"},
|
||||
{
|
||||
"tag": "action",
|
||||
"actions": [
|
||||
{"tag": "button", "text": {"tag": "plain_text", "content": "Continue"}, "type": "primary", "value": {"action": "continue", "conv_id": conv_id}},
|
||||
{"tag": "button", "text": {"tag": "plain_text", "content": "Close"}, "type": "default", "value": {"action": "close", "conv_id": conv_id}},
|
||||
],
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
@ -7,6 +7,7 @@ import json
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
from typing import Any, Dict
|
||||
|
||||
import lark_oapi as lark
|
||||
from lark_oapi.api.im.v1 import P2ImMessageReceiveV1
|
||||
@ -25,7 +26,7 @@ _last_message_time: float = 0.0
|
||||
_reconnect_count: int = 0
|
||||
|
||||
|
||||
def get_ws_status() -> dict:
|
||||
def get_ws_status() -> dict[str, Any]:
|
||||
"""Return WebSocket connection status."""
|
||||
return {
|
||||
"connected": _ws_connected,
|
||||
@ -39,8 +40,17 @@ def _handle_message(data: P2ImMessageReceiveV1) -> None:
|
||||
_last_message_time = time.time()
|
||||
|
||||
try:
|
||||
message = data.event.message
|
||||
sender = data.event.sender
|
||||
event = data.event
|
||||
if event is None:
|
||||
logger.warning("Received event with no data")
|
||||
return
|
||||
|
||||
message = event.message
|
||||
sender = event.sender
|
||||
|
||||
if message is None:
|
||||
logger.warning("Received event with no message")
|
||||
return
|
||||
|
||||
logger.debug(
|
||||
"event type=%r chat_type=%r content=%r",
|
||||
@ -53,7 +63,7 @@ def _handle_message(data: P2ImMessageReceiveV1) -> None:
|
||||
logger.info("Skipping non-text message_type=%r", message.message_type)
|
||||
return
|
||||
|
||||
chat_id: str = message.chat_id
|
||||
chat_id: str = message.chat_id or ""
|
||||
raw_content: str = message.content or "{}"
|
||||
content_obj = json.loads(raw_content)
|
||||
text: str = content_obj.get("text", "").strip()
|
||||
@ -119,8 +129,9 @@ async def _process_message(user_id: str, chat_id: str, text: str) -> None:
|
||||
if nodes:
|
||||
lines = ["Connected Nodes:"]
|
||||
for n in nodes:
|
||||
marker = " → " if n.get("node_id") == registry.get_active_node(user_id) else " "
|
||||
lines.append(f"{marker}{n['display_name']} sessions={n['sessions']} {n['status']}")
|
||||
active_node = registry.get_active_node(user_id)
|
||||
marker = " → " if n.get("node_id") == (active_node.node_id if active_node else None) else " "
|
||||
lines.append(f"{marker}{n.get('display_name', 'unknown')} sessions={n.get('sessions', 0)} {n.get('status', 'unknown')}")
|
||||
lines.append("\nUse \"/node <name>\" to switch active node.")
|
||||
await send_text(chat_id, "chat_id", "\n".join(lines))
|
||||
else:
|
||||
@ -144,7 +155,9 @@ async def _process_message(user_id: str, chat_id: str, text: str) -> None:
|
||||
|
||||
def _handle_any(data: lark.CustomizedEvent) -> None:
|
||||
"""Catch-all handler to log any event the SDK receives."""
|
||||
logger.info("RAW CustomizedEvent: %s", lark.JSON.marshal(data)[:500])
|
||||
marshaled = lark.JSON.marshal(data)
|
||||
if marshaled:
|
||||
logger.info("RAW CustomizedEvent: %s", marshaled[:500])
|
||||
|
||||
|
||||
def build_event_handler() -> lark.EventDispatcherHandler:
|
||||
|
||||
@ -1,11 +1,11 @@
|
||||
import yaml
|
||||
from pathlib import Path
|
||||
from typing import List, Optional
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
_CONFIG_PATH = Path(__file__).parent / "keyring.yaml"
|
||||
|
||||
|
||||
def _load() -> dict:
|
||||
def _load() -> dict[str, Any]:
|
||||
with open(_CONFIG_PATH, "r", encoding="utf-8") as f:
|
||||
return yaml.safe_load(f) or {}
|
||||
|
||||
@ -23,9 +23,8 @@ METASO_API_KEY: str = _cfg.get("METASO_API_KEY", "")
|
||||
ROUTER_MODE: bool = _cfg.get("ROUTER_MODE", False)
|
||||
ROUTER_SECRET: str = _cfg.get("ROUTER_SECRET", "")
|
||||
|
||||
ALLOWED_OPEN_IDS: List[str] = _cfg.get("ALLOWED_OPEN_IDS", [])
|
||||
if ALLOWED_OPEN_IDS and not isinstance(ALLOWED_OPEN_IDS, list):
|
||||
ALLOWED_OPEN_IDS = [str(ALLOWED_OPEN_IDS)]
|
||||
_allowed_open_ids_raw = _cfg.get("ALLOWED_OPEN_IDS", [])
|
||||
ALLOWED_OPEN_IDS: list[str] = _allowed_open_ids_raw if isinstance(_allowed_open_ids_raw, list) else [str(_allowed_open_ids_raw)]
|
||||
|
||||
|
||||
def is_user_allowed(open_id: str) -> bool:
|
||||
|
||||
@ -46,9 +46,9 @@ class HostConfig:
|
||||
self.metaso_api_key: Optional[str] = data.get("METASO_API_KEY")
|
||||
|
||||
serves_users = data.get("SERVES_USERS", [])
|
||||
self.serves_users: List[str] = serves_users if isinstance(serves_users, list) else []
|
||||
self.serves_users: list[str] = serves_users if isinstance(serves_users, list) else []
|
||||
|
||||
self.capabilities: List[str] = data.get(
|
||||
self.capabilities: list[str] = data.get(
|
||||
"CAPABILITIES",
|
||||
["claude_code", "shell", "file_ops", "web", "scheduler"],
|
||||
)
|
||||
|
||||
@ -10,16 +10,15 @@ import asyncio
|
||||
import logging
|
||||
import secrets
|
||||
import time
|
||||
from typing import Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
import websockets
|
||||
from websockets.client import WebSocketClientProtocol
|
||||
|
||||
from agent.manager import manager
|
||||
from agent.scheduler import scheduler
|
||||
from agent.task_runner import task_runner
|
||||
from host_client.config import HostConfig, get_host_config
|
||||
from orchestrator.agent import run as run_mailboy
|
||||
from orchestrator.agent import agent
|
||||
from orchestrator.tools import set_current_user, set_current_chat
|
||||
from shared import (
|
||||
RegisterMessage,
|
||||
@ -40,7 +39,7 @@ class NodeClient:
|
||||
|
||||
def __init__(self, config: HostConfig):
|
||||
self.config = config
|
||||
self.ws: Optional[WebSocketClientProtocol] = None
|
||||
self.ws: Any = None
|
||||
self._running = False
|
||||
self._last_heartbeat = time.time()
|
||||
self._reconnect_delay = 1.0
|
||||
@ -94,7 +93,7 @@ class NodeClient:
|
||||
set_current_chat(request.chat_id)
|
||||
|
||||
try:
|
||||
reply = await run_mailboy(request.user_id, request.text)
|
||||
reply = await agent.run(request.user_id, request.text)
|
||||
|
||||
response = ForwardResponse(
|
||||
id=request.id,
|
||||
@ -131,7 +130,7 @@ class NodeClient:
|
||||
|
||||
sessions = manager.list_sessions()
|
||||
active_sessions = [
|
||||
{"conv_id": s["conv_id"], "working_dir": s["working_dir"]}
|
||||
{"conv_id": s["conv_id"], "working_dir": s["cwd"]}
|
||||
for s in sessions
|
||||
]
|
||||
|
||||
|
||||
2
main.py
2
main.py
@ -91,7 +91,7 @@ async def startup_event() -> None:
|
||||
await manager.start()
|
||||
from agent.scheduler import scheduler
|
||||
await scheduler.start()
|
||||
loop = asyncio.get_event_loop()
|
||||
loop = asyncio.get_running_loop()
|
||||
start_websocket_client(loop)
|
||||
logger.info("PhoneWork started")
|
||||
|
||||
|
||||
@ -97,18 +97,18 @@ class OrchestrationAgent:
|
||||
base_url=OPENAI_BASE_URL,
|
||||
api_key=OPENAI_API_KEY,
|
||||
model=OPENAI_MODEL,
|
||||
temperature=0.0,
|
||||
temperature=0.1,
|
||||
)
|
||||
self._llm_with_tools = llm.bind_tools(TOOLS)
|
||||
|
||||
# user_id -> list[BaseMessage]
|
||||
self._history: Dict[str, List[BaseMessage]] = defaultdict(list)
|
||||
self._history: dict[str, list[BaseMessage]] = defaultdict(list)
|
||||
# user_id -> most recently active conv_id
|
||||
self._active_conv: Dict[str, Optional[str]] = defaultdict(lambda: None)
|
||||
self._active_conv: dict[str, Optional[str]] = defaultdict(lambda: None)
|
||||
# user_id -> asyncio.Lock (prevents concurrent processing per user)
|
||||
self._user_locks: Dict[str, asyncio.Lock] = defaultdict(asyncio.Lock)
|
||||
self._user_locks: dict[str, asyncio.Lock] = defaultdict(asyncio.Lock)
|
||||
# user_id -> passthrough mode enabled
|
||||
self._passthrough: Dict[str, bool] = defaultdict(lambda: False)
|
||||
self._passthrough: dict[str, bool] = defaultdict(lambda: False)
|
||||
|
||||
def _build_system_prompt(self, user_id: str) -> str:
|
||||
conv_id = self._active_conv[user_id]
|
||||
@ -173,7 +173,7 @@ class OrchestrationAgent:
|
||||
response = await llm_no_tools.ainvoke([HumanMessage(content=qa_prompt)])
|
||||
return response.content or ""
|
||||
|
||||
messages: List[BaseMessage] = (
|
||||
messages: list[BaseMessage] = (
|
||||
[SystemMessage(content=self._build_system_prompt(user_id))]
|
||||
+ self._history[user_id]
|
||||
+ [HumanMessage(content=text)]
|
||||
|
||||
@ -301,7 +301,8 @@ class FileReadTool(BaseTool):
|
||||
|
||||
async def _arun(self, path: str, start_line: Optional[int] = None, end_line: Optional[int] = None) -> str:
|
||||
try:
|
||||
file_path = _resolve_dir(path)
|
||||
p = Path(path.strip())
|
||||
file_path = _resolve_dir(str(p.parent)) / p.name if not p.is_absolute() else p.resolve()
|
||||
if not file_path.is_file():
|
||||
return json.dumps({"error": f"Not a file: {path}"}, ensure_ascii=False)
|
||||
|
||||
@ -342,7 +343,8 @@ class FileWriteTool(BaseTool):
|
||||
|
||||
async def _arun(self, path: str, content: str, mode: Optional[str] = "overwrite") -> str:
|
||||
try:
|
||||
file_path = _resolve_dir(path)
|
||||
p = Path(path.strip())
|
||||
file_path = (_resolve_dir(str(p.parent)) / p.name) if not p.is_absolute() else p.resolve()
|
||||
file_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
write_mode = "a" if mode == "append" else "w"
|
||||
@ -484,7 +486,8 @@ class FileSendTool(BaseTool):
|
||||
|
||||
async def _arun(self, path: str) -> str:
|
||||
try:
|
||||
file_path = _resolve_dir(path)
|
||||
p = Path(path.strip())
|
||||
file_path = _resolve_dir(str(p.parent)) / p.name if not p.is_absolute() else p.resolve()
|
||||
if not file_path.is_file():
|
||||
return json.dumps({"error": f"Not a file: {path}"}, ensure_ascii=False)
|
||||
|
||||
|
||||
@ -43,6 +43,7 @@ def create_app(router_secret: Optional[str] = None) -> FastAPI:
|
||||
|
||||
@app.get("/health")
|
||||
async def health():
|
||||
from router.rpc import get_pending_count
|
||||
nodes = registry.list_nodes()
|
||||
online_nodes = [n for n in nodes if n["status"] == "online"]
|
||||
return {
|
||||
@ -50,7 +51,7 @@ def create_app(router_secret: Optional[str] = None) -> FastAPI:
|
||||
"nodes": nodes,
|
||||
"online_nodes": len(online_nodes),
|
||||
"total_nodes": len(nodes),
|
||||
"pending_requests": 0,
|
||||
"pending_requests": get_pending_count(),
|
||||
}
|
||||
|
||||
@app.get("/nodes")
|
||||
@ -64,7 +65,7 @@ def create_app(router_secret: Optional[str] = None) -> FastAPI:
|
||||
@app.on_event("startup")
|
||||
async def startup():
|
||||
import asyncio
|
||||
loop = asyncio.get_event_loop()
|
||||
loop = asyncio.get_running_loop()
|
||||
start_websocket_client(loop)
|
||||
logger.info("Router started")
|
||||
|
||||
|
||||
@ -27,18 +27,18 @@ class NodeConnection:
|
||||
display_name: str = ""
|
||||
serves_users: Set[str] = field(default_factory=set)
|
||||
working_dir: str = ""
|
||||
capabilities: List[str] = field(default_factory=list)
|
||||
capabilities: list[str] = field(default_factory=list)
|
||||
connected_at: float = field(default_factory=time.time)
|
||||
last_heartbeat: float = field(default_factory=time.time)
|
||||
sessions: int = 0
|
||||
active_sessions: List[Dict[str, Any]] = field(default_factory=list)
|
||||
active_sessions: list[dict[str, Any]] = field(default_factory=list)
|
||||
|
||||
@property
|
||||
def is_online(self) -> bool:
|
||||
"""Check if node is still considered online (heartbeat within 60s)."""
|
||||
return time.time() - self.last_heartbeat < 60
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Serialize for API responses."""
|
||||
return {
|
||||
"node_id": self.node_id,
|
||||
@ -56,9 +56,9 @@ class NodeRegistry:
|
||||
"""Registry of connected host clients."""
|
||||
|
||||
def __init__(self, router_secret: str = ""):
|
||||
self._nodes: Dict[str, NodeConnection] = {}
|
||||
self._user_nodes: Dict[str, Set[str]] = {}
|
||||
self._active_node: Dict[str, str] = {}
|
||||
self._nodes: dict[str, NodeConnection] = {}
|
||||
self._user_nodes: dict[str, Set[str]] = {}
|
||||
self._active_node: dict[str, str] = {}
|
||||
self._secret = router_secret
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
@ -159,7 +159,7 @@ class NodeRegistry:
|
||||
"""Get a node by ID."""
|
||||
return self._nodes.get(node_id)
|
||||
|
||||
def get_nodes_for_user(self, user_id: str) -> List[NodeConnection]:
|
||||
def get_nodes_for_user(self, user_id: str) -> list[NodeConnection]:
|
||||
"""Get all nodes that serve a user."""
|
||||
node_ids = self._user_nodes.get(user_id, set())
|
||||
return [self._nodes[nid] for nid in node_ids if nid in self._nodes]
|
||||
@ -186,11 +186,11 @@ class NodeRegistry:
|
||||
logger.info("Active node for user %s set to %s", user_id, node_id)
|
||||
return True
|
||||
|
||||
def list_nodes(self) -> List[Dict[str, Any]]:
|
||||
def list_nodes(self) -> list[dict[str, Any]]:
|
||||
"""List all nodes with their status."""
|
||||
return [node.to_dict() for node in self._nodes.values()]
|
||||
|
||||
def get_affected_users(self, node_id: str) -> List[str]:
|
||||
def get_affected_users(self, node_id: str) -> list[str]:
|
||||
"""Get users affected by a node disconnect."""
|
||||
node = self._nodes.get(node_id)
|
||||
if node:
|
||||
|
||||
@ -12,6 +12,7 @@ from typing import List, Optional
|
||||
|
||||
from langchain_core.messages import HumanMessage, SystemMessage
|
||||
from langchain_openai import ChatOpenAI
|
||||
from pydantic import SecretStr
|
||||
|
||||
from config import OPENAI_API_KEY, OPENAI_BASE_URL, OPENAI_MODEL
|
||||
from router.nodes import NodeConnection, get_node_registry
|
||||
@ -36,7 +37,7 @@ Respond with a JSON object:
|
||||
"""
|
||||
|
||||
|
||||
def _format_nodes_info(nodes: List[NodeConnection], active_node_id: Optional[str] = None) -> str:
|
||||
def _format_nodes_info(nodes: list[NodeConnection], active_node_id: Optional[str] = None) -> str:
|
||||
"""Format node information for the routing prompt."""
|
||||
lines = []
|
||||
for node in nodes:
|
||||
@ -86,8 +87,8 @@ async def route(user_id: str, chat_id: str, text: str) -> tuple[Optional[str], s
|
||||
try:
|
||||
llm = ChatOpenAI(
|
||||
model=OPENAI_MODEL,
|
||||
openai_api_key=OPENAI_API_KEY,
|
||||
openai_api_base=OPENAI_BASE_URL,
|
||||
api_key=SecretStr(OPENAI_API_KEY),
|
||||
base_url=OPENAI_BASE_URL,
|
||||
temperature=0,
|
||||
)
|
||||
|
||||
@ -98,7 +99,11 @@ async def route(user_id: str, chat_id: str, text: str) -> tuple[Optional[str], s
|
||||
]
|
||||
|
||||
response = await llm.ainvoke(messages)
|
||||
content = response.content.strip()
|
||||
content = response.content
|
||||
if isinstance(content, str):
|
||||
content = content.strip()
|
||||
else:
|
||||
content = str(content).strip()
|
||||
|
||||
if content.startswith("```"):
|
||||
content = content.split("\n", 1)[1]
|
||||
|
||||
@ -18,7 +18,7 @@ from router.nodes import get_node_registry
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_pending_requests: Dict[str, asyncio.Future] = {}
|
||||
_pending_requests: dict[str, asyncio.Future[str]] = {}
|
||||
_default_timeout = 600.0
|
||||
|
||||
|
||||
@ -52,7 +52,7 @@ async def forward(
|
||||
raise RuntimeError(f"Node not connected: {node_id}")
|
||||
|
||||
request_id = str(uuid.uuid4())
|
||||
future: asyncio.Future = asyncio.get_event_loop().create_future()
|
||||
future: asyncio.Future[str] = asyncio.get_running_loop().create_future()
|
||||
_pending_requests[request_id] = future
|
||||
|
||||
request = ForwardRequest(
|
||||
|
||||
@ -47,7 +47,7 @@ async def ws_node_endpoint(websocket: WebSocket) -> None:
|
||||
return
|
||||
|
||||
node_id: Optional[str] = None
|
||||
heartbeat_task: Optional[asyncio.Task] = None
|
||||
heartbeat_task: Optional[asyncio.Task[None]] = None
|
||||
|
||||
async def send_heartbeat():
|
||||
"""Send periodic pings to the host client."""
|
||||
|
||||
@ -16,9 +16,9 @@ class RegisterMessage:
|
||||
"""Host client -> Router: Register this node."""
|
||||
type: str = "register"
|
||||
node_id: str = ""
|
||||
serves_users: List[str] = field(default_factory=list)
|
||||
serves_users: list[str] = field(default_factory=list)
|
||||
working_dir: str = ""
|
||||
capabilities: List[str] = field(default_factory=list)
|
||||
capabilities: list[str] = field(default_factory=list)
|
||||
display_name: str = ""
|
||||
|
||||
|
||||
@ -63,7 +63,7 @@ class NodeStatus:
|
||||
type: str = "node_status"
|
||||
node_id: str = ""
|
||||
sessions: int = 0
|
||||
active_sessions: List[Dict[str, Any]] = field(default_factory=list)
|
||||
active_sessions: list[dict[str, Any]] = field(default_factory=list)
|
||||
|
||||
|
||||
MESSAGE_TYPES = {
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user