- Implement SDK session with secretary model for tool approval flow - Add audit logging for tool usage and permission decisions - Support Feishu card interactions for approval requests - Add new commands for task interruption and progress checking - Remove old test files and update documentation
356 lines
12 KiB
Python
356 lines
12 KiB
Python
"""SDK-based Claude Code session — secretary model.
|
|
|
|
Messages are buffered in memory, not pushed to Feishu in real-time.
|
|
Only key events (completion, error, approval) trigger notifications.
|
|
The secretary AI queries get_progress() to answer user questions.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import logging
|
|
import time
|
|
from dataclasses import dataclass, field
|
|
from typing import Any
|
|
|
|
from claude_agent_sdk import (
|
|
AssistantMessage,
|
|
ClaudeAgentOptions,
|
|
ClaudeSDKClient,
|
|
PermissionMode,
|
|
PermissionResult,
|
|
PermissionResultAllow,
|
|
PermissionResultDeny,
|
|
ResultMessage,
|
|
SystemMessage,
|
|
TextBlock,
|
|
ToolPermissionContext,
|
|
ToolUseBlock,
|
|
)
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
VALID_PERMISSION_MODES = ["default", "acceptEdits", "plan", "bypassPermissions", "dontAsk"]
|
|
DEFAULT_PERMISSION_MODE = "default"
|
|
APPROVAL_TIMEOUT = 120 # seconds
|
|
|
|
|
|
@dataclass
|
|
class SessionProgress:
|
|
"""Session progress snapshot for the secretary AI to inspect."""
|
|
|
|
busy: bool = False
|
|
current_prompt: str = ""
|
|
started_at: float = 0.0
|
|
elapsed_seconds: float = 0.0
|
|
text_messages: list[str] = field(default_factory=list)
|
|
tool_calls: list[str] = field(default_factory=list)
|
|
last_result: str = ""
|
|
error: str = ""
|
|
pending_approval: str = "" # non-empty → waiting for approval, value is tool description
|
|
|
|
|
|
class SDKSession:
|
|
"""One session = one long-lived ClaudeSDKClient + background message buffer loop.
|
|
|
|
Secretary model design:
|
|
- _message_loop buffers all messages to memory, does NOT push to Feishu
|
|
- Only pushes on key events: completion (ResultMessage), error, approval needed
|
|
- get_progress() returns a snapshot for the secretary AI to inspect
|
|
"""
|
|
|
|
MAX_BUFFER_TEXTS = 20
|
|
MAX_BUFFER_TOOLS = 50
|
|
|
|
def __init__(
|
|
self,
|
|
conv_id: str,
|
|
cwd: str,
|
|
owner_id: str,
|
|
permission_mode: str = DEFAULT_PERMISSION_MODE,
|
|
chat_id: str | None = None,
|
|
):
|
|
self.conv_id = conv_id
|
|
self.cwd = cwd
|
|
self.owner_id = owner_id
|
|
self.permission_mode = permission_mode
|
|
self.chat_id = chat_id
|
|
|
|
self.client: ClaudeSDKClient | None = None
|
|
self.session_id: str | None = None
|
|
|
|
# Message buffers
|
|
self._text_buffer: list[str] = []
|
|
self._tool_buffer: list[str] = []
|
|
self._last_result: str = ""
|
|
self._error: str = ""
|
|
self._current_prompt: str = ""
|
|
self._started_at: float = 0.0
|
|
|
|
# Task state
|
|
self._message_loop_task: asyncio.Task | None = None
|
|
self._busy = False
|
|
self._busy_event = asyncio.Event()
|
|
self._busy_event.set() # initially idle
|
|
|
|
# Approval mechanism
|
|
self._pending_approval: asyncio.Future | None = None
|
|
self._pending_approval_desc: str = ""
|
|
|
|
async def start(self) -> None:
|
|
"""Create and connect the ClaudeSDKClient, start the message loop."""
|
|
from agent.sdk_hooks import build_hooks
|
|
|
|
env = self._build_env()
|
|
hooks = build_hooks(self.conv_id)
|
|
|
|
options = ClaudeAgentOptions(
|
|
cwd=self.cwd,
|
|
permission_mode=self.permission_mode,
|
|
allowed_tools=[
|
|
"Read", "Glob", "Grep", "Bash", "Edit", "Write",
|
|
"MultiEdit", "WebFetch", "WebSearch",
|
|
],
|
|
can_use_tool=self._permission_callback,
|
|
hooks=hooks,
|
|
env=env,
|
|
)
|
|
self.client = ClaudeSDKClient(options)
|
|
await self.client.connect()
|
|
|
|
self._message_loop_task = asyncio.create_task(
|
|
self._message_loop(), name=f"sdk-loop-{self.conv_id}"
|
|
)
|
|
logger.info("SDKSession %s started in %s", self.conv_id, self.cwd)
|
|
|
|
async def send(self, prompt: str, chat_id: str | None = None) -> str:
|
|
"""Send a message. Returns immediately; execution happens in background."""
|
|
if not self.client:
|
|
await self.start()
|
|
|
|
if chat_id:
|
|
self.chat_id = chat_id
|
|
|
|
# If busy, interrupt the current task first
|
|
if self._busy:
|
|
await self.interrupt()
|
|
try:
|
|
await asyncio.wait_for(self._busy_event.wait(), timeout=10)
|
|
except asyncio.TimeoutError:
|
|
pass
|
|
|
|
self._busy = True
|
|
self._busy_event.clear()
|
|
self._current_prompt = prompt
|
|
self._started_at = time.time()
|
|
self._last_result = ""
|
|
self._error = ""
|
|
self._text_buffer.clear()
|
|
self._tool_buffer.clear()
|
|
|
|
await self.client.query(prompt)
|
|
return "⏳ 已开始执行"
|
|
|
|
async def send_and_wait(self, prompt: str, chat_id: str | None = None) -> str:
|
|
"""Send and wait for completion. For LLM agent tool calls."""
|
|
await self.send(prompt, chat_id)
|
|
await self._busy_event.wait()
|
|
return self._last_result or self._error or "(no output)"
|
|
|
|
def get_progress(self) -> SessionProgress:
|
|
"""Return a progress snapshot. Primary query interface for the secretary AI."""
|
|
return SessionProgress(
|
|
busy=self._busy,
|
|
current_prompt=self._current_prompt,
|
|
started_at=self._started_at,
|
|
elapsed_seconds=time.time() - self._started_at if self._busy else 0,
|
|
text_messages=list(self._text_buffer[-5:]),
|
|
tool_calls=list(self._tool_buffer[-10:]),
|
|
last_result=self._last_result[:1000],
|
|
error=self._error,
|
|
pending_approval=self._pending_approval_desc,
|
|
)
|
|
|
|
async def interrupt(self) -> None:
|
|
"""Interrupt the currently running task."""
|
|
if self.client and self._busy:
|
|
await self.client.interrupt()
|
|
logger.info("SDKSession %s interrupted", self.conv_id)
|
|
|
|
async def set_permission_mode(self, mode: PermissionMode) -> None:
|
|
"""Dynamically change the permission mode."""
|
|
if self.client:
|
|
await self.client.set_permission_mode(mode)
|
|
self.permission_mode = mode
|
|
logger.info("SDKSession %s permission_mode → %s", self.conv_id, mode)
|
|
|
|
async def approve(self, approved: bool) -> None:
|
|
"""Resolve a pending tool approval."""
|
|
if self._pending_approval and not self._pending_approval.done():
|
|
self._pending_approval.set_result(approved)
|
|
|
|
async def close(self) -> None:
|
|
"""Disconnect and clean up."""
|
|
if self._message_loop_task and not self._message_loop_task.done():
|
|
self._message_loop_task.cancel()
|
|
try:
|
|
await self._message_loop_task
|
|
except asyncio.CancelledError:
|
|
pass
|
|
if self.client:
|
|
await self.client.disconnect()
|
|
self.client = None
|
|
logger.info("SDKSession %s closed", self.conv_id)
|
|
|
|
# --- Internal ---
|
|
|
|
async def _message_loop(self) -> None:
|
|
"""Background message consumption loop. Buffers messages, notifies on key events."""
|
|
from agent.audit import log_interaction
|
|
|
|
try:
|
|
async for msg in self.client.receive_messages():
|
|
if isinstance(msg, SystemMessage) and msg.subtype == "init":
|
|
self.session_id = msg.data.get("session_id")
|
|
|
|
elif isinstance(msg, AssistantMessage):
|
|
for block in msg.content:
|
|
if isinstance(block, TextBlock):
|
|
self._text_buffer.append(block.text)
|
|
if len(self._text_buffer) > self.MAX_BUFFER_TEXTS:
|
|
self._text_buffer.pop(0)
|
|
elif isinstance(block, ToolUseBlock):
|
|
summary = f"{block.name}({self._summarize_input(block.input)})"
|
|
self._tool_buffer.append(summary)
|
|
if len(self._tool_buffer) > self.MAX_BUFFER_TOOLS:
|
|
self._tool_buffer.pop(0)
|
|
|
|
elif isinstance(msg, ResultMessage):
|
|
self._last_result = msg.result or ""
|
|
self._busy = False
|
|
self._busy_event.set()
|
|
|
|
# Key event: task completed → notify Feishu
|
|
if self.chat_id:
|
|
await self._notify_completion()
|
|
|
|
log_interaction(
|
|
conv_id=self.conv_id,
|
|
prompt=self._current_prompt,
|
|
response=self._last_result[:2000],
|
|
cwd=self.cwd,
|
|
user_id=self.owner_id,
|
|
)
|
|
|
|
except asyncio.CancelledError:
|
|
logger.debug("Message loop cancelled for %s", self.conv_id)
|
|
except Exception as exc:
|
|
logger.exception("Message loop error for %s", self.conv_id)
|
|
self._error = str(exc)
|
|
self._busy = False
|
|
self._busy_event.set()
|
|
if self.chat_id:
|
|
await self._notify_error(str(exc))
|
|
|
|
async def _notify_completion(self) -> None:
|
|
from bot.feishu import send_markdown
|
|
|
|
result_preview = self._last_result[:800]
|
|
if len(self._last_result) > 800:
|
|
result_preview += "\n...[truncated]"
|
|
elapsed = int(time.time() - self._started_at)
|
|
tools_used = len(self._tool_buffer)
|
|
msg = f"✅ **任务完成** ({elapsed}s, {tools_used} tool calls)\n\n{result_preview}"
|
|
try:
|
|
await send_markdown(self.chat_id, "chat_id", msg)
|
|
except Exception:
|
|
logger.exception("Failed to notify completion")
|
|
|
|
async def _notify_error(self, error: str) -> None:
|
|
from bot.feishu import send_markdown
|
|
|
|
try:
|
|
await send_markdown(
|
|
self.chat_id, "chat_id",
|
|
f"❌ **任务出错**\n\n```\n{error[:500]}\n```",
|
|
)
|
|
except Exception:
|
|
logger.exception("Failed to notify error")
|
|
|
|
async def _permission_callback(
|
|
self, tool_name: str, input_data: dict, context: ToolPermissionContext
|
|
) -> PermissionResult:
|
|
"""can_use_tool — send approval card to Feishu, wait for card callback."""
|
|
# Auto-allow read-only tools
|
|
if tool_name in ("Read", "Glob", "Grep", "WebSearch", "WebFetch"):
|
|
return PermissionResultAllow()
|
|
|
|
if not self.chat_id:
|
|
return PermissionResultAllow()
|
|
|
|
# Send approval card
|
|
from bot.feishu import send_card, build_approval_card
|
|
|
|
summary = self._format_tool_summary(tool_name, input_data)
|
|
self._pending_approval_desc = f"{tool_name}: {summary}"
|
|
|
|
card = build_approval_card(
|
|
conv_id=self.conv_id,
|
|
tool_name=tool_name,
|
|
summary=summary,
|
|
timeout=APPROVAL_TIMEOUT,
|
|
)
|
|
await send_card(self.chat_id, "chat_id", card)
|
|
|
|
# Wait for card callback or text reply y/n
|
|
loop = asyncio.get_running_loop()
|
|
self._pending_approval = loop.create_future()
|
|
try:
|
|
approved = await asyncio.wait_for(
|
|
self._pending_approval, timeout=APPROVAL_TIMEOUT
|
|
)
|
|
except asyncio.TimeoutError:
|
|
approved = False
|
|
from bot.feishu import send_markdown
|
|
|
|
await send_markdown(self.chat_id, "chat_id", "⏰ 审批超时,已自动拒绝。")
|
|
finally:
|
|
self._pending_approval_desc = ""
|
|
|
|
from agent.audit import log_permission_decision
|
|
|
|
log_permission_decision(
|
|
conv_id=self.conv_id,
|
|
tool_name=tool_name,
|
|
tool_input=input_data,
|
|
approved=approved,
|
|
)
|
|
if approved:
|
|
return PermissionResultAllow()
|
|
return PermissionResultDeny(message="用户拒绝了此操作")
|
|
|
|
def _format_tool_summary(self, tool_name: str, input_data: dict) -> str:
|
|
if tool_name == "Bash":
|
|
return f"`{input_data.get('command', '')[:200]}`"
|
|
if tool_name in ("Edit", "Write", "MultiEdit"):
|
|
return f"file: `{input_data.get('file_path', input_data.get('path', ''))}`"
|
|
return str(input_data)[:200]
|
|
|
|
@staticmethod
|
|
def _summarize_input(input_data: dict) -> str:
|
|
if "command" in input_data:
|
|
return input_data["command"][:80]
|
|
if "file_path" in input_data:
|
|
return input_data["file_path"]
|
|
return str(input_data)[:60]
|
|
|
|
def _build_env(self) -> dict[str, str]:
|
|
import os
|
|
|
|
env = {}
|
|
for key in ("ANTHROPIC_BASE_URL", "ANTHROPIC_AUTH_TOKEN", "CLAUDE_CODE_OAUTH_TOKEN"):
|
|
val = os.environ.get(key, "")
|
|
if val:
|
|
env[key] = val
|
|
return env
|