feat: 优化WebSocket连接和心跳机制

- 在main.py和standalone.py中添加ws_ping_interval和ws_ping_timeout配置
- 调整ws.py中的心跳发送逻辑,先发送ping再等待
- 在host_client中优化消息处理,使用任务队列处理转发请求
- 更新WebTool以适配新的API格式并增加搜索结果限制
- 在agent.py中添加日期显示和web调用次数限制
- 修复bot/handler.py中的事件循环问题
This commit is contained in:
Yuyao Huang (Sam) 2026-03-28 15:53:44 +08:00
parent a3622ce26d
commit 80e4953cf9
7 changed files with 92 additions and 31 deletions

View File

@ -184,6 +184,14 @@ def start_websocket_client(loop: asyncio.AbstractEventLoop) -> None:
backoff = 1.0
max_backoff = 60.0
# lark_oapi.ws.client captures the event loop at module import time.
# In standalone mode uvicorn already owns the main loop, so we create
# a fresh loop for this thread and redirect the lark module to use it.
thread_loop = asyncio.new_event_loop()
asyncio.set_event_loop(thread_loop)
import lark_oapi.ws.client as _lark_ws_client
_lark_ws_client.loop = thread_loop
while True:
try:
_ws_connected = False

View File

@ -43,6 +43,7 @@ class NodeClient:
self._running = False
self._last_heartbeat = time.time()
self._reconnect_delay = 1.0
self._forward_tasks: set[asyncio.Task] = set()
async def connect(self) -> bool:
"""Connect to the router WebSocket."""
@ -53,9 +54,9 @@ class NodeClient:
try:
self.ws = await websockets.connect(
self.config.router_url,
extra_headers=headers,
ping_interval=30,
ping_timeout=10,
additional_headers=headers,
ping_interval=20,
ping_timeout=60,
)
logger.info("Connected to router: %s", self.config.router_url)
self._reconnect_delay = 1.0
@ -145,17 +146,9 @@ class NodeClient:
except Exception as e:
logger.error("Failed to send status: %s", e)
async def handle_message(self, data: str) -> None:
"""Handle an incoming message from the router."""
try:
msg = decode(data)
except Exception as e:
logger.error("Failed to decode message: %s", e)
return
if isinstance(msg, ForwardRequest):
await self.handle_forward(msg)
elif isinstance(msg, Heartbeat):
async def handle_message_decoded(self, msg: Any) -> None:
"""Handle an already-decoded message from the router."""
if isinstance(msg, Heartbeat):
if msg.type == "ping":
if self.ws:
try:
@ -165,7 +158,7 @@ class NodeClient:
elif msg.type == "pong":
self._last_heartbeat = time.time()
else:
logger.debug("Received message type: %s", msg.type)
logger.debug("Received message type: %s", type(msg).__name__)
async def receive_loop(self) -> None:
"""Main receive loop for incoming messages."""
@ -174,7 +167,20 @@ class NodeClient:
try:
async for data in self.ws:
await self.handle_message(data)
try:
msg = decode(data)
except Exception as e:
logger.error("Failed to decode message: %s", e)
continue
if isinstance(msg, ForwardRequest):
# Dispatch as a task so pings are handled without waiting
# for the full agent run to complete.
task = asyncio.create_task(self.handle_forward(msg))
self._forward_tasks.add(task)
task.add_done_callback(self._forward_tasks.discard)
else:
await self.handle_message_decoded(msg)
except websockets.ConnectionClosed as e:
logger.warning("Connection closed: %s", e)
except Exception as e:
@ -243,6 +249,10 @@ class NodeClient:
async def stop(self) -> None:
"""Stop the client."""
self._running = False
for task in list(self._forward_tasks):
task.cancel()
if self._forward_tasks:
await asyncio.gather(*self._forward_tasks, return_exceptions=True)
if self.ws:
await self.ws.close()
await manager.stop()

View File

@ -111,4 +111,6 @@ if __name__ == "__main__":
port=8000,
reload=False,
log_level="info",
ws_ping_interval=20,
ws_ping_timeout=60,
)

View File

