feat: 优化WebSocket连接和心跳机制
- 在main.py和standalone.py中添加ws_ping_interval和ws_ping_timeout配置 - 调整ws.py中的心跳发送逻辑,先发送ping再等待 - 在host_client中优化消息处理,使用任务队列处理转发请求 - 更新WebTool以适配新的API格式并增加搜索结果限制 - 在agent.py中添加日期显示和web调用次数限制 - 修复bot/handler.py中的事件循环问题
This commit is contained in:
parent
a3622ce26d
commit
80e4953cf9
@ -184,6 +184,14 @@ def start_websocket_client(loop: asyncio.AbstractEventLoop) -> None:
|
||||
backoff = 1.0
|
||||
max_backoff = 60.0
|
||||
|
||||
# lark_oapi.ws.client captures the event loop at module import time.
|
||||
# In standalone mode uvicorn already owns the main loop, so we create
|
||||
# a fresh loop for this thread and redirect the lark module to use it.
|
||||
thread_loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(thread_loop)
|
||||
import lark_oapi.ws.client as _lark_ws_client
|
||||
_lark_ws_client.loop = thread_loop
|
||||
|
||||
while True:
|
||||
try:
|
||||
_ws_connected = False
|
||||
|
||||
@ -43,6 +43,7 @@ class NodeClient:
|
||||
self._running = False
|
||||
self._last_heartbeat = time.time()
|
||||
self._reconnect_delay = 1.0
|
||||
self._forward_tasks: set[asyncio.Task] = set()
|
||||
|
||||
async def connect(self) -> bool:
|
||||
"""Connect to the router WebSocket."""
|
||||
@ -53,9 +54,9 @@ class NodeClient:
|
||||
try:
|
||||
self.ws = await websockets.connect(
|
||||
self.config.router_url,
|
||||
extra_headers=headers,
|
||||
ping_interval=30,
|
||||
ping_timeout=10,
|
||||
additional_headers=headers,
|
||||
ping_interval=20,
|
||||
ping_timeout=60,
|
||||
)
|
||||
logger.info("Connected to router: %s", self.config.router_url)
|
||||
self._reconnect_delay = 1.0
|
||||
@ -145,17 +146,9 @@ class NodeClient:
|
||||
except Exception as e:
|
||||
logger.error("Failed to send status: %s", e)
|
||||
|
||||
async def handle_message(self, data: str) -> None:
|
||||
"""Handle an incoming message from the router."""
|
||||
try:
|
||||
msg = decode(data)
|
||||
except Exception as e:
|
||||
logger.error("Failed to decode message: %s", e)
|
||||
return
|
||||
|
||||
if isinstance(msg, ForwardRequest):
|
||||
await self.handle_forward(msg)
|
||||
elif isinstance(msg, Heartbeat):
|
||||
async def handle_message_decoded(self, msg: Any) -> None:
|
||||
"""Handle an already-decoded message from the router."""
|
||||
if isinstance(msg, Heartbeat):
|
||||
if msg.type == "ping":
|
||||
if self.ws:
|
||||
try:
|
||||
@ -165,7 +158,7 @@ class NodeClient:
|
||||
elif msg.type == "pong":
|
||||
self._last_heartbeat = time.time()
|
||||
else:
|
||||
logger.debug("Received message type: %s", msg.type)
|
||||
logger.debug("Received message type: %s", type(msg).__name__)
|
||||
|
||||
async def receive_loop(self) -> None:
|
||||
"""Main receive loop for incoming messages."""
|
||||
@ -174,7 +167,20 @@ class NodeClient:
|
||||
|
||||
try:
|
||||
async for data in self.ws:
|
||||
await self.handle_message(data)
|
||||
try:
|
||||
msg = decode(data)
|
||||
except Exception as e:
|
||||
logger.error("Failed to decode message: %s", e)
|
||||
continue
|
||||
|
||||
if isinstance(msg, ForwardRequest):
|
||||
# Dispatch as a task so pings are handled without waiting
|
||||
# for the full agent run to complete.
|
||||
task = asyncio.create_task(self.handle_forward(msg))
|
||||
self._forward_tasks.add(task)
|
||||
task.add_done_callback(self._forward_tasks.discard)
|
||||
else:
|
||||
await self.handle_message_decoded(msg)
|
||||
except websockets.ConnectionClosed as e:
|
||||
logger.warning("Connection closed: %s", e)
|
||||
except Exception as e:
|
||||
@ -243,6 +249,10 @@ class NodeClient:
|
||||
async def stop(self) -> None:
|
||||
"""Stop the client."""
|
||||
self._running = False
|
||||
for task in list(self._forward_tasks):
|
||||
task.cancel()
|
||||
if self._forward_tasks:
|
||||
await asyncio.gather(*self._forward_tasks, return_exceptions=True)
|
||||
if self.ws:
|
||||
await self.ws.close()
|
||||
await manager.stop()
|
||||
|
||||
2
main.py
2
main.py
@ -111,4 +111,6 @@ if __name__ == "__main__":
|
||||
port=8000,
|
||||
reload=False,
|
||||
log_level="info",
|
||||
ws_ping_interval=20,
|
||||
ws_ping_timeout=60,
|
||||
)
|
||||
|
||||
@ -30,6 +30,8 @@ logger = logging.getLogger(__name__)
|
||||
SYSTEM_PROMPT_TEMPLATE = """You are PhoneWork, an AI assistant that helps users control Claude Code \
|
||||
from their phone via Feishu (飞书).
|
||||
|
||||
Today's date: {today}
|
||||
|
||||
You manage Claude Code sessions. Each session has a conv_id and runs in a project directory.
|
||||
|
||||
Base working directory: {working_dir}
|
||||
@ -46,12 +48,16 @@ Your responsibilities:
|
||||
4. Close session: call `close_conversation`.
|
||||
5. GENERAL QUESTIONS: If the user asks a general question (not about a specific project or file), \
|
||||
answer directly using your own knowledge. Do NOT create a session for simple Q&A.
|
||||
6. WEB / SEARCH: Use the `web` tool when the user needs current information. \
|
||||
Call it ONCE (or at most twice with a refined query). Then synthesize and reply — \
|
||||
do NOT keep searching in a loop. If the first search returns results, use them.
|
||||
|
||||
Guidelines:
|
||||
- Relay Claude Code's output verbatim.
|
||||
- If no active session and the user sends a task without naming a directory, ask them which project.
|
||||
- For general knowledge questions (e.g., "what is a Python generator?", "explain async/await"), \
|
||||
answer directly without creating a session.
|
||||
- After using any tool, always produce a final text reply to the user. Never end a turn on a tool call.
|
||||
- Keep your own words brief — let Claude Code's output speak.
|
||||
- Reply in the same language the user uses (Chinese or English).
|
||||
"""
|
||||
@ -111,6 +117,8 @@ class OrchestrationAgent:
|
||||
self._passthrough: dict[str, bool] = defaultdict(lambda: False)
|
||||
|
||||
def _build_system_prompt(self, user_id: str) -> str:
|
||||
from datetime import date
|
||||
today = date.today().strftime("%Y-%m-%d")
|
||||
conv_id = self._active_conv[user_id]
|
||||
if conv_id:
|
||||
active_line = f"ACTIVE SESSION: conv_id={conv_id!r} ← use this for all follow-up messages"
|
||||
@ -119,6 +127,7 @@ class OrchestrationAgent:
|
||||
return SYSTEM_PROMPT_TEMPLATE.format(
|
||||
working_dir=WORKING_DIR,
|
||||
active_session_line=active_line,
|
||||
today=today,
|
||||
)
|
||||
|
||||
def get_active_conv(self, user_id: str) -> Optional[str]:
|
||||
@ -181,6 +190,7 @@ class OrchestrationAgent:
|
||||
|
||||
reply = ""
|
||||
try:
|
||||
web_calls = 0
|
||||
for iteration in range(MAX_ITERATIONS):
|
||||
logger.debug(" LLM call #%d", iteration)
|
||||
ai_msg: AIMessage = await self._llm_with_tools.ainvoke(messages)
|
||||
@ -201,6 +211,16 @@ class OrchestrationAgent:
|
||||
)
|
||||
logger.info(" ⚙ %s(%s)", tool_name, args_summary)
|
||||
|
||||
if tool_name == "web":
|
||||
web_calls += 1
|
||||
if web_calls > 2:
|
||||
result = "Web search limit reached. Synthesize from results already obtained."
|
||||
logger.warning(" web call limit exceeded, blocking")
|
||||
messages.append(
|
||||
ToolMessage(content=str(result), tool_call_id=tool_id)
|
||||
)
|
||||
continue
|
||||
|
||||
tool_obj = _TOOL_MAP.get(tool_name)
|
||||
if tool_obj is None:
|
||||
result = f"Unknown tool: {tool_name}"
|
||||
|
||||
@ -553,18 +553,31 @@ class WebTool(BaseTool):
|
||||
payload = {
|
||||
"jsonrpc": "2.0",
|
||||
"id": 1,
|
||||
"method": "metaso_web_search",
|
||||
"params": {"query": query, "scope": scope or "webpage"},
|
||||
"method": "tools/call",
|
||||
"params": {
|
||||
"name": "metaso_web_search",
|
||||
"arguments": {"q": query, "scope": scope or "webpage", "size": 5, "includeSummary": True},
|
||||
},
|
||||
}
|
||||
resp = await client.post(base_url, json=payload, headers=headers)
|
||||
data = resp.json()
|
||||
if "error" in data:
|
||||
return json.dumps({"error": data["error"]}, ensure_ascii=False)
|
||||
results = data.get("result", {}).get("results", [])[:5]
|
||||
content_text = data.get("result", {}).get("content", [{}])[0].get("text", "")
|
||||
result_data = json.loads(content_text) if content_text else {}
|
||||
webpages = result_data.get("webpages", [])[:5]
|
||||
output = []
|
||||
for r in results:
|
||||
output.append(f"**{r.get('title', 'No title')}**\n{r.get('snippet', '')}\n{r.get('url', '')}")
|
||||
return json.dumps({"results": "\n\n".join(output)[:max_chars]}, ensure_ascii=False)
|
||||
for r in webpages:
|
||||
date = r.get("date", "")
|
||||
title = r.get("title", "No title")
|
||||
snippet = r.get("snippet", "")[:300]
|
||||
link = r.get("link", "")
|
||||
output.append(f"[{date}] **{title}**\n{snippet}\n{link}")
|
||||
total = result_data.get("total", 0)
|
||||
return json.dumps({
|
||||
"total": total,
|
||||
"results": "\n\n".join(output)[:max_chars],
|
||||
}, ensure_ascii=False)
|
||||
|
||||
elif action == "fetch":
|
||||
if not url:
|
||||
@ -572,15 +585,18 @@ class WebTool(BaseTool):
|
||||
payload = {
|
||||
"jsonrpc": "2.0",
|
||||
"id": 1,
|
||||
"method": "metaso_web_reader",
|
||||
"params": {"url": url, "format": "markdown"},
|
||||
"method": "tools/call",
|
||||
"params": {
|
||||
"name": "metaso_web_reader",
|
||||
"arguments": {"url": url, "format": "markdown"},
|
||||
},
|
||||
}
|
||||
resp = await client.post(base_url, json=payload, headers=headers)
|
||||
data = resp.json()
|
||||
if "error" in data:
|
||||
return json.dumps({"error": data["error"]}, ensure_ascii=False)
|
||||
content = data.get("result", {}).get("content", "")
|
||||
return json.dumps({"content": content[:max_chars]}, ensure_ascii=False)
|
||||
content_text = data.get("result", {}).get("content", [{}])[0].get("text", "")
|
||||
return json.dumps({"content": content_text[:max_chars]}, ensure_ascii=False)
|
||||
|
||||
elif action == "ask":
|
||||
if not query:
|
||||
@ -588,15 +604,18 @@ class WebTool(BaseTool):
|
||||
payload = {
|
||||
"jsonrpc": "2.0",
|
||||
"id": 1,
|
||||
"method": "metaso_chat",
|
||||
"params": {"query": query},
|
||||
"method": "tools/call",
|
||||
"params": {
|
||||
"name": "metaso_chat",
|
||||
"arguments": {"message": query},
|
||||
},
|
||||
}
|
||||
resp = await client.post(base_url, json=payload, headers=headers)
|
||||
data = resp.json()
|
||||
if "error" in data:
|
||||
return json.dumps({"error": data["error"]}, ensure_ascii=False)
|
||||
answer = data.get("result", {}).get("answer", "")
|
||||
return json.dumps({"answer": answer[:max_chars]}, ensure_ascii=False)
|
||||
content_text = data.get("result", {}).get("content", [{}])[0].get("text", "")
|
||||
return json.dumps({"answer": content_text[:max_chars]}, ensure_ascii=False)
|
||||
|
||||
else:
|
||||
return json.dumps({"error": f"Unknown action: {action}"}, ensure_ascii=False)
|
||||
|
||||
@ -53,11 +53,11 @@ async def ws_node_endpoint(websocket: WebSocket) -> None:
|
||||
"""Send periodic pings to the host client."""
|
||||
try:
|
||||
while True:
|
||||
await asyncio.sleep(30)
|
||||
try:
|
||||
await websocket.send_text(encode(Heartbeat(type="ping")))
|
||||
except Exception:
|
||||
break
|
||||
await asyncio.sleep(30)
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
|
||||
@ -41,6 +41,8 @@ async def run_standalone() -> None:
|
||||
host="0.0.0.0",
|
||||
port=8000,
|
||||
log_level="info",
|
||||
ws_ping_interval=20,
|
||||
ws_ping_timeout=60,
|
||||
)
|
||||
server = uvicorn.Server(config_obj)
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user