"""Shared protocol module for Router <-> Host Client communication. All message types are dataclasses that serialize to/from JSON. Both router and host client import from this module. """ from __future__ import annotations import json from dataclasses import asdict, dataclass, field from typing import Any, Dict, List, Optional @dataclass class RegisterMessage: """Host client -> Router: Register this node.""" type: str = "register" node_id: str = "" serves_users: List[str] = field(default_factory=list) working_dir: str = "" capabilities: List[str] = field(default_factory=list) display_name: str = "" @dataclass class ForwardRequest: """Router -> Host client: Forward a user message.""" type: str = "forward" id: str = "" user_id: str = "" chat_id: str = "" text: str = "" @dataclass class ForwardResponse: """Host client -> Router: Reply to a forwarded message.""" type: str = "forward_response" id: str = "" reply: str = "" error: str = "" @dataclass class TaskComplete: """Host client -> Router: Background task finished.""" type: str = "task_complete" task_id: str = "" user_id: str = "" chat_id: str = "" result: str = "" @dataclass class Heartbeat: """Bidirectional ping/pong.""" type: str = "ping" @dataclass class NodeStatus: """Host client -> Router: Periodic status update.""" type: str = "node_status" node_id: str = "" sessions: int = 0 active_sessions: List[Dict[str, Any]] = field(default_factory=list) MESSAGE_TYPES = { "register": RegisterMessage, "forward": ForwardRequest, "forward_response": ForwardResponse, "task_complete": TaskComplete, "ping": Heartbeat, "pong": Heartbeat, "node_status": NodeStatus, } def encode(msg: Any) -> str: """Encode a message to JSON string.""" if hasattr(msg, "type"): return json.dumps(asdict(msg), ensure_ascii=False) raise ValueError(f"Invalid message type: {type(msg)}") def decode(data: str) -> Any: """Decode a JSON string to a message object.""" obj = json.loads(data) msg_type = obj.get("type") if msg_type not in MESSAGE_TYPES: raise ValueError(f"Unknown message type: {msg_type}") cls = MESSAGE_TYPES[msg_type] return cls(**{k: v for k, v in obj.items() if k in cls.__dataclass_fields__})