@ -30,6 +30,8 @@ logger = logging.getLogger(__name__)
SYSTEM_PROMPT_TEMPLATE = """You are PhoneWork, an AI assistant that helps users control Claude Code \
from their phone via Feishu (飞书).
Today's date: {today}
You manage Claude Code sessions. Each session has a conv_id and runs in a project directory.
Base working directory: {working_dir}
@ -46,12 +48,16 @@ Your responsibilities:
4. Close session: call `close_conversation`.
5. GENERAL QUESTIONS: If the user asks a general question (not about a specific project or file), \
answer directly using your own knowledge. Do NOT create a session for simple Q&A.
6. WEB / SEARCH: Use the `web` tool when the user needs current information. \
Call it ONCE (or at most twice with a refined query). Then synthesize and reply \
do NOT keep searching in a loop. If the first search returns results, use them.
Guidelines:
- Relay Claude Code's output verbatim.
- If no active session and the user sends a task without naming a directory, ask them which project.
- For general knowledge questions (e.g., "what is a Python generator?", "explain async/await"), \
answer directly without creating a session.
- After using any tool, always produce a final text reply to the user. Never end a turn on a tool call.
- Keep your own words brief let Claude Code's output speak.
- Reply in the same language the user uses (Chinese or English).
"""
@ -111,6 +117,8 @@ class OrchestrationAgent:
self._passthrough: dict[str, bool] = defaultdict(lambda: False)
def _build_system_prompt(self, user_id: str) -> str:
from datetime import date
today = date.today().strftime("%Y-%m-%d")
conv_id = self._active_conv[user_id]
if conv_id:
active_line = f"ACTIVE SESSION: conv_id={conv_id!r} ← use this for all follow-up messages"
@ -119,6 +127,7 @@ class OrchestrationAgent:
return SYSTEM_PROMPT_TEMPLATE.format(
working_dir=WORKING_DIR,
active_session_line=active_line,
today=today,
)
def get_active_conv(self, user_id: str) -> Optional[str]:
@ -181,6 +190,7 @@ class OrchestrationAgent:
reply = ""
try:
web_calls = 0
for iteration in range(MAX_ITERATIONS):
logger.debug(" LLM call #%d", iteration)
ai_msg: AIMessage = await self._llm_with_tools.ainvoke(messages)
@ -201,6 +211,16 @@ class OrchestrationAgent:
)
logger.info("%s(%s)", tool_name, args_summary)
if tool_name == "web":
web_calls += 1
if web_calls > 2:
result = "Web search limit reached. Synthesize from results already obtained."
logger.warning(" web call limit exceeded, blocking")
messages.append(
ToolMessage(content=str(result), tool_call_id=tool_id)
)
continue
tool_obj = _TOOL_MAP.get(tool_name)
if tool_obj is None:
result = f"Unknown tool: {tool_name}"

View File

@ -553,18 +553,31 @@ class WebTool(BaseTool):
payload = {
"jsonrpc": "2.0",
"id": 1,
"method": "metaso_web_search",
"params": {"query": query, "scope": scope or "webpage"},
"method": "tools/call",
"params": {
"name": "metaso_web_search",
"arguments": {"q": query, "scope": scope or "webpage", "size": 5, "includeSummary": True},
},
}
resp = await client.post(base_url, json=payload, headers=headers)
data = resp.json()
if "error" in data:
return json.dumps({"error": data["error"]}, ensure_ascii=False)
results = data.get("result", {}).get("results", [])[:5]
content_text = data.get("result", {}).get("content", [{}])[0].get("text", "")
result_data = json.loads(content_text) if content_text else {}
webpages = result_data.get("webpages", [])[:5]
output = []
for r in results:
output.append(f"**{r.get('title', 'No title')}**\n{r.get('snippet', '')}\n{r.get('url', '')}")
return json.dumps({"results": "\n\n".join(output)[:max_chars]}, ensure_ascii=False)
for r in webpages:
date = r.get("date", "")
title = r.get("title", "No title")
snippet = r.get("snippet", "")[:300]
link = r.get("link", "")
output.append(f"[{date}] **{title}**\n{snippet}\n{link}")
total = result_data.get("total", 0)
return json.dumps({
"total": total,
"results": "\n\n".join(output)[:max_chars],
}, ensure_ascii=False)
elif action == "fetch":
if not url:
@ -572,15 +585,18 @@ class WebTool(BaseTool):
payload = {
"jsonrpc": "2.0",
"id": 1,
"method": "metaso_web_reader",
"params": {"url": url, "format": "markdown"},
"method": "tools/call",
"params": {
"name": "metaso_web_reader",
"arguments": {"url": url, "format": "markdown"},
},
}
resp = await client.post(base_url, json=payload, headers=headers)
data = resp.json()
if "error" in data:
return json.dumps({"error": data["error"]}, ensure_ascii=False)
content = data.get("result", {}).get("content", "")
return json.dumps({"content": content[:max_chars]}, ensure_ascii=False)
content_text = data.get("result", {}).get("content", [{}])[0].get("text", "")
return json.dumps({"content": content_text[:max_chars]}, ensure_ascii=False)
elif action == "ask":
if not query:
@ -588,15 +604,18 @@ class WebTool(BaseTool):
payload = {
"jsonrpc": "2.0",
"id": 1,
"method": "metaso_chat",
"params": {"query": query},
"method": "tools/call",
"params": {
"name": "metaso_chat",
"arguments": {"message": query},
},
}
resp = await client.post(base_url, json=payload, headers=headers)
data = resp.json()
if "error" in data:
return json.dumps({"error": data["error"]}, ensure_ascii=False)
answer = data.get("result", {}).get("answer", "")
return json.dumps({"answer": answer[:max_chars]}, ensure_ascii=False)
content_text = data.get("result", {}).get("content", [{}])[0].get("text", "")
return json.dumps({"answer": content_text[:max_chars]}, ensure_ascii=False)
else:
return json.dumps({"error": f"Unknown action: {action}"}, ensure_ascii=False)

View File

@ -53,11 +53,11 @@ async def ws_node_endpoint(websocket: WebSocket) -> None:
"""Send periodic pings to the host client."""
try:
while True:
await asyncio.sleep(30)
try:
await websocket.send_text(encode(Heartbeat(type="ping")))
except Exception:
break
await asyncio.sleep(30)
except asyncio.CancelledError:
pass

View File

@ -41,6 +41,8 @@ async def run_standalone() -> None:
host="0.0.0.0",
port=8000,
log_level="info",
ws_ping_interval=20,
ws_ping_timeout=60,
)
server = uvicorn.Server(config_obj)