feat: 优化WebSocket连接和心跳机制
- 在main.py和standalone.py中添加ws_ping_interval和ws_ping_timeout配置 - 调整ws.py中的心跳发送逻辑,先发送ping再等待 - 在host_client中优化消息处理,使用任务队列处理转发请求 - 更新WebTool以适配新的API格式并增加搜索结果限制 - 在agent.py中添加日期显示和web调用次数限制 - 修复bot/handler.py中的事件循环问题
This commit is contained in:
parent
a3622ce26d
commit
80e4953cf9
@ -184,6 +184,14 @@ def start_websocket_client(loop: asyncio.AbstractEventLoop) -> None:
|
|||||||
backoff = 1.0
|
backoff = 1.0
|
||||||
max_backoff = 60.0
|
max_backoff = 60.0
|
||||||
|
|
||||||
|
# lark_oapi.ws.client captures the event loop at module import time.
|
||||||
|
# In standalone mode uvicorn already owns the main loop, so we create
|
||||||
|
# a fresh loop for this thread and redirect the lark module to use it.
|
||||||
|
thread_loop = asyncio.new_event_loop()
|
||||||
|
asyncio.set_event_loop(thread_loop)
|
||||||
|
import lark_oapi.ws.client as _lark_ws_client
|
||||||
|
_lark_ws_client.loop = thread_loop
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
_ws_connected = False
|
_ws_connected = False
|
||||||
|
|||||||
@ -43,6 +43,7 @@ class NodeClient:
|
|||||||
self._running = False
|
self._running = False
|
||||||
self._last_heartbeat = time.time()
|
self._last_heartbeat = time.time()
|
||||||
self._reconnect_delay = 1.0
|
self._reconnect_delay = 1.0
|
||||||
|
self._forward_tasks: set[asyncio.Task] = set()
|
||||||
|
|
||||||
async def connect(self) -> bool:
|
async def connect(self) -> bool:
|
||||||
"""Connect to the router WebSocket."""
|
"""Connect to the router WebSocket."""
|
||||||
@ -53,9 +54,9 @@ class NodeClient:
|
|||||||
try:
|
try:
|
||||||
self.ws = await websockets.connect(
|
self.ws = await websockets.connect(
|
||||||
self.config.router_url,
|
self.config.router_url,
|
||||||
extra_headers=headers,
|
additional_headers=headers,
|
||||||
ping_interval=30,
|
ping_interval=20,
|
||||||
ping_timeout=10,
|
ping_timeout=60,
|
||||||
)
|
)
|
||||||
logger.info("Connected to router: %s", self.config.router_url)
|
logger.info("Connected to router: %s", self.config.router_url)
|
||||||
self._reconnect_delay = 1.0
|
self._reconnect_delay = 1.0
|
||||||
@ -145,17 +146,9 @@ class NodeClient:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error("Failed to send status: %s", e)
|
logger.error("Failed to send status: %s", e)
|
||||||
|
|
||||||
async def handle_message(self, data: str) -> None:
|
async def handle_message_decoded(self, msg: Any) -> None:
|
||||||
"""Handle an incoming message from the router."""
|
"""Handle an already-decoded message from the router."""
|
||||||
try:
|
if isinstance(msg, Heartbeat):
|
||||||
msg = decode(data)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error("Failed to decode message: %s", e)
|
|
||||||
return
|
|
||||||
|
|
||||||
if isinstance(msg, ForwardRequest):
|
|
||||||
await self.handle_forward(msg)
|
|
||||||
elif isinstance(msg, Heartbeat):
|
|
||||||
if msg.type == "ping":
|
if msg.type == "ping":
|
||||||
if self.ws:
|
if self.ws:
|
||||||
try:
|
try:
|
||||||
@ -165,7 +158,7 @@ class NodeClient:
|
|||||||
elif msg.type == "pong":
|
elif msg.type == "pong":
|
||||||
self._last_heartbeat = time.time()
|
self._last_heartbeat = time.time()
|
||||||
else:
|
else:
|
||||||
logger.debug("Received message type: %s", msg.type)
|
logger.debug("Received message type: %s", type(msg).__name__)
|
||||||
|
|
||||||
async def receive_loop(self) -> None:
|
async def receive_loop(self) -> None:
|
||||||
"""Main receive loop for incoming messages."""
|
"""Main receive loop for incoming messages."""
|
||||||
@ -174,7 +167,20 @@ class NodeClient:
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
async for data in self.ws:
|
async for data in self.ws:
|
||||||
await self.handle_message(data)
|
try:
|
||||||
|
msg = decode(data)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("Failed to decode message: %s", e)
|
||||||
|
continue
|
||||||
|
|
||||||
|
if isinstance(msg, ForwardRequest):
|
||||||
|
# Dispatch as a task so pings are handled without waiting
|
||||||
|
# for the full agent run to complete.
|
||||||
|
task = asyncio.create_task(self.handle_forward(msg))
|
||||||
|
self._forward_tasks.add(task)
|
||||||
|
task.add_done_callback(self._forward_tasks.discard)
|
||||||
|
else:
|
||||||
|
await self.handle_message_decoded(msg)
|
||||||
except websockets.ConnectionClosed as e:
|
except websockets.ConnectionClosed as e:
|
||||||
logger.warning("Connection closed: %s", e)
|
logger.warning("Connection closed: %s", e)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@ -243,6 +249,10 @@ class NodeClient:
|
|||||||
async def stop(self) -> None:
|
async def stop(self) -> None:
|
||||||
"""Stop the client."""
|
"""Stop the client."""
|
||||||
self._running = False
|
self._running = False
|
||||||
|
for task in list(self._forward_tasks):
|
||||||
|
task.cancel()
|
||||||
|
if self._forward_tasks:
|
||||||
|
await asyncio.gather(*self._forward_tasks, return_exceptions=True)
|
||||||
if self.ws:
|
if self.ws:
|
||||||
await self.ws.close()
|
await self.ws.close()
|
||||||
await manager.stop()
|
await manager.stop()
|
||||||
|
|||||||
2
main.py
2
main.py
@ -111,4 +111,6 @@ if __name__ == "__main__":
|
|||||||
port=8000,
|
port=8000,
|
||||||
reload=False,
|
reload=False,
|
||||||
log_level="info",
|
log_level="info",
|
||||||
|
ws_ping_interval=20,
|
||||||
|
ws_ping_timeout=60,
|
||||||
)
|
)
|
||||||
|
|||||||
@ -30,6 +30,8 @@ logger = logging.getLogger(__name__)
|
|||||||
SYSTEM_PROMPT_TEMPLATE = """You are PhoneWork, an AI assistant that helps users control Claude Code \
|
SYSTEM_PROMPT_TEMPLATE = """You are PhoneWork, an AI assistant that helps users control Claude Code \
|
||||||
from their phone via Feishu (飞书).
|
from their phone via Feishu (飞书).
|
||||||
|
|
||||||
|
Today's date: {today}
|
||||||
|
|
||||||
You manage Claude Code sessions. Each session has a conv_id and runs in a project directory.
|
You manage Claude Code sessions. Each session has a conv_id and runs in a project directory.
|
||||||
|
|
||||||
Base working directory: {working_dir}
|
Base working directory: {working_dir}
|
||||||
@ -46,12 +48,16 @@ Your responsibilities:
|
|||||||
4. Close session: call `close_conversation`.
|
4. Close session: call `close_conversation`.
|
||||||
5. GENERAL QUESTIONS: If the user asks a general question (not about a specific project or file), \
|
5. GENERAL QUESTIONS: If the user asks a general question (not about a specific project or file), \
|
||||||
answer directly using your own knowledge. Do NOT create a session for simple Q&A.
|
answer directly using your own knowledge. Do NOT create a session for simple Q&A.
|
||||||
|
6. WEB / SEARCH: Use the `web` tool when the user needs current information. \
|
||||||
|
Call it ONCE (or at most twice with a refined query). Then synthesize and reply — \
|
||||||
|
do NOT keep searching in a loop. If the first search returns results, use them.
|
||||||
|
|
||||||
Guidelines:
|
Guidelines:
|
||||||
- Relay Claude Code's output verbatim.
|
- Relay Claude Code's output verbatim.
|
||||||
- If no active session and the user sends a task without naming a directory, ask them which project.
|
- If no active session and the user sends a task without naming a directory, ask them which project.
|
||||||
- For general knowledge questions (e.g., "what is a Python generator?", "explain async/await"), \
|
- For general knowledge questions (e.g., "what is a Python generator?", "explain async/await"), \
|
||||||
answer directly without creating a session.
|
answer directly without creating a session.
|
||||||
|
- After using any tool, always produce a final text reply to the user. Never end a turn on a tool call.
|
||||||
- Keep your own words brief — let Claude Code's output speak.
|
- Keep your own words brief — let Claude Code's output speak.
|
||||||
- Reply in the same language the user uses (Chinese or English).
|
- Reply in the same language the user uses (Chinese or English).
|
||||||
"""
|
"""
|
||||||
@ -111,6 +117,8 @@ class OrchestrationAgent:
|
|||||||
self._passthrough: dict[str, bool] = defaultdict(lambda: False)
|
self._passthrough: dict[str, bool] = defaultdict(lambda: False)
|
||||||
|
|
||||||
def _build_system_prompt(self, user_id: str) -> str:
|
def _build_system_prompt(self, user_id: str) -> str:
|
||||||
|
from datetime import date
|
||||||
|
today = date.today().strftime("%Y-%m-%d")
|
||||||
conv_id = self._active_conv[user_id]
|
conv_id = self._active_conv[user_id]
|
||||||
if conv_id:
|
if conv_id:
|
||||||
active_line = f"ACTIVE SESSION: conv_id={conv_id!r} ← use this for all follow-up messages"
|
active_line = f"ACTIVE SESSION: conv_id={conv_id!r} ← use this for all follow-up messages"
|
||||||
@ -119,6 +127,7 @@ class OrchestrationAgent:
|
|||||||
return SYSTEM_PROMPT_TEMPLATE.format(
|
return SYSTEM_PROMPT_TEMPLATE.format(
|
||||||
working_dir=WORKING_DIR,
|
working_dir=WORKING_DIR,
|
||||||
active_session_line=active_line,
|
active_session_line=active_line,
|
||||||
|
today=today,
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_active_conv(self, user_id: str) -> Optional[str]:
|
def get_active_conv(self, user_id: str) -> Optional[str]:
|
||||||
@ -181,6 +190,7 @@ class OrchestrationAgent:
|
|||||||
|
|
||||||
reply = ""
|
reply = ""
|
||||||
try:
|
try:
|
||||||
|
web_calls = 0
|
||||||
for iteration in range(MAX_ITERATIONS):
|
for iteration in range(MAX_ITERATIONS):
|
||||||
logger.debug(" LLM call #%d", iteration)
|
logger.debug(" LLM call #%d", iteration)
|
||||||
ai_msg: AIMessage = await self._llm_with_tools.ainvoke(messages)
|
ai_msg: AIMessage = await self._llm_with_tools.ainvoke(messages)
|
||||||
@ -201,6 +211,16 @@ class OrchestrationAgent:
|
|||||||
)
|
)
|
||||||
logger.info(" ⚙ %s(%s)", tool_name, args_summary)
|
logger.info(" ⚙ %s(%s)", tool_name, args_summary)
|
||||||
|
|
||||||
|
if tool_name == "web":
|
||||||
|
web_calls += 1
|
||||||
|
if web_calls > 2:
|
||||||
|
result = "Web search limit reached. Synthesize from results already obtained."
|
||||||
|
logger.warning(" web call limit exceeded, blocking")
|
||||||
|
messages.append(
|
||||||
|
ToolMessage(content=str(result), tool_call_id=tool_id)
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
tool_obj = _TOOL_MAP.get(tool_name)
|
tool_obj = _TOOL_MAP.get(tool_name)
|
||||||
if tool_obj is None:
|
if tool_obj is None:
|
||||||
result = f"Unknown tool: {tool_name}"
|
result = f"Unknown tool: {tool_name}"
|
||||||
|
|||||||
@ -553,18 +553,31 @@ class WebTool(BaseTool):
|
|||||||
payload = {
|
payload = {
|
||||||
"jsonrpc": "2.0",
|
"jsonrpc": "2.0",
|
||||||
"id": 1,
|
"id": 1,
|
||||||
"method": "metaso_web_search",
|
"method": "tools/call",
|
||||||
"params": {"query": query, "scope": scope or "webpage"},
|
"params": {
|
||||||
|
"name": "metaso_web_search",
|
||||||
|
"arguments": {"q": query, "scope": scope or "webpage", "size": 5, "includeSummary": True},
|
||||||
|
},
|
||||||
}
|
}
|
||||||
resp = await client.post(base_url, json=payload, headers=headers)
|
resp = await client.post(base_url, json=payload, headers=headers)
|
||||||
data = resp.json()
|
data = resp.json()
|
||||||
if "error" in data:
|
if "error" in data:
|
||||||
return json.dumps({"error": data["error"]}, ensure_ascii=False)
|
return json.dumps({"error": data["error"]}, ensure_ascii=False)
|
||||||
results = data.get("result", {}).get("results", [])[:5]
|
content_text = data.get("result", {}).get("content", [{}])[0].get("text", "")
|
||||||
|
result_data = json.loads(content_text) if content_text else {}
|
||||||
|
webpages = result_data.get("webpages", [])[:5]
|
||||||
output = []
|
output = []
|
||||||
for r in results:
|
for r in webpages:
|
||||||
output.append(f"**{r.get('title', 'No title')}**\n{r.get('snippet', '')}\n{r.get('url', '')}")
|
date = r.get("date", "")
|
||||||
return json.dumps({"results": "\n\n".join(output)[:max_chars]}, ensure_ascii=False)
|
title = r.get("title", "No title")
|
||||||
|
snippet = r.get("snippet", "")[:300]
|
||||||
|
link = r.get("link", "")
|
||||||
|
output.append(f"[{date}] **{title}**\n{snippet}\n{link}")
|
||||||
|
total = result_data.get("total", 0)
|
||||||
|
return json.dumps({
|
||||||
|
"total": total,
|
||||||
|
"results": "\n\n".join(output)[:max_chars],
|
||||||
|
}, ensure_ascii=False)
|
||||||
|
|
||||||
elif action == "fetch":
|
elif action == "fetch":
|
||||||
if not url:
|
if not url:
|
||||||
@ -572,15 +585,18 @@ class WebTool(BaseTool):
|
|||||||
payload = {
|
payload = {
|
||||||
"jsonrpc": "2.0",
|
"jsonrpc": "2.0",
|
||||||
"id": 1,
|
"id": 1,
|
||||||
"method": "metaso_web_reader",
|
"method": "tools/call",
|
||||||
"params": {"url": url, "format": "markdown"},
|
"params": {
|
||||||
|
"name": "metaso_web_reader",
|
||||||
|
"arguments": {"url": url, "format": "markdown"},
|
||||||
|
},
|
||||||
}
|
}
|
||||||
resp = await client.post(base_url, json=payload, headers=headers)
|
resp = await client.post(base_url, json=payload, headers=headers)
|
||||||
data = resp.json()
|
data = resp.json()
|
||||||
if "error" in data:
|
if "error" in data:
|
||||||
return json.dumps({"error": data["error"]}, ensure_ascii=False)
|
return json.dumps({"error": data["error"]}, ensure_ascii=False)
|
||||||
content = data.get("result", {}).get("content", "")
|
content_text = data.get("result", {}).get("content", [{}])[0].get("text", "")
|
||||||
return json.dumps({"content": content[:max_chars]}, ensure_ascii=False)
|
return json.dumps({"content": content_text[:max_chars]}, ensure_ascii=False)
|
||||||
|
|
||||||
elif action == "ask":
|
elif action == "ask":
|
||||||
if not query:
|
if not query:
|
||||||
@ -588,15 +604,18 @@ class WebTool(BaseTool):
|
|||||||
payload = {
|
payload = {
|
||||||
"jsonrpc": "2.0",
|
"jsonrpc": "2.0",
|
||||||
"id": 1,
|
"id": 1,
|
||||||
"method": "metaso_chat",
|
"method": "tools/call",
|
||||||
"params": {"query": query},
|
"params": {
|
||||||
|
"name": "metaso_chat",
|
||||||
|
"arguments": {"message": query},
|
||||||
|
},
|
||||||
}
|
}
|
||||||
resp = await client.post(base_url, json=payload, headers=headers)
|
resp = await client.post(base_url, json=payload, headers=headers)
|
||||||
data = resp.json()
|
data = resp.json()
|
||||||
if "error" in data:
|
if "error" in data:
|
||||||
return json.dumps({"error": data["error"]}, ensure_ascii=False)
|
return json.dumps({"error": data["error"]}, ensure_ascii=False)
|
||||||
answer = data.get("result", {}).get("answer", "")
|
content_text = data.get("result", {}).get("content", [{}])[0].get("text", "")
|
||||||
return json.dumps({"answer": answer[:max_chars]}, ensure_ascii=False)
|
return json.dumps({"answer": content_text[:max_chars]}, ensure_ascii=False)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
return json.dumps({"error": f"Unknown action: {action}"}, ensure_ascii=False)
|
return json.dumps({"error": f"Unknown action: {action}"}, ensure_ascii=False)
|
||||||
|
|||||||
@ -53,11 +53,11 @@ async def ws_node_endpoint(websocket: WebSocket) -> None:
|
|||||||
"""Send periodic pings to the host client."""
|
"""Send periodic pings to the host client."""
|
||||||
try:
|
try:
|
||||||
while True:
|
while True:
|
||||||
await asyncio.sleep(30)
|
|
||||||
try:
|
try:
|
||||||
await websocket.send_text(encode(Heartbeat(type="ping")))
|
await websocket.send_text(encode(Heartbeat(type="ping")))
|
||||||
except Exception:
|
except Exception:
|
||||||
break
|
break
|
||||||
|
await asyncio.sleep(30)
|
||||||
except asyncio.CancelledError:
|
except asyncio.CancelledError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|||||||
@ -41,6 +41,8 @@ async def run_standalone() -> None:
|
|||||||
host="0.0.0.0",
|
host="0.0.0.0",
|
||||||
port=8000,
|
port=8000,
|
||||||
log_level="info",
|
log_level="info",
|
||||||
|
ws_ping_interval=20,
|
||||||
|
ws_ping_timeout=60,
|
||||||
)
|
)
|
||||||
server = uvicorn.Server(config_obj)
|
server = uvicorn.Server(config_obj)
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user