From 80e4953cf9295176e642dfaf6251f996b35c0225 Mon Sep 17 00:00:00 2001 From: "Yuyao Huang (Sam)" Date: Sat, 28 Mar 2026 15:53:44 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E4=BC=98=E5=8C=96WebSocket=E8=BF=9E?= =?UTF-8?q?=E6=8E=A5=E5=92=8C=E5=BF=83=E8=B7=B3=E6=9C=BA=E5=88=B6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 在main.py和standalone.py中添加ws_ping_interval和ws_ping_timeout配置 - 调整ws.py中的心跳发送逻辑,先发送ping再等待 - 在host_client中优化消息处理,使用任务队列处理转发请求 - 更新WebTool以适配新的API格式并增加搜索结果限制 - 在agent.py中添加日期显示和web调用次数限制 - 修复bot/handler.py中的事件循环问题 --- bot/handler.py | 8 ++++++++ host_client/main.py | 42 +++++++++++++++++++++++--------------- main.py | 2 ++ orchestrator/agent.py | 20 ++++++++++++++++++ orchestrator/tools.py | 47 ++++++++++++++++++++++++++++++------------- router/ws.py | 2 +- standalone.py | 2 ++ 7 files changed, 92 insertions(+), 31 deletions(-) diff --git a/bot/handler.py b/bot/handler.py index 0d142d6..cc99990 100644 --- a/bot/handler.py +++ b/bot/handler.py @@ -184,6 +184,14 @@ def start_websocket_client(loop: asyncio.AbstractEventLoop) -> None: backoff = 1.0 max_backoff = 60.0 + # lark_oapi.ws.client captures the event loop at module import time. + # In standalone mode uvicorn already owns the main loop, so we create + # a fresh loop for this thread and redirect the lark module to use it. + thread_loop = asyncio.new_event_loop() + asyncio.set_event_loop(thread_loop) + import lark_oapi.ws.client as _lark_ws_client + _lark_ws_client.loop = thread_loop + while True: try: _ws_connected = False diff --git a/host_client/main.py b/host_client/main.py index 65bc991..9e8e8f4 100644 --- a/host_client/main.py +++ b/host_client/main.py @@ -43,6 +43,7 @@ class NodeClient: self._running = False self._last_heartbeat = time.time() self._reconnect_delay = 1.0 + self._forward_tasks: set[asyncio.Task] = set() async def connect(self) -> bool: """Connect to the router WebSocket.""" @@ -53,9 +54,9 @@ class NodeClient: try: self.ws = await websockets.connect( self.config.router_url, - extra_headers=headers, - ping_interval=30, - ping_timeout=10, + additional_headers=headers, + ping_interval=20, + ping_timeout=60, ) logger.info("Connected to router: %s", self.config.router_url) self._reconnect_delay = 1.0 @@ -145,17 +146,9 @@ class NodeClient: except Exception as e: logger.error("Failed to send status: %s", e) - async def handle_message(self, data: str) -> None: - """Handle an incoming message from the router.""" - try: - msg = decode(data) - except Exception as e: - logger.error("Failed to decode message: %s", e) - return - - if isinstance(msg, ForwardRequest): - await self.handle_forward(msg) - elif isinstance(msg, Heartbeat): + async def handle_message_decoded(self, msg: Any) -> None: + """Handle an already-decoded message from the router.""" + if isinstance(msg, Heartbeat): if msg.type == "ping": if self.ws: try: @@ -165,7 +158,7 @@ class NodeClient: elif msg.type == "pong": self._last_heartbeat = time.time() else: - logger.debug("Received message type: %s", msg.type) + logger.debug("Received message type: %s", type(msg).__name__) async def receive_loop(self) -> None: """Main receive loop for incoming messages.""" @@ -174,7 +167,20 @@ class NodeClient: try: async for data in self.ws: - await self.handle_message(data) + try: + msg = decode(data) + except Exception as e: + logger.error("Failed to decode message: %s", e) + continue + + if isinstance(msg, ForwardRequest): + # Dispatch as a task so pings are handled without waiting + # for the full agent run to complete. + task = asyncio.create_task(self.handle_forward(msg)) + self._forward_tasks.add(task) + task.add_done_callback(self._forward_tasks.discard) + else: + await self.handle_message_decoded(msg) except websockets.ConnectionClosed as e: logger.warning("Connection closed: %s", e) except Exception as e: @@ -243,6 +249,10 @@ class NodeClient: async def stop(self) -> None: """Stop the client.""" self._running = False + for task in list(self._forward_tasks): + task.cancel() + if self._forward_tasks: + await asyncio.gather(*self._forward_tasks, return_exceptions=True) if self.ws: await self.ws.close() await manager.stop() diff --git a/main.py b/main.py index 482ee4b..e8535d9 100644 --- a/main.py +++ b/main.py @@ -111,4 +111,6 @@ if __name__ == "__main__": port=8000, reload=False, log_level="info", + ws_ping_interval=20, + ws_ping_timeout=60, ) diff --git a/orchestrator/agent.py b/orchestrator/agent.py index 7619e5d..af32397 100644 --- a/orchestrator/agent.py +++ b/orchestrator/agent.py @@ -30,6 +30,8 @@ logger = logging.getLogger(__name__) SYSTEM_PROMPT_TEMPLATE = """You are PhoneWork, an AI assistant that helps users control Claude Code \ from their phone via Feishu (飞书). +Today's date: {today} + You manage Claude Code sessions. Each session has a conv_id and runs in a project directory. Base working directory: {working_dir} @@ -46,12 +48,16 @@ Your responsibilities: 4. Close session: call `close_conversation`. 5. GENERAL QUESTIONS: If the user asks a general question (not about a specific project or file), \ answer directly using your own knowledge. Do NOT create a session for simple Q&A. +6. WEB / SEARCH: Use the `web` tool when the user needs current information. \ + Call it ONCE (or at most twice with a refined query). Then synthesize and reply — \ + do NOT keep searching in a loop. If the first search returns results, use them. Guidelines: - Relay Claude Code's output verbatim. - If no active session and the user sends a task without naming a directory, ask them which project. - For general knowledge questions (e.g., "what is a Python generator?", "explain async/await"), \ answer directly without creating a session. +- After using any tool, always produce a final text reply to the user. Never end a turn on a tool call. - Keep your own words brief — let Claude Code's output speak. - Reply in the same language the user uses (Chinese or English). """ @@ -111,6 +117,8 @@ class OrchestrationAgent: self._passthrough: dict[str, bool] = defaultdict(lambda: False) def _build_system_prompt(self, user_id: str) -> str: + from datetime import date + today = date.today().strftime("%Y-%m-%d") conv_id = self._active_conv[user_id] if conv_id: active_line = f"ACTIVE SESSION: conv_id={conv_id!r} ← use this for all follow-up messages" @@ -119,6 +127,7 @@ class OrchestrationAgent: return SYSTEM_PROMPT_TEMPLATE.format( working_dir=WORKING_DIR, active_session_line=active_line, + today=today, ) def get_active_conv(self, user_id: str) -> Optional[str]: @@ -181,6 +190,7 @@ class OrchestrationAgent: reply = "" try: + web_calls = 0 for iteration in range(MAX_ITERATIONS): logger.debug(" LLM call #%d", iteration) ai_msg: AIMessage = await self._llm_with_tools.ainvoke(messages) @@ -201,6 +211,16 @@ class OrchestrationAgent: ) logger.info(" ⚙ %s(%s)", tool_name, args_summary) + if tool_name == "web": + web_calls += 1 + if web_calls > 2: + result = "Web search limit reached. Synthesize from results already obtained." + logger.warning(" web call limit exceeded, blocking") + messages.append( + ToolMessage(content=str(result), tool_call_id=tool_id) + ) + continue + tool_obj = _TOOL_MAP.get(tool_name) if tool_obj is None: result = f"Unknown tool: {tool_name}" diff --git a/orchestrator/tools.py b/orchestrator/tools.py index fd948ba..aac0d79 100644 --- a/orchestrator/tools.py +++ b/orchestrator/tools.py @@ -553,18 +553,31 @@ class WebTool(BaseTool): payload = { "jsonrpc": "2.0", "id": 1, - "method": "metaso_web_search", - "params": {"query": query, "scope": scope or "webpage"}, + "method": "tools/call", + "params": { + "name": "metaso_web_search", + "arguments": {"q": query, "scope": scope or "webpage", "size": 5, "includeSummary": True}, + }, } resp = await client.post(base_url, json=payload, headers=headers) data = resp.json() if "error" in data: return json.dumps({"error": data["error"]}, ensure_ascii=False) - results = data.get("result", {}).get("results", [])[:5] + content_text = data.get("result", {}).get("content", [{}])[0].get("text", "") + result_data = json.loads(content_text) if content_text else {} + webpages = result_data.get("webpages", [])[:5] output = [] - for r in results: - output.append(f"**{r.get('title', 'No title')}**\n{r.get('snippet', '')}\n{r.get('url', '')}") - return json.dumps({"results": "\n\n".join(output)[:max_chars]}, ensure_ascii=False) + for r in webpages: + date = r.get("date", "") + title = r.get("title", "No title") + snippet = r.get("snippet", "")[:300] + link = r.get("link", "") + output.append(f"[{date}] **{title}**\n{snippet}\n{link}") + total = result_data.get("total", 0) + return json.dumps({ + "total": total, + "results": "\n\n".join(output)[:max_chars], + }, ensure_ascii=False) elif action == "fetch": if not url: @@ -572,15 +585,18 @@ class WebTool(BaseTool): payload = { "jsonrpc": "2.0", "id": 1, - "method": "metaso_web_reader", - "params": {"url": url, "format": "markdown"}, + "method": "tools/call", + "params": { + "name": "metaso_web_reader", + "arguments": {"url": url, "format": "markdown"}, + }, } resp = await client.post(base_url, json=payload, headers=headers) data = resp.json() if "error" in data: return json.dumps({"error": data["error"]}, ensure_ascii=False) - content = data.get("result", {}).get("content", "") - return json.dumps({"content": content[:max_chars]}, ensure_ascii=False) + content_text = data.get("result", {}).get("content", [{}])[0].get("text", "") + return json.dumps({"content": content_text[:max_chars]}, ensure_ascii=False) elif action == "ask": if not query: @@ -588,15 +604,18 @@ class WebTool(BaseTool): payload = { "jsonrpc": "2.0", "id": 1, - "method": "metaso_chat", - "params": {"query": query}, + "method": "tools/call", + "params": { + "name": "metaso_chat", + "arguments": {"message": query}, + }, } resp = await client.post(base_url, json=payload, headers=headers) data = resp.json() if "error" in data: return json.dumps({"error": data["error"]}, ensure_ascii=False) - answer = data.get("result", {}).get("answer", "") - return json.dumps({"answer": answer[:max_chars]}, ensure_ascii=False) + content_text = data.get("result", {}).get("content", [{}])[0].get("text", "") + return json.dumps({"answer": content_text[:max_chars]}, ensure_ascii=False) else: return json.dumps({"error": f"Unknown action: {action}"}, ensure_ascii=False) diff --git a/router/ws.py b/router/ws.py index 7b688c5..4c3bc0c 100644 --- a/router/ws.py +++ b/router/ws.py @@ -53,11 +53,11 @@ async def ws_node_endpoint(websocket: WebSocket) -> None: """Send periodic pings to the host client.""" try: while True: - await asyncio.sleep(30) try: await websocket.send_text(encode(Heartbeat(type="ping"))) except Exception: break + await asyncio.sleep(30) except asyncio.CancelledError: pass diff --git a/standalone.py b/standalone.py index 86171b6..eb5e8cd 100644 --- a/standalone.py +++ b/standalone.py @@ -41,6 +41,8 @@ async def run_standalone() -> None: host="0.0.0.0", port=8000, log_level="info", + ws_ping_interval=20, + ws_ping_timeout=60, ) server = uvicorn.Server(config_obj)