PhoneWork/agent/sdk_session.py
Yuyao Huang 72ebf3b75d feat(question): implement AskUserQuestion tool support
- Add question card builder and answer handling in feishu.py
- Extend SDKSession with pending question state and answer method
- Update card callback handler to support question answers
- Add test cases for question flow and card responses
- Document usage with test_can_use_tool_ask.py example
2026-04-02 08:52:50 +08:00

415 lines
15 KiB
Python

"""SDK-based Claude Code session — secretary model.
Messages are buffered in memory, not pushed to Feishu in real-time.
Only key events (completion, error, approval) trigger notifications.
The secretary AI queries get_progress() to answer user questions.
"""
from __future__ import annotations
import asyncio
import logging
import time
from dataclasses import dataclass, field
from typing import Any
from claude_agent_sdk import (
AssistantMessage,
ClaudeAgentOptions,
ClaudeSDKClient,
PermissionMode,
PermissionResult,
PermissionResultAllow,
PermissionResultDeny,
ResultMessage,
SystemMessage,
TextBlock,
ToolPermissionContext,
ToolUseBlock,
)
logger = logging.getLogger(__name__)
VALID_PERMISSION_MODES = ["default", "acceptEdits", "plan", "bypassPermissions", "dontAsk"]
DEFAULT_PERMISSION_MODE = "default"
APPROVAL_TIMEOUT = 120 # seconds
@dataclass
class SessionProgress:
"""Session progress snapshot for the secretary AI to inspect."""
busy: bool = False
current_prompt: str = ""
started_at: float = 0.0
elapsed_seconds: float = 0.0
text_messages: list[str] = field(default_factory=list)
tool_calls: list[str] = field(default_factory=list)
last_result: str = ""
error: str = ""
pending_approval: str = "" # non-empty → waiting for approval, value is tool description
pending_question: dict | None = None # non-None → waiting for user answer to AskUserQuestion
class SDKSession:
"""One session = one long-lived ClaudeSDKClient + background message buffer loop.
Secretary model design:
- _message_loop buffers all messages to memory, does NOT push to Feishu
- Only pushes on key events: completion (ResultMessage), error, approval needed
- get_progress() returns a snapshot for the secretary AI to inspect
"""
MAX_BUFFER_TEXTS = 20
MAX_BUFFER_TOOLS = 50
def __init__(
self,
conv_id: str,
cwd: str,
owner_id: str,
permission_mode: str = DEFAULT_PERMISSION_MODE,
chat_id: str | None = None,
):
self.conv_id = conv_id
self.cwd = cwd
self.owner_id = owner_id
self.permission_mode = permission_mode
self.chat_id = chat_id
self.client: ClaudeSDKClient | None = None
self.session_id: str | None = None
# Message buffers
self._text_buffer: list[str] = []
self._tool_buffer: list[str] = []
self._last_result: str = ""
self._error: str = ""
self._current_prompt: str = ""
self._started_at: float = 0.0
# Task state
self._message_loop_task: asyncio.Task | None = None
self._busy = False
self._busy_event = asyncio.Event()
self._busy_event.set() # initially idle
# Approval mechanism
self._pending_approval: asyncio.Future | None = None
self._pending_approval_desc: str = ""
# AskUserQuestion mechanism
self._pending_question: asyncio.Future | None = None
self._pending_question_data: dict | None = None # {questions: [...], conv_id: ...}
async def start(self) -> None:
"""Create and connect the ClaudeSDKClient, start the message loop."""
from agent.sdk_hooks import build_hooks
env = self._build_env()
hooks = build_hooks(self.conv_id)
options = ClaudeAgentOptions(
cwd=self.cwd,
permission_mode=self.permission_mode,
allowed_tools=[
"Read", "Glob", "Grep", "Bash", "Edit", "Write",
"MultiEdit", "WebFetch", "WebSearch",
],
can_use_tool=self._permission_callback,
hooks=hooks,
env=env,
)
self.client = ClaudeSDKClient(options)
await self.client.connect()
self._message_loop_task = asyncio.create_task(
self._message_loop(), name=f"sdk-loop-{self.conv_id}"
)
logger.info("SDKSession %s started in %s", self.conv_id, self.cwd)
async def send(self, prompt: str, chat_id: str | None = None) -> str:
"""Send a message. Returns immediately; execution happens in background."""
if not self.client:
await self.start()
if chat_id:
self.chat_id = chat_id
# If busy, interrupt the current task first
if self._busy:
await self.interrupt()
try:
await asyncio.wait_for(self._busy_event.wait(), timeout=10)
except asyncio.TimeoutError:
pass
self._busy = True
self._busy_event.clear()
self._current_prompt = prompt
self._started_at = time.time()
self._last_result = ""
self._error = ""
self._text_buffer.clear()
self._tool_buffer.clear()
await self.client.query(prompt)
return "⏳ 已开始执行"
async def send_and_wait(self, prompt: str, chat_id: str | None = None) -> str:
"""Send and wait for completion. For LLM agent tool calls."""
await self.send(prompt, chat_id)
await self._busy_event.wait()
return self._last_result or self._error or "(no output)"
def get_progress(self) -> SessionProgress:
"""Return a progress snapshot. Primary query interface for the secretary AI."""
return SessionProgress(
busy=self._busy,
current_prompt=self._current_prompt,
started_at=self._started_at,
elapsed_seconds=time.time() - self._started_at if self._busy else 0,
text_messages=list(self._text_buffer[-5:]),
tool_calls=list(self._tool_buffer[-10:]),
last_result=self._last_result[:1000],
error=self._error,
pending_approval=self._pending_approval_desc,
pending_question=self._pending_question_data,
)
async def interrupt(self) -> None:
"""Interrupt the currently running task."""
if self.client and self._busy:
await self.client.interrupt()
logger.info("SDKSession %s interrupted", self.conv_id)
async def set_permission_mode(self, mode: PermissionMode) -> None:
"""Dynamically change the permission mode."""
if self.client:
await self.client.set_permission_mode(mode)
self.permission_mode = mode
logger.info("SDKSession %s permission_mode → %s", self.conv_id, mode)
async def approve(self, approved: bool) -> None:
"""Resolve a pending tool approval."""
if self._pending_approval and not self._pending_approval.done():
self._pending_approval.set_result(approved)
async def answer_question(self, answers: dict[str, str]) -> None:
"""Resolve a pending AskUserQuestion with user's selected answers.
Args:
answers: maps question text → selected option label.
"""
if self._pending_question and not self._pending_question.done():
self._pending_question.set_result(answers)
async def close(self) -> None:
"""Disconnect and clean up."""
if self._message_loop_task and not self._message_loop_task.done():
self._message_loop_task.cancel()
try:
await self._message_loop_task
except asyncio.CancelledError:
pass
if self.client:
await self.client.disconnect()
self.client = None
logger.info("SDKSession %s closed", self.conv_id)
# --- Internal ---
async def _message_loop(self) -> None:
"""Background message consumption loop. Buffers messages, notifies on key events."""
from agent.audit import log_interaction
try:
async for msg in self.client.receive_messages():
if isinstance(msg, SystemMessage) and msg.subtype == "init":
self.session_id = msg.data.get("session_id")
elif isinstance(msg, AssistantMessage):
for block in msg.content:
if isinstance(block, TextBlock):
self._text_buffer.append(block.text)
if len(self._text_buffer) > self.MAX_BUFFER_TEXTS:
self._text_buffer.pop(0)
elif isinstance(block, ToolUseBlock):
summary = f"{block.name}({self._summarize_input(block.input)})"
self._tool_buffer.append(summary)
if len(self._tool_buffer) > self.MAX_BUFFER_TOOLS:
self._tool_buffer.pop(0)
elif isinstance(msg, ResultMessage):
self._last_result = msg.result or ""
self._busy = False
self._busy_event.set()
# Key event: task completed → notify Feishu
if self.chat_id:
await self._notify_completion()
log_interaction(
conv_id=self.conv_id,
prompt=self._current_prompt,
response=self._last_result[:2000],
cwd=self.cwd,
user_id=self.owner_id,
)
except asyncio.CancelledError:
logger.debug("Message loop cancelled for %s", self.conv_id)
except Exception as exc:
logger.exception("Message loop error for %s", self.conv_id)
self._error = str(exc)
self._busy = False
self._busy_event.set()
if self.chat_id:
await self._notify_error(str(exc))
async def _notify_completion(self) -> None:
from bot.feishu import send_markdown
result_preview = self._last_result[:800]
if len(self._last_result) > 800:
result_preview += "\n...[truncated]"
elapsed = int(time.time() - self._started_at)
tools_used = len(self._tool_buffer)
msg = f"✅ **任务完成** ({elapsed}s, {tools_used} tool calls)\n\n{result_preview}"
try:
await send_markdown(self.chat_id, "chat_id", msg)
except Exception:
logger.exception("Failed to notify completion")
async def _notify_error(self, error: str) -> None:
from bot.feishu import send_markdown
try:
await send_markdown(
self.chat_id, "chat_id",
f"❌ **任务出错**\n\n```\n{error[:500]}\n```",
)
except Exception:
logger.exception("Failed to notify error")
async def _permission_callback(
self, tool_name: str, input_data: dict, context: ToolPermissionContext
) -> PermissionResult:
"""can_use_tool — route to question card or approval card based on tool type."""
# Auto-allow read-only tools
if tool_name in ("Read", "Glob", "Grep", "WebSearch", "WebFetch"):
return PermissionResultAllow()
# AskUserQuestion: show options to user, collect answer, return via updated_input
if tool_name == "AskUserQuestion":
return await self._handle_ask_user_question(input_data)
if not self.chat_id:
return PermissionResultAllow()
# Regular tools: approval flow
return await self._handle_tool_approval(tool_name, input_data)
async def _handle_ask_user_question(self, input_data: dict) -> PermissionResult:
"""Handle AskUserQuestion: send question card, wait for answer, return updated_input."""
if not self.chat_id:
return PermissionResultAllow()
questions = input_data.get("questions", [])
if not questions:
return PermissionResultAllow()
from bot.feishu import send_card, build_question_card
# Build and send question card
self._pending_question_data = {"questions": questions, "conv_id": self.conv_id}
card = build_question_card(
conv_id=self.conv_id,
questions=questions,
)
await send_card(self.chat_id, "chat_id", card)
# Wait for user's answer (via card callback or text reply)
loop = asyncio.get_running_loop()
self._pending_question = loop.create_future()
try:
answers = await asyncio.wait_for(
self._pending_question, timeout=APPROVAL_TIMEOUT
)
except asyncio.TimeoutError:
answers = {}
from bot.feishu import send_markdown
await send_markdown(self.chat_id, "chat_id", "⏰ 问题超时,已跳过。")
finally:
self._pending_question_data = None
# Pre-fill answers in the tool input
modified_input = dict(input_data)
if "answers" not in modified_input or not isinstance(modified_input.get("answers"), dict):
modified_input["answers"] = {}
modified_input["answers"].update(answers)
return PermissionResultAllow(updated_input=modified_input)
async def _handle_tool_approval(self, tool_name: str, input_data: dict) -> PermissionResult:
"""Handle regular tool approval: send approval card, wait for approve/deny."""
from bot.feishu import send_card, build_approval_card
summary = self._format_tool_summary(tool_name, input_data)
self._pending_approval_desc = f"{tool_name}: {summary}"
card = build_approval_card(
conv_id=self.conv_id,
tool_name=tool_name,
summary=summary,
timeout=APPROVAL_TIMEOUT,
)
await send_card(self.chat_id, "chat_id", card)
loop = asyncio.get_running_loop()
self._pending_approval = loop.create_future()
try:
approved = await asyncio.wait_for(
self._pending_approval, timeout=APPROVAL_TIMEOUT
)
except asyncio.TimeoutError:
approved = False
from bot.feishu import send_markdown
await send_markdown(self.chat_id, "chat_id", "⏰ 审批超时,已自动拒绝。")
finally:
self._pending_approval_desc = ""
from agent.audit import log_permission_decision
log_permission_decision(
conv_id=self.conv_id, tool_name=tool_name,
tool_input=input_data, approved=approved,
)
if approved:
return PermissionResultAllow()
return PermissionResultDeny(message="用户拒绝了此操作")
def _format_tool_summary(self, tool_name: str, input_data: dict) -> str:
if tool_name == "Bash":
return f"`{input_data.get('command', '')[:200]}`"
if tool_name in ("Edit", "Write", "MultiEdit"):
return f"file: `{input_data.get('file_path', input_data.get('path', ''))}`"
return str(input_data)[:200]
@staticmethod
def _summarize_input(input_data: dict) -> str:
if "command" in input_data:
return input_data["command"][:80]
if "file_path" in input_data:
return input_data["file_path"]
return str(input_data)[:60]
def _build_env(self) -> dict[str, str]:
import os
env = {}
for key in ("ANTHROPIC_BASE_URL", "ANTHROPIC_AUTH_TOKEN", "CLAUDE_CODE_OAUTH_TOKEN"):
val = os.environ.get(key, "")
if val:
env[key] = val
return env