Yuyao Huang (Sam) 09b63341cd refactor: 统一使用现代类型注解替代传统类型注解
- 将 Dict、List 等传统类型注解替换为 dict、list 等现代类型注解
- 更新类型注解以更精确地反映变量类型
- 修复部分类型注解与实际使用不匹配的问题
- 优化部分代码逻辑以提高类型安全性
2026-03-28 14:27:21 +08:00

250 lines
9.8 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""LangChain orchestration agent backed by ZhipuAI (OpenAI-compatible API).
Uses LangChain 1.x tool-calling pattern: bind_tools + manual agentic loop.
"""
from __future__ import annotations
import asyncio
import json
import logging
import re
from collections import defaultdict
from typing import Dict, List, Optional
from langchain_core.messages import (
AIMessage,
BaseMessage,
HumanMessage,
SystemMessage,
ToolMessage,
)
from langchain_openai import ChatOpenAI
from agent.manager import manager
from config import OPENAI_API_KEY, OPENAI_BASE_URL, OPENAI_MODEL, WORKING_DIR
from orchestrator.tools import TOOLS, set_current_user
logger = logging.getLogger(__name__)
SYSTEM_PROMPT_TEMPLATE = """You are PhoneWork, an AI assistant that helps users control Claude Code \
from their phone via Feishu (飞书).
You manage Claude Code sessions. Each session has a conv_id and runs in a project directory.
Base working directory: {working_dir}
Users refer to projects by subfolder name (e.g. "todo_app") or relative path. \
Pass these names directly to `create_conversation` — the tool resolves them automatically.
{active_session_line}
Your responsibilities:
1. NEW session: call `create_conversation` with the project name/path. \
If the user's message also contains a task, pass it as `initial_message` too.
2. Follow-up to ACTIVE session: call `send_to_conversation` with the active conv_id shown above.
3. List sessions: call `list_conversations`.
4. Close session: call `close_conversation`.
5. GENERAL QUESTIONS: If the user asks a general question (not about a specific project or file), \
answer directly using your own knowledge. Do NOT create a session for simple Q&A.
Guidelines:
- Relay Claude Code's output verbatim.
- If no active session and the user sends a task without naming a directory, ask them which project.
- For general knowledge questions (e.g., "what is a Python generator?", "explain async/await"), \
answer directly without creating a session.
- Keep your own words brief — let Claude Code's output speak.
- Reply in the same language the user uses (Chinese or English).
"""
MAX_ITERATIONS = 10
_TOOL_MAP = {t.name: t for t in TOOLS}
QUESTION_PATTERNS = [
r'\?$', # ends with ?
r'$', # ends with Chinese ?
r'\b(what|how|why|when|where|who|which|explain|describe|tell me|can you|could you|is there|are there|do you know)\b',
r'(什么|怎么|为什么|何时|哪里|谁|哪个|解释|描述|告诉我|能否|可以|有没有|是不是)',
]
def _is_general_question(text: str) -> bool:
"""Check if text looks like a general knowledge question (not a project task)."""
text_lower = text.lower().strip()
project_indicators = [
'create', 'make', 'build', 'fix', 'update', 'delete', 'remove', 'add',
'implement', 'refactor', 'test', 'run', 'execute', 'start', 'stop',
'project', 'folder', 'directory', 'file', 'code', 'session',
'创建', '制作', '构建', '修复', '更新', '删除', '添加', '实现', '重构', '测试', '运行', '项目', '文件夹', '文件', '代码',
]
for indicator in project_indicators:
if indicator in text_lower:
return False
for pattern in QUESTION_PATTERNS:
if re.search(pattern, text_lower, re.IGNORECASE):
return True
return False
class OrchestrationAgent:
"""Per-user agent with conversation history and active session tracking."""
def __init__(self) -> None:
llm = ChatOpenAI(
base_url=OPENAI_BASE_URL,
api_key=OPENAI_API_KEY,
model=OPENAI_MODEL,
temperature=0.1,
)
self._llm_with_tools = llm.bind_tools(TOOLS)
# user_id -> list[BaseMessage]
self._history: dict[str, list[BaseMessage]] = defaultdict(list)
# user_id -> most recently active conv_id
self._active_conv: dict[str, Optional[str]] = defaultdict(lambda: None)
# user_id -> asyncio.Lock (prevents concurrent processing per user)
self._user_locks: dict[str, asyncio.Lock] = defaultdict(asyncio.Lock)
# user_id -> passthrough mode enabled
self._passthrough: dict[str, bool] = defaultdict(lambda: False)
def _build_system_prompt(self, user_id: str) -> str:
conv_id = self._active_conv[user_id]
if conv_id:
active_line = f"ACTIVE SESSION: conv_id={conv_id!r} ← use this for all follow-up messages"
else:
active_line = "ACTIVE SESSION: none"
return SYSTEM_PROMPT_TEMPLATE.format(
working_dir=WORKING_DIR,
active_session_line=active_line,
)
def get_active_conv(self, user_id: str) -> Optional[str]:
return self._active_conv.get(user_id)
def get_passthrough(self, user_id: str) -> bool:
return self._passthrough.get(user_id, False)
def set_passthrough(self, user_id: str, enabled: bool) -> None:
self._passthrough[user_id] = enabled
async def run(self, user_id: str, text: str) -> str:
"""Process a user message and return the agent's reply."""
async with self._user_locks[user_id]:
return await self._run_locked(user_id, text)
async def _run_locked(self, user_id: str, text: str) -> str:
"""Internal implementation, must be called with user lock held."""
set_current_user(user_id)
active_conv = self._active_conv[user_id]
short_uid = user_id[-8:]
logger.info(">>> user=...%s conv=%s msg=%r", short_uid, active_conv, text[:80])
logger.debug(" history_len=%d", len(self._history[user_id]))
# Passthrough mode: if enabled and active session, bypass LLM
if self._passthrough[user_id] and active_conv:
try:
reply = await manager.send(active_conv, text, user_id=user_id)
logger.info("<<< [passthrough] reply: %r", reply[:120])
return reply
except KeyError:
logger.warning("Session %s no longer exists, clearing active_conv", active_conv)
self._active_conv[user_id] = None
except Exception as exc:
logger.exception("Passthrough error for user=%s", user_id)
return f"[Error] {exc}"
# Direct Q&A: if no active session and message looks like a general question, answer directly
if not active_conv and _is_general_question(text):
logger.debug(" → direct Q&A (no tools)")
llm_no_tools = ChatOpenAI(
base_url=OPENAI_BASE_URL,
api_key=OPENAI_API_KEY,
model=OPENAI_MODEL,
temperature=0.7,
)
qa_prompt = (
"You are a helpful assistant. Answer the user's question concisely and accurately. "
"Reply in the same language the user uses.\n\n"
f"Question: {text}"
)
response = await llm_no_tools.ainvoke([HumanMessage(content=qa_prompt)])
return response.content or ""
messages: list[BaseMessage] = (
[SystemMessage(content=self._build_system_prompt(user_id))]
+ self._history[user_id]
+ [HumanMessage(content=text)]
)
reply = ""
try:
for iteration in range(MAX_ITERATIONS):
logger.debug(" LLM call #%d", iteration)
ai_msg: AIMessage = await self._llm_with_tools.ainvoke(messages)
messages.append(ai_msg)
if not ai_msg.tool_calls:
reply = ai_msg.content or ""
logger.debug(" → done (no tool call)")
break
for tc in ai_msg.tool_calls:
tool_name = tc["name"]
tool_args = tc["args"]
tool_id = tc["id"]
args_summary = ", ".join(
f"{k}={str(v)[:50]!r}" for k, v in tool_args.items()
)
logger.info("%s(%s)", tool_name, args_summary)
tool_obj = _TOOL_MAP.get(tool_name)
if tool_obj is None:
result = f"Unknown tool: {tool_name}"
logger.warning(" unknown tool: %s", tool_name)
else:
try:
result = await tool_obj.arun(tool_args)
except Exception as exc:
result = f"Tool error: {exc}"
logger.error(" tool %s error: %s", tool_name, exc)
logger.debug("%s: %r", tool_name, str(result)[:120])
if tool_name == "create_conversation":
try:
data = json.loads(result)
if "conv_id" in data:
self._active_conv[user_id] = data["conv_id"]
logger.info(" ✓ active session → %s", data["conv_id"])
except Exception:
pass
messages.append(
ToolMessage(content=str(result), tool_call_id=tool_id)
)
else:
reply = "[Max iterations reached]"
logger.warning(" max iterations reached")
except Exception as exc:
logger.exception("agent error for user=%s", user_id)
reply = f"[Error] {exc}"
logger.info("<<< reply: %r", reply[:120])
# Update history
self._history[user_id].append(HumanMessage(content=text))
self._history[user_id].append(AIMessage(content=reply))
if len(self._history[user_id]) > 40:
self._history[user_id] = self._history[user_id][-40:]
return reply
agent = OrchestrationAgent()