- Add question card builder and answer handling in feishu.py - Extend SDKSession with pending question state and answer method - Update card callback handler to support question answers - Add test cases for question flow and card responses - Document usage with test_can_use_tool_ask.py example
418 lines
15 KiB
Python
418 lines
15 KiB
Python
"""Feishu event handler using lark-oapi long-connection (WebSocket) mode."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import json
|
|
import logging
|
|
import threading
|
|
import time
|
|
from typing import Any, Dict
|
|
|
|
import lark_oapi as lark
|
|
from lark_oapi.api.im.v1 import P2ImMessageReceiveV1
|
|
|
|
from bot.commands import handle_command
|
|
from bot.feishu import send_text, send_markdown
|
|
from config import FEISHU_APP_ID, FEISHU_APP_SECRET, is_user_allowed
|
|
from orchestrator.agent import agent
|
|
from orchestrator.tools import set_current_chat
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
_main_loop: asyncio.AbstractEventLoop | None = None
|
|
_ws_connected: bool = False
|
|
_last_message_time: float = 0.0
|
|
_reconnect_count: int = 0
|
|
|
|
# Deduplication: drop Feishu re-deliveries by (user_id, content) within a short window.
|
|
# Feishu retries on network hiccups within ~60s using the same payload.
|
|
# We use a 10s window: identical content from the same user within 10s is a re-delivery,
|
|
# not a deliberate repeat (user intentional repeats arrive after the bot has already replied).
|
|
_recent_messages: dict[tuple[str, str], float] = {} # key: (user_id, content) → timestamp
|
|
_DEDUP_WINDOW = 10.0 # seconds
|
|
|
|
|
|
def _is_duplicate(user_id: str, content: str) -> bool:
|
|
"""Return True if this (user, content) pair arrived within the dedup window."""
|
|
now = time.time()
|
|
expired = [k for k, ts in _recent_messages.items() if now - ts > _DEDUP_WINDOW]
|
|
for k in expired:
|
|
del _recent_messages[k]
|
|
key = (user_id, content)
|
|
if key in _recent_messages:
|
|
return True
|
|
_recent_messages[key] = now
|
|
return False
|
|
|
|
|
|
def get_ws_status() -> dict[str, Any]:
|
|
"""Return WebSocket connection status."""
|
|
return {
|
|
"connected": _ws_connected,
|
|
"last_message_time": _last_message_time,
|
|
"reconnect_count": _reconnect_count,
|
|
}
|
|
|
|
|
|
def _handle_message(data: P2ImMessageReceiveV1) -> None:
|
|
global _last_message_time
|
|
_last_message_time = time.time()
|
|
|
|
try:
|
|
event = data.event
|
|
if event is None:
|
|
logger.warning("Received event with no data")
|
|
return
|
|
|
|
message = event.message
|
|
sender = event.sender
|
|
|
|
if message is None:
|
|
logger.warning("Received event with no message")
|
|
return
|
|
|
|
logger.debug(
|
|
"event type=%r chat_type=%r content=%r",
|
|
getattr(message, "message_type", None),
|
|
getattr(message, "chat_type", None),
|
|
(getattr(message, "content", None) or "")[:100],
|
|
)
|
|
|
|
if message.message_type != "text":
|
|
logger.info("Skipping non-text message_type=%r", message.message_type)
|
|
return
|
|
|
|
chat_id: str = message.chat_id or ""
|
|
raw_content: str = message.content or "{}"
|
|
content_obj = json.loads(raw_content)
|
|
text: str = content_obj.get("text", "").strip()
|
|
|
|
import re
|
|
text = re.sub(r"@\S+\s*", "", text).strip()
|
|
|
|
open_id: str = ""
|
|
if sender and sender.sender_id:
|
|
open_id = sender.sender_id.open_id or ""
|
|
|
|
if not text:
|
|
logger.info("Empty text after stripping, ignoring")
|
|
return
|
|
|
|
user_id = open_id or chat_id
|
|
|
|
if _is_duplicate(user_id, text):
|
|
logger.info("Dropping duplicate delivery: user=...%s text=%r", user_id[-8:], text[:60])
|
|
return
|
|
|
|
logger.info("✉ ...%s → %r", open_id[-8:], text[:80])
|
|
|
|
if _main_loop is None:
|
|
logger.error("Main event loop not set; cannot process message")
|
|
return
|
|
|
|
asyncio.run_coroutine_threadsafe(
|
|
_process_message(user_id, chat_id, text),
|
|
_main_loop,
|
|
)
|
|
except Exception:
|
|
logger.exception("Error in _handle_message")
|
|
|
|
|
|
async def _process_message(user_id: str, chat_id: str, text: str) -> None:
|
|
"""Process message: check allowlist, then commands, then route to node or local agent."""
|
|
try:
|
|
set_current_chat(chat_id)
|
|
|
|
if not is_user_allowed(user_id):
|
|
logger.warning("Rejected message from unauthorized user: ...%s", user_id[-8:])
|
|
await send_text(chat_id, "chat_id", "Sorry, you are not authorized to use this bot.")
|
|
return
|
|
|
|
# Text approval fallback: user replies y/n to a pending tool approval
|
|
if text.strip().lower() in ("y", "n", "yes", "no"):
|
|
approved = text.strip().lower() in ("y", "yes")
|
|
from orchestrator.agent import agent as _agent
|
|
from agent.manager import manager as _manager
|
|
conv_id = _agent.get_active_conv(user_id)
|
|
if conv_id:
|
|
session = _manager._sessions.get(conv_id)
|
|
if (
|
|
session
|
|
and session.sdk_session
|
|
and session.sdk_session._pending_approval
|
|
and not session.sdk_session._pending_approval.done()
|
|
):
|
|
await _manager.approve(conv_id, approved)
|
|
label = "✅ 已批准" if approved else "❌ 已拒绝"
|
|
await send_text(chat_id, "chat_id", label)
|
|
return
|
|
|
|
# Text answer fallback: any text reply when a question is pending
|
|
from orchestrator.agent import agent as _agent
|
|
from agent.manager import manager as _manager
|
|
conv_id = _agent.get_active_conv(user_id)
|
|
if conv_id:
|
|
session = _manager._sessions.get(conv_id)
|
|
if (
|
|
session
|
|
and session.sdk_session
|
|
and session.sdk_session._pending_question
|
|
and not session.sdk_session._pending_question.done()
|
|
and session.sdk_session._pending_question_data
|
|
):
|
|
# Use text as answer to the first pending question
|
|
questions = session.sdk_session._pending_question_data.get("questions", [])
|
|
if questions:
|
|
q_text = questions[0].get("question", "")
|
|
await _manager.answer_question(conv_id, {q_text: text.strip()})
|
|
await send_text(chat_id, "chat_id", f"✅ 已回答: {text.strip()}")
|
|
return
|
|
|
|
from config import ROUTER_MODE
|
|
if ROUTER_MODE:
|
|
from router.nodes import get_node_registry
|
|
registry = get_node_registry()
|
|
is_new = registry.track_user(user_id)
|
|
if is_new:
|
|
nodes = registry.get_nodes_for_user(user_id)
|
|
online = [n for n in nodes if n.is_online]
|
|
if online:
|
|
names = ", ".join(n.display_name for n in online)
|
|
await send_text(chat_id, "chat_id", f"Available nodes: {names}")
|
|
|
|
reply = await handle_command(user_id, text)
|
|
if reply is not None:
|
|
if reply:
|
|
await send_text(chat_id, "chat_id", reply)
|
|
return
|
|
|
|
from config import ROUTER_MODE
|
|
if ROUTER_MODE:
|
|
from router.routing_agent import route
|
|
from router.rpc import forward
|
|
from router.nodes import get_node_registry
|
|
|
|
node_id, reason = await route(user_id, chat_id, text)
|
|
|
|
if node_id is None:
|
|
await send_text(chat_id, "chat_id", f"No host available: {reason}")
|
|
return
|
|
|
|
if node_id == "meta":
|
|
registry = get_node_registry()
|
|
nodes = registry.list_nodes()
|
|
if nodes:
|
|
lines = ["Connected Nodes:"]
|
|
for n in nodes:
|
|
active_node = registry.get_active_node(user_id)
|
|
marker = " → " if n.get("node_id") == (active_node.node_id if active_node else None) else " "
|
|
lines.append(f"{marker}{n.get('display_name', 'unknown')} sessions={n.get('sessions', 0)} {n.get('status', 'unknown')}")
|
|
lines.append("\nUse \"/node <name>\" to switch active node.")
|
|
await send_text(chat_id, "chat_id", "\n".join(lines))
|
|
else:
|
|
await send_text(chat_id, "chat_id", "No nodes connected.")
|
|
return
|
|
|
|
try:
|
|
reply = await forward(node_id, user_id, chat_id, text)
|
|
if reply:
|
|
await send_markdown(chat_id, "chat_id", reply)
|
|
except Exception as e:
|
|
logger.exception("Failed to forward to node %s", node_id)
|
|
await send_text(chat_id, "chat_id", f"Error communicating with node: {e}")
|
|
else:
|
|
reply = await agent.run(user_id, text)
|
|
if reply:
|
|
await send_markdown(chat_id, "chat_id", reply)
|
|
except Exception:
|
|
logger.exception("Error processing message for user %s", user_id)
|
|
|
|
|
|
def _handle_any(data: lark.CustomizedEvent) -> None:
|
|
"""Catch-all handler to log any event the SDK receives."""
|
|
marshaled = lark.JSON.marshal(data)
|
|
if marshaled:
|
|
logger.info("RAW CustomizedEvent: %s", marshaled[:500])
|
|
|
|
|
|
def _handle_card_action(data: "P2CardActionTrigger") -> "P2CardActionTriggerResponse":
|
|
"""Handle Feishu card button clicks via register_p2_card_action_trigger.
|
|
|
|
Per docs/feishu/card_callback_communication.md:
|
|
- Must respond within 3 seconds
|
|
- Return P2CardActionTriggerResponse with toast + updated card
|
|
"""
|
|
from lark_oapi.event.callback.model.p2_card_action_trigger import (
|
|
CallBackCard, CallBackToast, P2CardActionTriggerResponse,
|
|
)
|
|
|
|
def _response(toast_type: str, toast_text: str, card_data: dict) -> P2CardActionTriggerResponse:
|
|
resp = P2CardActionTriggerResponse()
|
|
toast = CallBackToast()
|
|
toast.type = toast_type
|
|
toast.content = toast_text
|
|
resp.toast = toast
|
|
card = CallBackCard()
|
|
card.type = "raw"
|
|
card.data = card_data
|
|
resp.card = card
|
|
return resp
|
|
|
|
def _empty_response() -> P2CardActionTriggerResponse:
|
|
return P2CardActionTriggerResponse()
|
|
|
|
try:
|
|
event = data.event
|
|
if not event:
|
|
return _empty_response()
|
|
|
|
action = event.action
|
|
if not action:
|
|
return _empty_response()
|
|
|
|
value: dict = action.value or {}
|
|
action_type = value.get("action")
|
|
conv_id = value.get("conv_id")
|
|
|
|
if not action_type or not conv_id:
|
|
logger.debug("Card action without action/conv_id: %s", value)
|
|
return _empty_response()
|
|
|
|
operator_open_id = (event.operator.open_id or "") if event.operator else ""
|
|
logger.info(
|
|
"Card action: %s for session %s by ...%s",
|
|
action_type, conv_id, operator_open_id[-8:],
|
|
)
|
|
|
|
# --- AskUserQuestion answer ---
|
|
if action_type == "answer_question":
|
|
question = value.get("question", "")
|
|
answer = value.get("answer", "")
|
|
if _main_loop:
|
|
asyncio.run_coroutine_threadsafe(
|
|
_handle_question_answer_async(conv_id, question, answer), _main_loop
|
|
)
|
|
return _response(
|
|
"success", f"已选择: {answer}",
|
|
{
|
|
"schema": "2.0",
|
|
"header": {
|
|
"title": {"tag": "plain_text", "content": "❓ Claude Code 提问"},
|
|
"template": "green",
|
|
},
|
|
"body": {
|
|
"elements": [
|
|
{"tag": "markdown", "content": f"**{question}**\n\n✅ 已选择: **{answer}**"},
|
|
],
|
|
},
|
|
},
|
|
)
|
|
|
|
# --- Tool approval ---
|
|
approved = action_type == "approve"
|
|
if _main_loop:
|
|
asyncio.run_coroutine_threadsafe(
|
|
_handle_approval_async(conv_id, approved), _main_loop
|
|
)
|
|
|
|
if approved:
|
|
toast_type, toast_text = "success", "✅ 已批准"
|
|
card_status, template = "✅ **已批准**", "green"
|
|
else:
|
|
toast_type, toast_text = "warning", "❌ 已拒绝"
|
|
card_status, template = "❌ **已拒绝**", "red"
|
|
|
|
return _response(
|
|
toast_type, toast_text,
|
|
{
|
|
"schema": "2.0",
|
|
"header": {
|
|
"title": {"tag": "plain_text", "content": "🔐 权限审批"},
|
|
"template": template,
|
|
},
|
|
"body": {"elements": [{"tag": "markdown", "content": card_status}]},
|
|
},
|
|
)
|
|
|
|
except Exception:
|
|
logger.exception("Error handling card action")
|
|
return P2CardActionTriggerResponse()
|
|
|
|
|
|
async def _handle_approval_async(conv_id: str, approved: bool) -> None:
|
|
"""Process a card approval action."""
|
|
from agent.manager import manager
|
|
await manager.approve(conv_id, approved)
|
|
|
|
|
|
async def _handle_question_answer_async(conv_id: str, question: str, answer: str) -> None:
|
|
"""Process a question answer from card callback."""
|
|
from agent.manager import manager
|
|
await manager.answer_question(conv_id, {question: answer})
|
|
|
|
|
|
def build_event_handler() -> lark.EventDispatcherHandler:
|
|
"""Construct the EventDispatcherHandler with all registered callbacks."""
|
|
handler = (
|
|
lark.EventDispatcherHandler.builder("", "")
|
|
.register_p2_im_message_receive_v1(_handle_message)
|
|
.register_p1_customized_event("im.message.receive_v1", _handle_any)
|
|
.register_p2_card_action_trigger(_handle_card_action)
|
|
.build()
|
|
)
|
|
return handler
|
|
|
|
|
|
def start_websocket_client(loop: asyncio.AbstractEventLoop) -> None:
|
|
"""
|
|
Start the lark-oapi WebSocket long-connection client in a background thread.
|
|
Must be called after the asyncio event loop is running.
|
|
"""
|
|
global _main_loop
|
|
_main_loop = loop
|
|
|
|
def _run_with_reconnect() -> None:
|
|
global _ws_connected, _reconnect_count
|
|
backoff = 1.0
|
|
max_backoff = 60.0
|
|
|
|
# lark_oapi.ws.client captures the event loop at module import time.
|
|
# In standalone mode uvicorn already owns the main loop, so we create
|
|
# a fresh loop for this thread and redirect the lark module to use it.
|
|
thread_loop = asyncio.new_event_loop()
|
|
asyncio.set_event_loop(thread_loop)
|
|
import lark_oapi.ws.client as _lark_ws_client
|
|
_lark_ws_client.loop = thread_loop
|
|
|
|
while True:
|
|
try:
|
|
_ws_connected = False
|
|
event_handler = build_event_handler()
|
|
ws_client = lark.ws.Client(
|
|
FEISHU_APP_ID,
|
|
FEISHU_APP_SECRET,
|
|
event_handler=event_handler,
|
|
log_level=lark.LogLevel.INFO,
|
|
)
|
|
|
|
logger.info("Starting Feishu long-connection client...")
|
|
_ws_connected = True
|
|
_reconnect_count += 1
|
|
ws_client.start()
|
|
logger.warning("WebSocket disconnected, will reconnect...")
|
|
|
|
except Exception as e:
|
|
logger.error("WebSocket error: %s", e)
|
|
|
|
finally:
|
|
_ws_connected = False
|
|
|
|
logger.info("Reconnecting in %.1fs (attempt %d)...", backoff, _reconnect_count)
|
|
time.sleep(backoff)
|
|
backoff = min(backoff * 2, max_backoff)
|
|
|
|
thread = threading.Thread(target=_run_with_reconnect, daemon=True, name="feishu-ws")
|
|
thread.start()
|
|
logger.info("Feishu WebSocket thread started")
|