"""LangChain tools that bridge the orchestration agent to Claude Code PTY sessions.""" from __future__ import annotations import json import uuid from contextvars import ContextVar from pathlib import Path from typing import Optional, Type from langchain_core.tools import BaseTool from pydantic import BaseModel, Field from agent.manager import manager from config import WORKING_DIR _current_user_id: ContextVar[Optional[str]] = ContextVar("current_user_id", default=None) _current_chat_id: ContextVar[Optional[str]] = ContextVar("current_chat_id", default=None) def set_current_user(user_id: Optional[str]) -> None: _current_user_id.set(user_id) def get_current_user() -> Optional[str]: return _current_user_id.get() def set_current_chat(chat_id: Optional[str]) -> None: _current_chat_id.set(chat_id) def get_current_chat() -> Optional[str]: return _current_chat_id.get() def _resolve_dir(working_dir: str) -> Path: """ Resolve working_dir to an absolute path under WORKING_DIR. Rules: - Absolute paths are used as-is (but must stay within WORKING_DIR for safety). - Relative paths / bare names are joined onto WORKING_DIR. - Path traversal attempts (..) are blocked. - The resolved directory is created if it doesn't exist. """ working_dir = working_dir.strip() if ".." in working_dir.split("/") or ".." in working_dir.split("\\"): raise ValueError( "Path traversal not allowed. Use a subfolder name or path inside the working directory." ) p = Path(working_dir) if not p.is_absolute(): p = WORKING_DIR / p p = p.resolve() try: p.relative_to(WORKING_DIR) except ValueError: raise ValueError( f"Directory {p} is outside the allowed base {WORKING_DIR}. " "Please use a subfolder name or a path inside the working directory." ) p.mkdir(parents=True, exist_ok=True) return p # --------------------------------------------------------------------------- # Input schemas # --------------------------------------------------------------------------- class CreateConversationInput(BaseModel): working_dir: str = Field( ..., description=( "The project directory. Can be a subfolder name (e.g. 'todo_app'), " "a relative path (e.g. 'projects/todo_app'), or a full absolute path. " "Relative names are resolved under the configured base working directory." ), ) initial_message: Optional[str] = Field(None, description="Optional first message to send after spawning") idle_timeout: Optional[int] = Field(None, description="Idle timeout in seconds (default 1800)") cc_timeout: Optional[float] = Field(None, description="Claude Code execution timeout in seconds (default 300)") class SendToConversationInput(BaseModel): conv_id: str = Field(..., description="Conversation ID returned by create_conversation") message: str = Field(..., description="Message / instruction to send to Claude Code") class CloseConversationInput(BaseModel): conv_id: str = Field(..., description="Conversation ID to close") # --------------------------------------------------------------------------- # Tools # --------------------------------------------------------------------------- class CreateConversationTool(BaseTool): name: str = "create_conversation" description: str = ( "Spawn a new Claude Code session in the given working directory. " "Returns a conv_id that must be used for subsequent messages. " "Use this when the user wants to start a new task in a specific directory." ) args_schema: Type[BaseModel] = CreateConversationInput def _run(self, working_dir: str, initial_message: Optional[str] = None, idle_timeout: Optional[int] = None, cc_timeout: Optional[float] = None) -> str: raise NotImplementedError("Use async version") async def _arun(self, working_dir: str, initial_message: Optional[str] = None, idle_timeout: Optional[int] = None, cc_timeout: Optional[float] = None) -> str: try: resolved = _resolve_dir(working_dir) except ValueError as exc: return json.dumps({"error": str(exc)}) user_id = get_current_user() conv_id = str(uuid.uuid4())[:8] await manager.create( conv_id, str(resolved), owner_id=user_id or "", idle_timeout=idle_timeout or 1800, cc_timeout=cc_timeout or 300.0, ) result: dict = { "conv_id": conv_id, "working_dir": str(resolved), } if initial_message: output = await manager.send(conv_id, initial_message, user_id=user_id) result["response"] = output else: result["status"] = "Session created. Send a message to start working." return json.dumps(result, ensure_ascii=False) class SendToConversationTool(BaseTool): name: str = "send_to_conversation" description: str = ( "Send a message to an existing Claude Code session and return its response. " "Use this for follow-up messages in an ongoing session." ) args_schema: Type[BaseModel] = SendToConversationInput def _run(self, conv_id: str, message: str) -> str: raise NotImplementedError("Use async version") async def _arun(self, conv_id: str, message: str) -> str: user_id = get_current_user() try: output = await manager.send(conv_id, message, user_id=user_id) return json.dumps({"conv_id": conv_id, "response": output}, ensure_ascii=False) except KeyError: return json.dumps({"error": f"No active session for conv_id={conv_id!r}"}) except PermissionError as e: return json.dumps({"error": str(e)}) class ListConversationsTool(BaseTool): name: str = "list_conversations" description: str = "List all currently active Claude Code sessions." def _run(self) -> str: raise NotImplementedError("Use async version") async def _arun(self) -> str: user_id = get_current_user() sessions = manager.list_sessions(user_id=user_id) if not sessions: return "No active sessions." return json.dumps(sessions, ensure_ascii=False, indent=2) class CloseConversationTool(BaseTool): name: str = "close_conversation" description: str = "Close and terminate an active Claude Code session." args_schema: Type[BaseModel] = CloseConversationInput def _run(self, conv_id: str) -> str: raise NotImplementedError("Use async version") async def _arun(self, conv_id: str) -> str: user_id = get_current_user() try: closed = await manager.close(conv_id, user_id=user_id) if closed: return f"Session {conv_id!r} closed." return f"Session {conv_id!r} not found." except PermissionError as e: return str(e) BLOCKED_PATTERNS = [ r'\brm\s+-rf\s+/', r'\brm\s+-rf\s+~', r'\bformat\s+', r'\bmkfs\b', r'\bshutdown\b', r'\breboot\b', r'\bdd\s+if=', r':\(\)\{:\|:&\};:', r'\bchmod\s+777\s+/', r'\bchown\s+.*\s+/', r'\b>\s*/dev/sd', r'\bkill\s+-9\s+1\b', r'\bsudo\s+rm\b', r'\bsu\s+-c\b', r'\bsudo\s+chmod\b', r'\bsudo\s+chown\b', r'\bsudo\s+dd\b', r'\b>\s*/dev/null\s+2>&1\s*&\s*;', # fork via backgrounding ] def _is_command_safe(command: str) -> tuple[bool, str]: """Check if command is safe to execute.""" import re for pattern in BLOCKED_PATTERNS: if re.search(pattern, command, re.IGNORECASE): return False, f"Blocked: command matches dangerous pattern" return True, "" class ShellInput(BaseModel): command: str = Field(..., description="Shell command to execute") cwd: Optional[str] = Field(None, description="Working directory (default: WORKING_DIR)") timeout: Optional[int] = Field(30, description="Timeout in seconds (max 120)") class ShellTool(BaseTool): name: str = "run_shell" description: str = ( "Execute a shell command on the host machine and return stdout/stderr. " "Use for: git status, ls, cat, grep, pip list, etc. " "Destructive commands (rm -rf /, format, shutdown) are blocked." ) args_schema: Type[BaseModel] = ShellInput def _run(self, command: str, cwd: Optional[str] = None, timeout: Optional[int] = 30) -> str: raise NotImplementedError("Use async version") async def _arun(self, command: str, cwd: Optional[str] = None, timeout: Optional[int] = 30) -> str: import asyncio import shutil is_safe, reason = _is_command_safe(command) if not is_safe: return json.dumps({"error": reason}, ensure_ascii=False) timeout = min(timeout or 30, 120) work_dir = WORKING_DIR if cwd: try: work_dir = _resolve_dir(cwd) except ValueError as e: return json.dumps({"error": str(e)}, ensure_ascii=False) try: proc = await asyncio.create_subprocess_shell( command, cwd=str(work_dir), stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE, ) stdout, stderr = await asyncio.wait_for( proc.communicate(), timeout=timeout, ) return json.dumps({ "stdout": stdout.decode("utf-8", errors="replace")[:4000], "stderr": stderr.decode("utf-8", errors="replace")[:1000], "exit_code": proc.returncode, "cwd": str(work_dir), }, ensure_ascii=False) except asyncio.TimeoutError: return json.dumps({"error": f"Command timed out after {timeout}s"}, ensure_ascii=False) except Exception as e: return json.dumps({"error": str(e)}, ensure_ascii=False) class FileReadInput(BaseModel): path: str = Field(..., description="File path relative to working directory") start_line: Optional[int] = Field(None, description="Start line (1-indexed)") end_line: Optional[int] = Field(None, description="End line (inclusive)") class FileReadTool(BaseTool): name: str = "read_file" description: str = "Read a file from the working directory. Returns file content with line numbers." args_schema: Type[BaseModel] = FileReadInput def _run(self, path: str, start_line: Optional[int] = None, end_line: Optional[int] = None) -> str: raise NotImplementedError("Use async version") async def _arun(self, path: str, start_line: Optional[int] = None, end_line: Optional[int] = None) -> str: try: file_path = _resolve_dir(path) if not file_path.is_file(): return json.dumps({"error": f"Not a file: {path}"}, ensure_ascii=False) with open(file_path, "r", encoding="utf-8", errors="replace") as f: lines = f.readlines() total_lines = len(lines) start = max(1, start_line or 1) - 1 end = min(total_lines, end_line or total_lines) result_lines = [] for i in range(start, end): result_lines.append(f"{i+1:4d} | {lines[i].rstrip()}") return json.dumps({ "path": str(file_path), "lines": f"{start+1}-{end}", "total_lines": total_lines, "content": "\n".join(result_lines[-500:]), }, ensure_ascii=False) except Exception as e: return json.dumps({"error": str(e)}, ensure_ascii=False) class FileWriteInput(BaseModel): path: str = Field(..., description="File path relative to working directory") content: str = Field(..., description="Content to write") mode: Optional[str] = Field("overwrite", description="Write mode: 'overwrite' or 'append'") class FileWriteTool(BaseTool): name: str = "write_file" description: str = "Write content to a file in the working directory. Use mode='append' to add to existing file." args_schema: Type[BaseModel] = FileWriteInput def _run(self, path: str, content: str, mode: Optional[str] = "overwrite") -> str: raise NotImplementedError("Use async version") async def _arun(self, path: str, content: str, mode: Optional[str] = "overwrite") -> str: try: file_path = _resolve_dir(path) file_path.parent.mkdir(parents=True, exist_ok=True) write_mode = "a" if mode == "append" else "w" with open(file_path, write_mode, encoding="utf-8") as f: f.write(content) return json.dumps({ "success": True, "path": str(file_path), "bytes_written": len(content.encode("utf-8")), }, ensure_ascii=False) except Exception as e: return json.dumps({"error": str(e)}, ensure_ascii=False) class FileListInput(BaseModel): path: Optional[str] = Field(None, description="Directory path (default: working directory)") pattern: Optional[str] = Field(None, description="Glob pattern (e.g. '*.py')") class FileListTool(BaseTool): name: str = "list_files" description: str = "List files in a directory. Use pattern to filter (e.g. '*.py')." args_schema: Type[BaseModel] = FileListInput def _run(self, path: Optional[str] = None, pattern: Optional[str] = None) -> str: raise NotImplementedError("Use async version") async def _arun(self, path: Optional[str] = None, pattern: Optional[str] = None) -> str: try: dir_path = _resolve_dir(path or ".") if not dir_path.is_dir(): return json.dumps({"error": f"Not a directory: {path}"}, ensure_ascii=False) if pattern: files = list(dir_path.glob(pattern))[:100] else: files = list(dir_path.iterdir())[:100] result = [] for f in sorted(files): result.append({ "name": f.name, "type": "dir" if f.is_dir() else "file", "size": f.stat().st_size if f.is_file() else None, }) return json.dumps({ "path": str(dir_path), "files": result, }, ensure_ascii=False) except Exception as e: return json.dumps({"error": str(e)}, ensure_ascii=False) class FileSearchInput(BaseModel): path: str = Field(..., description="Directory path to search in") pattern: str = Field(..., description="Search pattern (regex supported)") max_results: Optional[int] = Field(50, description="Max number of results") class FileSearchTool(BaseTool): name: str = "search_files" description: str = ( "Search for text pattern in files under a directory (grep-like). " "Returns matching lines with file paths and line numbers." ) args_schema: Type[BaseModel] = FileSearchInput def _run(self, path: str, pattern: str, max_results: Optional[int] = 50) -> str: raise NotImplementedError("Use async version") async def _arun(self, path: str, pattern: str, max_results: Optional[int] = 50) -> str: import re try: dir_path = _resolve_dir(path) if not dir_path.is_dir(): return json.dumps({"error": f"Not a directory: {path}"}, ensure_ascii=False) try: regex = re.compile(pattern, re.IGNORECASE) except re.error as e: return json.dumps({"error": f"Invalid regex pattern: {e}"}, ensure_ascii=False) results = [] text_extensions = {'.py', '.js', '.ts', '.tsx', '.jsx', '.java', '.c', '.cpp', '.h', '.go', '.rs', '.rb', '.php', '.cs', '.swift', '.kt', '.scala', '.txt', '.md', '.json', '.yaml', '.yml', '.toml', '.ini', '.cfg', '.html', '.css', '.scss', '.sass', '.less', '.xml', '.sql', '.sh', '.bash', '.zsh', '.ps1', '.bat', '.cmd'} for file_path in dir_path.rglob("*"): if not file_path.is_file(): continue if file_path.suffix.lower() not in text_extensions: continue if any(part.startswith('.') for part in file_path.parts): continue try: with open(file_path, "r", encoding="utf-8", errors="ignore") as f: for line_num, line in enumerate(f, 1): if regex.search(line): rel_path = file_path.relative_to(dir_path) results.append({ "file": str(rel_path), "line": line_num, "content": line.rstrip()[:200], }) if len(results) >= max_results: break if len(results) >= max_results: break except Exception: continue return json.dumps({ "path": str(dir_path), "pattern": pattern, "total_matches": len(results), "results": results, }, ensure_ascii=False) except Exception as e: return json.dumps({"error": str(e)}, ensure_ascii=False) class FileSendInput(BaseModel): path: str = Field(..., description="File path to send") class FileSendTool(BaseTool): name: str = "send_file" description: str = "Send a file to the user via Feishu. Returns confirmation message." args_schema: Type[BaseModel] = FileSendInput def _run(self, path: str) -> str: raise NotImplementedError("Use async version") async def _arun(self, path: str) -> str: try: file_path = _resolve_dir(path) if not file_path.is_file(): return json.dumps({"error": f"Not a file: {path}"}, ensure_ascii=False) chat_id = get_current_chat() if not chat_id: return json.dumps({"error": "No chat context available"}, ensure_ascii=False) from bot.feishu import send_file await send_file(chat_id, "chat_id", str(file_path)) return json.dumps({ "success": True, "path": str(file_path), "size": file_path.stat().st_size, "message": f"File sent: {file_path.name}", }, ensure_ascii=False) except Exception as e: return json.dumps({"error": str(e)}, ensure_ascii=False) class WebInput(BaseModel): action: str = Field(..., description="Action: 'search', 'fetch', or 'ask'") query: Optional[str] = Field(None, description="Search query or question") url: Optional[str] = Field(None, description="URL to fetch (for 'fetch' action)") scope: Optional[str] = Field("webpage", description="Search scope: webpage, paper, document, video, podcast") max_chars: Optional[int] = Field(2000, description="Max characters in response") class WebTool(BaseTool): name: str = "web" description: str = ( "Search the web, fetch URLs, or ask questions using 秘塔AI Search. " "Actions: 'search' (web search), 'fetch' (extract content from URL), 'ask' (RAG Q&A). " "Requires METASO_API_KEY in keyring.yaml." ) args_schema: Type[BaseModel] = WebInput def _run(self, action: str, query: Optional[str] = None, url: Optional[str] = None, scope: Optional[str] = "webpage", max_chars: Optional[int] = 2000) -> str: raise NotImplementedError("Use async version") async def _arun(self, action: str, query: Optional[str] = None, url: Optional[str] = None, scope: Optional[str] = "webpage", max_chars: Optional[int] = 2000) -> str: from config import METASO_API_KEY if not METASO_API_KEY: return json.dumps({"error": "METASO_API_KEY not configured. Add it to keyring.yaml."}, ensure_ascii=False) import httpx base_url = "https://metaso.cn/api/mcp" headers = { "Authorization": f"Bearer {METASO_API_KEY}", "Content-Type": "application/json", } try: async with httpx.AsyncClient(timeout=30.0) as client: if action == "search": if not query: return json.dumps({"error": "query required for search"}, ensure_ascii=False) payload = { "jsonrpc": "2.0", "id": 1, "method": "metaso_web_search", "params": {"query": query, "scope": scope or "webpage"}, } resp = await client.post(base_url, json=payload, headers=headers) data = resp.json() if "error" in data: return json.dumps({"error": data["error"]}, ensure_ascii=False) results = data.get("result", {}).get("results", [])[:5] output = [] for r in results: output.append(f"**{r.get('title', 'No title')}**\n{r.get('snippet', '')}\n{r.get('url', '')}") return json.dumps({"results": "\n\n".join(output)[:max_chars]}, ensure_ascii=False) elif action == "fetch": if not url: return json.dumps({"error": "url required for fetch"}, ensure_ascii=False) payload = { "jsonrpc": "2.0", "id": 1, "method": "metaso_web_reader", "params": {"url": url, "format": "markdown"}, } resp = await client.post(base_url, json=payload, headers=headers) data = resp.json() if "error" in data: return json.dumps({"error": data["error"]}, ensure_ascii=False) content = data.get("result", {}).get("content", "") return json.dumps({"content": content[:max_chars]}, ensure_ascii=False) elif action == "ask": if not query: return json.dumps({"error": "query required for ask"}, ensure_ascii=False) payload = { "jsonrpc": "2.0", "id": 1, "method": "metaso_chat", "params": {"query": query}, } resp = await client.post(base_url, json=payload, headers=headers) data = resp.json() if "error" in data: return json.dumps({"error": data["error"]}, ensure_ascii=False) answer = data.get("result", {}).get("answer", "") return json.dumps({"answer": answer[:max_chars]}, ensure_ascii=False) else: return json.dumps({"error": f"Unknown action: {action}"}, ensure_ascii=False) except httpx.TimeoutException: return json.dumps({"error": "Request timed out"}, ensure_ascii=False) except Exception as e: return json.dumps({"error": str(e)}, ensure_ascii=False) class SchedulerInput(BaseModel): action: str = Field(..., description="Action: 'remind' or 'repeat'") delay_seconds: Optional[int] = Field(None, description="Delay in seconds (for 'remind')") interval_seconds: Optional[int] = Field(None, description="Interval in seconds (for 'repeat')") message: str = Field(..., description="Reminder message") max_runs: Optional[int] = Field(5, description="Max runs for recurring (default 5)") class SchedulerTool(BaseTool): name: str = "scheduler" description: str = ( "Schedule reminders. Use 'remind' for one-time, 'repeat' for recurring. " "Notifications sent to current chat." ) args_schema: Type[BaseModel] = SchedulerInput def _run(self, action: str, message: str, delay_seconds: Optional[int] = None, interval_seconds: Optional[int] = None, max_runs: Optional[int] = 5) -> str: raise NotImplementedError("Use async version") async def _arun(self, action: str, message: str, delay_seconds: Optional[int] = None, interval_seconds: Optional[int] = None, max_runs: Optional[int] = 5) -> str: from agent.scheduler import scheduler chat_id = get_current_chat() if action == "remind": if not delay_seconds: return json.dumps({"error": "delay_seconds required for remind"}, ensure_ascii=False) job_id = await scheduler.schedule_once( delay_seconds=delay_seconds, message=message, notify_chat_id=chat_id, ) return json.dumps({ "success": True, "job_id": job_id, "message": f"Reminder set for {delay_seconds}s from now", }, ensure_ascii=False) elif action == "repeat": if not interval_seconds: return json.dumps({"error": "interval_seconds required for repeat"}, ensure_ascii=False) job_id = await scheduler.schedule_recurring( interval_seconds=interval_seconds, message=message, max_runs=max_runs or 5, notify_chat_id=chat_id, ) return json.dumps({ "success": True, "job_id": job_id, "message": f"Recurring reminder set every {interval_seconds}s ({max_runs} times)", }, ensure_ascii=False) else: return json.dumps({"error": f"Unknown action: {action}"}, ensure_ascii=False) class TaskStatusInput(BaseModel): task_id: str = Field(..., description="Task ID to check") class TaskStatusTool(BaseTool): name: str = "task_status" description: str = "Check the status of a background task. Returns current status and result if completed." args_schema: Type[BaseModel] = TaskStatusInput def _run(self, task_id: str) -> str: raise NotImplementedError("Use async version") async def _arun(self, task_id: str) -> str: from agent.task_runner import task_runner task = task_runner.get_task(task_id) if not task: return json.dumps({"error": f"Task {task_id} not found"}, ensure_ascii=False) return json.dumps({ "task_id": task.task_id, "description": task.description, "status": task.status.value, "elapsed": int(task.elapsed), "started_at": task.started_at, "completed_at": task.completed_at, "result": task.result[:500] if task.result else None, "error": task.error, }, ensure_ascii=False) # Module-level tool list for easy import TOOLS = [ CreateConversationTool(), SendToConversationTool(), ListConversationsTool(), CloseConversationTool(), ShellTool(), FileReadTool(), FileWriteTool(), FileListTool(), FileSearchTool(), FileSendTool(), WebTool(), SchedulerTool(), TaskStatusTool(), ]