Yuyao Huang (Sam) 80e4953cf9 feat: 优化WebSocket连接和心跳机制
- 在main.py和standalone.py中添加ws_ping_interval和ws_ping_timeout配置
- 调整ws.py中的心跳发送逻辑,先发送ping再等待
- 在host_client中优化消息处理,使用任务队列处理转发请求
- 更新WebTool以适配新的API格式并增加搜索结果限制
- 在agent.py中添加日期显示和web调用次数限制
- 修复bot/handler.py中的事件循环问题
2026-03-28 15:53:44 +08:00

735 lines
28 KiB
Python

"""LangChain tools that bridge the orchestration agent to Claude Code PTY sessions."""
from __future__ import annotations
import json
import uuid
from contextvars import ContextVar
from pathlib import Path
from typing import Optional, Type
from langchain_core.tools import BaseTool
from pydantic import BaseModel, Field
from agent.manager import manager
from config import WORKING_DIR
_current_user_id: ContextVar[Optional[str]] = ContextVar("current_user_id", default=None)
_current_chat_id: ContextVar[Optional[str]] = ContextVar("current_chat_id", default=None)
def set_current_user(user_id: Optional[str]) -> None:
_current_user_id.set(user_id)
def get_current_user() -> Optional[str]:
return _current_user_id.get()
def set_current_chat(chat_id: Optional[str]) -> None:
_current_chat_id.set(chat_id)
def get_current_chat() -> Optional[str]:
return _current_chat_id.get()
def _resolve_dir(working_dir: str) -> Path:
"""
Resolve working_dir to an absolute path under WORKING_DIR.
Rules:
- Absolute paths are used as-is (but must stay within WORKING_DIR for safety).
- Relative paths / bare names are joined onto WORKING_DIR.
- Path traversal attempts (..) are blocked.
- The resolved directory is created if it doesn't exist.
"""
working_dir = working_dir.strip()
if ".." in working_dir.split("/") or ".." in working_dir.split("\\"):
raise ValueError(
"Path traversal not allowed. Use a subfolder name or path inside the working directory."
)
p = Path(working_dir)
if not p.is_absolute():
p = WORKING_DIR / p
p = p.resolve()
try:
p.relative_to(WORKING_DIR)
except ValueError:
raise ValueError(
f"Directory {p} is outside the allowed base {WORKING_DIR}. "
"Please use a subfolder name or a path inside the working directory."
)
p.mkdir(parents=True, exist_ok=True)
return p
# ---------------------------------------------------------------------------
# Input schemas
# ---------------------------------------------------------------------------
class CreateConversationInput(BaseModel):
working_dir: str = Field(
...,
description=(
"The project directory. Can be a subfolder name (e.g. 'todo_app'), "
"a relative path (e.g. 'projects/todo_app'), or a full absolute path. "
"Relative names are resolved under the configured base working directory."
),
)
initial_message: Optional[str] = Field(None, description="Optional first message to send after spawning")
idle_timeout: Optional[int] = Field(None, description="Idle timeout in seconds (default 1800)")
cc_timeout: Optional[float] = Field(None, description="Claude Code execution timeout in seconds (default 300)")
class SendToConversationInput(BaseModel):
conv_id: str = Field(..., description="Conversation ID returned by create_conversation")
message: str = Field(..., description="Message / instruction to send to Claude Code")
class CloseConversationInput(BaseModel):
conv_id: str = Field(..., description="Conversation ID to close")
# ---------------------------------------------------------------------------
# Tools
# ---------------------------------------------------------------------------
class CreateConversationTool(BaseTool):
name: str = "create_conversation"
description: str = (
"Spawn a new Claude Code session in the given working directory. "
"Returns a conv_id that must be used for subsequent messages. "
"Use this when the user wants to start a new task in a specific directory."
)
args_schema: Type[BaseModel] = CreateConversationInput
def _run(self, working_dir: str, initial_message: Optional[str] = None, idle_timeout: Optional[int] = None, cc_timeout: Optional[float] = None) -> str:
raise NotImplementedError("Use async version")
async def _arun(self, working_dir: str, initial_message: Optional[str] = None, idle_timeout: Optional[int] = None, cc_timeout: Optional[float] = None) -> str:
try:
resolved = _resolve_dir(working_dir)
except ValueError as exc:
return json.dumps({"error": str(exc)})
user_id = get_current_user()
conv_id = str(uuid.uuid4())[:8]
await manager.create(
conv_id,
str(resolved),
owner_id=user_id or "",
idle_timeout=idle_timeout or 1800,
cc_timeout=cc_timeout or 300.0,
)
result: dict = {
"conv_id": conv_id,
"working_dir": str(resolved),
}
if initial_message:
output = await manager.send(conv_id, initial_message, user_id=user_id)
result["response"] = output
else:
result["status"] = "Session created. Send a message to start working."
return json.dumps(result, ensure_ascii=False)
class SendToConversationTool(BaseTool):
name: str = "send_to_conversation"
description: str = (
"Send a message to an existing Claude Code session and return its response. "
"Use this for follow-up messages in an ongoing session."
)
args_schema: Type[BaseModel] = SendToConversationInput
def _run(self, conv_id: str, message: str) -> str:
raise NotImplementedError("Use async version")
async def _arun(self, conv_id: str, message: str) -> str:
user_id = get_current_user()
try:
output = await manager.send(conv_id, message, user_id=user_id)
return json.dumps({"conv_id": conv_id, "response": output}, ensure_ascii=False)
except KeyError:
return json.dumps({"error": f"No active session for conv_id={conv_id!r}"})
except PermissionError as e:
return json.dumps({"error": str(e)})
class ListConversationsTool(BaseTool):
name: str = "list_conversations"
description: str = "List all currently active Claude Code sessions."
def _run(self) -> str:
raise NotImplementedError("Use async version")
async def _arun(self) -> str:
user_id = get_current_user()
sessions = manager.list_sessions(user_id=user_id)
if not sessions:
return "No active sessions."
return json.dumps(sessions, ensure_ascii=False, indent=2)
class CloseConversationTool(BaseTool):
name: str = "close_conversation"
description: str = "Close and terminate an active Claude Code session."
args_schema: Type[BaseModel] = CloseConversationInput
def _run(self, conv_id: str) -> str:
raise NotImplementedError("Use async version")
async def _arun(self, conv_id: str) -> str:
user_id = get_current_user()
try:
closed = await manager.close(conv_id, user_id=user_id)
if closed:
return f"Session {conv_id!r} closed."
return f"Session {conv_id!r} not found."
except PermissionError as e:
return str(e)
BLOCKED_PATTERNS = [
r'\brm\s+-rf\s+/',
r'\brm\s+-rf\s+~',
r'\bformat\s+',
r'\bmkfs\b',
r'\bshutdown\b',
r'\breboot\b',
r'\bdd\s+if=',
r':\(\)\{:\|:&\};:',
r'\bchmod\s+777\s+/',
r'\bchown\s+.*\s+/',
r'\b>\s*/dev/sd',
r'\bkill\s+-9\s+1\b',
r'\bsudo\s+rm\b',
r'\bsu\s+-c\b',
r'\bsudo\s+chmod\b',
r'\bsudo\s+chown\b',
r'\bsudo\s+dd\b',
r'\b>\s*/dev/null\s+2>&1\s*&\s*;', # fork via backgrounding
]
def _is_command_safe(command: str) -> tuple[bool, str]:
"""Check if command is safe to execute."""
import re
for pattern in BLOCKED_PATTERNS:
if re.search(pattern, command, re.IGNORECASE):
return False, f"Blocked: command matches dangerous pattern"
return True, ""
class ShellInput(BaseModel):
command: str = Field(..., description="Shell command to execute")
cwd: Optional[str] = Field(None, description="Working directory (default: WORKING_DIR)")
timeout: Optional[int] = Field(30, description="Timeout in seconds (max 120)")
class ShellTool(BaseTool):
name: str = "run_shell"
description: str = (
"Execute a shell command on the host machine and return stdout/stderr. "
"Use for: git status, ls, cat, grep, pip list, etc. "
"Destructive commands (rm -rf /, format, shutdown) are blocked."
)
args_schema: Type[BaseModel] = ShellInput
def _run(self, command: str, cwd: Optional[str] = None, timeout: Optional[int] = 30) -> str:
raise NotImplementedError("Use async version")
async def _arun(self, command: str, cwd: Optional[str] = None, timeout: Optional[int] = 30) -> str:
import asyncio
import shutil
is_safe, reason = _is_command_safe(command)
if not is_safe:
return json.dumps({"error": reason}, ensure_ascii=False)
timeout = min(timeout or 30, 120)
work_dir = WORKING_DIR
if cwd:
try:
work_dir = _resolve_dir(cwd)
except ValueError as e:
return json.dumps({"error": str(e)}, ensure_ascii=False)
try:
proc = await asyncio.create_subprocess_shell(
command,
cwd=str(work_dir),
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
)
stdout, stderr = await asyncio.wait_for(
proc.communicate(),
timeout=timeout,
)
return json.dumps({
"stdout": stdout.decode("utf-8", errors="replace")[:4000],
"stderr": stderr.decode("utf-8", errors="replace")[:1000],
"exit_code": proc.returncode,
"cwd": str(work_dir),
}, ensure_ascii=False)
except asyncio.TimeoutError:
return json.dumps({"error": f"Command timed out after {timeout}s"}, ensure_ascii=False)
except Exception as e:
return json.dumps({"error": str(e)}, ensure_ascii=False)
class FileReadInput(BaseModel):
path: str = Field(..., description="File path relative to working directory")
start_line: Optional[int] = Field(None, description="Start line (1-indexed)")
end_line: Optional[int] = Field(None, description="End line (inclusive)")
class FileReadTool(BaseTool):
name: str = "read_file"
description: str = "Read a file from the working directory. Returns file content with line numbers."
args_schema: Type[BaseModel] = FileReadInput
def _run(self, path: str, start_line: Optional[int] = None, end_line: Optional[int] = None) -> str:
raise NotImplementedError("Use async version")
async def _arun(self, path: str, start_line: Optional[int] = None, end_line: Optional[int] = None) -> str:
try:
p = Path(path.strip())
file_path = _resolve_dir(str(p.parent)) / p.name if not p.is_absolute() else p.resolve()
if not file_path.is_file():
return json.dumps({"error": f"Not a file: {path}"}, ensure_ascii=False)
with open(file_path, "r", encoding="utf-8", errors="replace") as f:
lines = f.readlines()
total_lines = len(lines)
start = max(1, start_line or 1) - 1
end = min(total_lines, end_line or total_lines)
result_lines = []
for i in range(start, end):
result_lines.append(f"{i+1:4d} | {lines[i].rstrip()}")
return json.dumps({
"path": str(file_path),
"lines": f"{start+1}-{end}",
"total_lines": total_lines,
"content": "\n".join(result_lines[-500:]),
}, ensure_ascii=False)
except Exception as e:
return json.dumps({"error": str(e)}, ensure_ascii=False)
class FileWriteInput(BaseModel):
path: str = Field(..., description="File path relative to working directory")
content: str = Field(..., description="Content to write")
mode: Optional[str] = Field("overwrite", description="Write mode: 'overwrite' or 'append'")
class FileWriteTool(BaseTool):
name: str = "write_file"
description: str = "Write content to a file in the working directory. Use mode='append' to add to existing file."
args_schema: Type[BaseModel] = FileWriteInput
def _run(self, path: str, content: str, mode: Optional[str] = "overwrite") -> str:
raise NotImplementedError("Use async version")
async def _arun(self, path: str, content: str, mode: Optional[str] = "overwrite") -> str:
try:
p = Path(path.strip())
file_path = (_resolve_dir(str(p.parent)) / p.name) if not p.is_absolute() else p.resolve()
file_path.parent.mkdir(parents=True, exist_ok=True)
write_mode = "a" if mode == "append" else "w"
with open(file_path, write_mode, encoding="utf-8") as f:
f.write(content)
return json.dumps({
"success": True,
"path": str(file_path),
"bytes_written": len(content.encode("utf-8")),
}, ensure_ascii=False)
except Exception as e:
return json.dumps({"error": str(e)}, ensure_ascii=False)
class FileListInput(BaseModel):
path: Optional[str] = Field(None, description="Directory path (default: working directory)")
pattern: Optional[str] = Field(None, description="Glob pattern (e.g. '*.py')")
class FileListTool(BaseTool):
name: str = "list_files"
description: str = "List files in a directory. Use pattern to filter (e.g. '*.py')."
args_schema: Type[BaseModel] = FileListInput
def _run(self, path: Optional[str] = None, pattern: Optional[str] = None) -> str:
raise NotImplementedError("Use async version")
async def _arun(self, path: Optional[str] = None, pattern: Optional[str] = None) -> str:
try:
dir_path = _resolve_dir(path or ".")
if not dir_path.is_dir():
return json.dumps({"error": f"Not a directory: {path}"}, ensure_ascii=False)
if pattern:
files = list(dir_path.glob(pattern))[:100]
else:
files = list(dir_path.iterdir())[:100]
result = []
for f in sorted(files):
result.append({
"name": f.name,
"type": "dir" if f.is_dir() else "file",
"size": f.stat().st_size if f.is_file() else None,
})
return json.dumps({
"path": str(dir_path),
"files": result,
}, ensure_ascii=False)
except Exception as e:
return json.dumps({"error": str(e)}, ensure_ascii=False)
class FileSearchInput(BaseModel):
path: str = Field(..., description="Directory path to search in")
pattern: str = Field(..., description="Search pattern (regex supported)")
max_results: Optional[int] = Field(50, description="Max number of results")
class FileSearchTool(BaseTool):
name: str = "search_files"
description: str = (
"Search for text pattern in files under a directory (grep-like). "
"Returns matching lines with file paths and line numbers."
)
args_schema: Type[BaseModel] = FileSearchInput
def _run(self, path: str, pattern: str, max_results: Optional[int] = 50) -> str:
raise NotImplementedError("Use async version")
async def _arun(self, path: str, pattern: str, max_results: Optional[int] = 50) -> str:
import re
try:
dir_path = _resolve_dir(path)
if not dir_path.is_dir():
return json.dumps({"error": f"Not a directory: {path}"}, ensure_ascii=False)
try:
regex = re.compile(pattern, re.IGNORECASE)
except re.error as e:
return json.dumps({"error": f"Invalid regex pattern: {e}"}, ensure_ascii=False)
results = []
text_extensions = {'.py', '.js', '.ts', '.tsx', '.jsx', '.java', '.c', '.cpp', '.h',
'.go', '.rs', '.rb', '.php', '.cs', '.swift', '.kt', '.scala',
'.txt', '.md', '.json', '.yaml', '.yml', '.toml', '.ini', '.cfg',
'.html', '.css', '.scss', '.sass', '.less', '.xml', '.sql',
'.sh', '.bash', '.zsh', '.ps1', '.bat', '.cmd'}
for file_path in dir_path.rglob("*"):
if not file_path.is_file():
continue
if file_path.suffix.lower() not in text_extensions:
continue
if any(part.startswith('.') for part in file_path.parts):
continue
try:
with open(file_path, "r", encoding="utf-8", errors="ignore") as f:
for line_num, line in enumerate(f, 1):
if regex.search(line):
rel_path = file_path.relative_to(dir_path)
results.append({
"file": str(rel_path),
"line": line_num,
"content": line.rstrip()[:200],
})
if len(results) >= max_results:
break
if len(results) >= max_results:
break
except Exception:
continue
return json.dumps({
"path": str(dir_path),
"pattern": pattern,
"total_matches": len(results),
"results": results,
}, ensure_ascii=False)
except Exception as e:
return json.dumps({"error": str(e)}, ensure_ascii=False)
class FileSendInput(BaseModel):
path: str = Field(..., description="File path to send")
class FileSendTool(BaseTool):
name: str = "send_file"
description: str = "Send a file to the user via Feishu. Returns confirmation message."
args_schema: Type[BaseModel] = FileSendInput
def _run(self, path: str) -> str:
raise NotImplementedError("Use async version")
async def _arun(self, path: str) -> str:
try:
p = Path(path.strip())
file_path = _resolve_dir(str(p.parent)) / p.name if not p.is_absolute() else p.resolve()
if not file_path.is_file():
return json.dumps({"error": f"Not a file: {path}"}, ensure_ascii=False)
chat_id = get_current_chat()
if not chat_id:
return json.dumps({"error": "No chat context available"}, ensure_ascii=False)
from bot.feishu import send_file
await send_file(chat_id, "chat_id", str(file_path))
return json.dumps({
"success": True,
"path": str(file_path),
"size": file_path.stat().st_size,
"message": f"File sent: {file_path.name}",
}, ensure_ascii=False)
except Exception as e:
return json.dumps({"error": str(e)}, ensure_ascii=False)
class WebInput(BaseModel):
action: str = Field(..., description="Action: 'search', 'fetch', or 'ask'")
query: Optional[str] = Field(None, description="Search query or question")
url: Optional[str] = Field(None, description="URL to fetch (for 'fetch' action)")
scope: Optional[str] = Field("webpage", description="Search scope: webpage, paper, document, video, podcast")
max_chars: Optional[int] = Field(2000, description="Max characters in response")
class WebTool(BaseTool):
name: str = "web"
description: str = (
"Search the web, fetch URLs, or ask questions using 秘塔AI Search. "
"Actions: 'search' (web search), 'fetch' (extract content from URL), 'ask' (RAG Q&A). "
"Requires METASO_API_KEY in keyring.yaml."
)
args_schema: Type[BaseModel] = WebInput
def _run(self, action: str, query: Optional[str] = None, url: Optional[str] = None,
scope: Optional[str] = "webpage", max_chars: Optional[int] = 2000) -> str:
raise NotImplementedError("Use async version")
async def _arun(self, action: str, query: Optional[str] = None, url: Optional[str] = None,
scope: Optional[str] = "webpage", max_chars: Optional[int] = 2000) -> str:
from config import METASO_API_KEY
if not METASO_API_KEY:
return json.dumps({"error": "METASO_API_KEY not configured. Add it to keyring.yaml."}, ensure_ascii=False)
import httpx
base_url = "https://metaso.cn/api/mcp"
headers = {
"Authorization": f"Bearer {METASO_API_KEY}",
"Content-Type": "application/json",
}
try:
async with httpx.AsyncClient(timeout=30.0) as client:
if action == "search":
if not query:
return json.dumps({"error": "query required for search"}, ensure_ascii=False)
payload = {
"jsonrpc": "2.0",
"id": 1,
"method": "tools/call",
"params": {
"name": "metaso_web_search",
"arguments": {"q": query, "scope": scope or "webpage", "size": 5, "includeSummary": True},
},
}
resp = await client.post(base_url, json=payload, headers=headers)
data = resp.json()
if "error" in data:
return json.dumps({"error": data["error"]}, ensure_ascii=False)
content_text = data.get("result", {}).get("content", [{}])[0].get("text", "")
result_data = json.loads(content_text) if content_text else {}
webpages = result_data.get("webpages", [])[:5]
output = []
for r in webpages:
date = r.get("date", "")
title = r.get("title", "No title")
snippet = r.get("snippet", "")[:300]
link = r.get("link", "")
output.append(f"[{date}] **{title}**\n{snippet}\n{link}")
total = result_data.get("total", 0)
return json.dumps({
"total": total,
"results": "\n\n".join(output)[:max_chars],
}, ensure_ascii=False)
elif action == "fetch":
if not url:
return json.dumps({"error": "url required for fetch"}, ensure_ascii=False)
payload = {
"jsonrpc": "2.0",
"id": 1,
"method": "tools/call",
"params": {
"name": "metaso_web_reader",
"arguments": {"url": url, "format": "markdown"},
},
}
resp = await client.post(base_url, json=payload, headers=headers)
data = resp.json()
if "error" in data:
return json.dumps({"error": data["error"]}, ensure_ascii=False)
content_text = data.get("result", {}).get("content", [{}])[0].get("text", "")
return json.dumps({"content": content_text[:max_chars]}, ensure_ascii=False)
elif action == "ask":
if not query:
return json.dumps({"error": "query required for ask"}, ensure_ascii=False)
payload = {
"jsonrpc": "2.0",
"id": 1,
"method": "tools/call",
"params": {
"name": "metaso_chat",
"arguments": {"message": query},
},
}
resp = await client.post(base_url, json=payload, headers=headers)
data = resp.json()
if "error" in data:
return json.dumps({"error": data["error"]}, ensure_ascii=False)
content_text = data.get("result", {}).get("content", [{}])[0].get("text", "")
return json.dumps({"answer": content_text[:max_chars]}, ensure_ascii=False)
else:
return json.dumps({"error": f"Unknown action: {action}"}, ensure_ascii=False)
except httpx.TimeoutException:
return json.dumps({"error": "Request timed out"}, ensure_ascii=False)
except Exception as e:
return json.dumps({"error": str(e)}, ensure_ascii=False)
class SchedulerInput(BaseModel):
action: str = Field(..., description="Action: 'remind' or 'repeat'")
delay_seconds: Optional[int] = Field(None, description="Delay in seconds (for 'remind')")
interval_seconds: Optional[int] = Field(None, description="Interval in seconds (for 'repeat')")
message: str = Field(..., description="Reminder message")
max_runs: Optional[int] = Field(5, description="Max runs for recurring (default 5)")
class SchedulerTool(BaseTool):
name: str = "scheduler"
description: str = (
"Schedule reminders. Use 'remind' for one-time, 'repeat' for recurring. "
"Notifications sent to current chat."
)
args_schema: Type[BaseModel] = SchedulerInput
def _run(self, action: str, message: str, delay_seconds: Optional[int] = None,
interval_seconds: Optional[int] = None, max_runs: Optional[int] = 5) -> str:
raise NotImplementedError("Use async version")
async def _arun(self, action: str, message: str, delay_seconds: Optional[int] = None,
interval_seconds: Optional[int] = None, max_runs: Optional[int] = 5) -> str:
from agent.scheduler import scheduler
chat_id = get_current_chat()
if action == "remind":
if not delay_seconds:
return json.dumps({"error": "delay_seconds required for remind"}, ensure_ascii=False)
job_id = await scheduler.schedule_once(
delay_seconds=delay_seconds,
message=message,
notify_chat_id=chat_id,
)
return json.dumps({
"success": True,
"job_id": job_id,
"message": f"Reminder set for {delay_seconds}s from now",
}, ensure_ascii=False)
elif action == "repeat":
if not interval_seconds:
return json.dumps({"error": "interval_seconds required for repeat"}, ensure_ascii=False)
job_id = await scheduler.schedule_recurring(
interval_seconds=interval_seconds,
message=message,
max_runs=max_runs or 5,
notify_chat_id=chat_id,
)
return json.dumps({
"success": True,
"job_id": job_id,
"message": f"Recurring reminder set every {interval_seconds}s ({max_runs} times)",
}, ensure_ascii=False)
else:
return json.dumps({"error": f"Unknown action: {action}"}, ensure_ascii=False)
class TaskStatusInput(BaseModel):
task_id: str = Field(..., description="Task ID to check")
class TaskStatusTool(BaseTool):
name: str = "task_status"
description: str = "Check the status of a background task. Returns current status and result if completed."
args_schema: Type[BaseModel] = TaskStatusInput
def _run(self, task_id: str) -> str:
raise NotImplementedError("Use async version")
async def _arun(self, task_id: str) -> str:
from agent.task_runner import task_runner
task = task_runner.get_task(task_id)
if not task:
return json.dumps({"error": f"Task {task_id} not found"}, ensure_ascii=False)
return json.dumps({
"task_id": task.task_id,
"description": task.description,
"status": task.status.value,
"elapsed": int(task.elapsed),
"started_at": task.started_at,
"completed_at": task.completed_at,
"result": task.result[:500] if task.result else None,
"error": task.error,
}, ensure_ascii=False)
# Module-level tool list for easy import
TOOLS = [
CreateConversationTool(),
SendToConversationTool(),
ListConversationsTool(),
CloseConversationTool(),
ShellTool(),
FileReadTool(),
FileWriteTool(),
FileListTool(),
FileSearchTool(),
FileSendTool(),
WebTool(),
SchedulerTool(),
TaskStatusTool(),
]