feat: implement assistant agent type (issue #86)
Add AssistantAgent class with conversational system prompt, minimal tool set (web_search, memory_read only), and AGENT_TYPE_ASSISTANT routing. 29 tests, 98% coverage on new code. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -88,6 +88,7 @@
|
||||
| #83 | Implement run_code tool executor | Phase 10 | `COMPLETED` | Rust | [issue-083.md](issue-083.md) |
|
||||
| #85 | Implement run_shell and package_install executors | Phase 10 | `COMPLETED` | Rust | [issue-085.md](issue-085.md) |
|
||||
| #81 | Implement coder agent type | Phase 10 | `COMPLETED` | Python | [issue-081.md](issue-081.md) |
|
||||
| #86 | Implement assistant agent type | Phase 10 | `IMPLEMENTING` | Python | [issue-086.md](issue-086.md) |
|
||||
|
||||
## Status Legend
|
||||
|
||||
|
||||
100
implementation-plans/issue-086.md
Normal file
100
implementation-plans/issue-086.md
Normal file
@@ -0,0 +1,100 @@
|
||||
# Issue #86: Implement assistant agent type
|
||||
|
||||
## Metadata
|
||||
|
||||
| Field | Value |
|
||||
|---|---|
|
||||
| Issue | #86 |
|
||||
| Title | Implement assistant agent type |
|
||||
| Milestone | Phase 10: Remaining Agent Types |
|
||||
| Status | `PLANNED` |
|
||||
| Language | Python |
|
||||
| Related Plans | issue-069.md, issue-081.md |
|
||||
| Blocked by | #69, #52 |
|
||||
|
||||
## Acceptance Criteria
|
||||
|
||||
- [ ] Assistant system prompt emphasizing helpful conversation, summarization, Q&A
|
||||
- [ ] Agent manifest with allowed tools: web_search, memory_query only
|
||||
- [ ] No file system or code execution access
|
||||
- [ ] No network egress beyond search
|
||||
- [ ] Uses the same agent loop infrastructure as researcher
|
||||
- [ ] Optimized for conversational response quality
|
||||
|
||||
## Architecture Analysis
|
||||
|
||||
### Service Context
|
||||
- Orchestrator service — new agent type alongside researcher and coder
|
||||
- Uses the same `PromptBuilder` with a custom `ASSISTANT_SYSTEM_PROMPT`
|
||||
- Routes via `AGENT_TYPE_ASSISTANT` (proto value 5) in `service.py`
|
||||
|
||||
### Existing Patterns
|
||||
- `researcher.py` — base pattern (web_search + memory_read only)
|
||||
- `coder.py` — extended pattern with tool-output source tracking
|
||||
- `prompt.py` — parameterized `PromptBuilder` with `system_prompt` field
|
||||
- `config.py` — `AgentConfig` dataclass reused per agent type
|
||||
|
||||
### Dependencies
|
||||
- `common_pb2.AGENT_TYPE_ASSISTANT` (value 5) — already defined in proto
|
||||
- No new proto changes needed
|
||||
- No new external libraries
|
||||
|
||||
## Implementation Steps
|
||||
|
||||
### 1. Types & Configuration
|
||||
- Add `assistant: AgentConfig` field to `Config` dataclass
|
||||
- Add YAML loading for `assistant` section in `Config.load()`
|
||||
|
||||
### 2. Core Logic
|
||||
- Create `AssistantAgent` class in `assistant.py`
|
||||
- Mirror researcher pattern exactly (web_search + memory_read tools only)
|
||||
- No `_TOOL_OUTPUT_SOURCES` tracking (assistant has no fs/code tools)
|
||||
- Source tracking: web_search → `RESULT_SOURCE_WEB`, else `RESULT_SOURCE_MODEL_KNOWLEDGE`
|
||||
- Agent ID prefix: `ast-{uuid}`
|
||||
- Agent-specific messages: "Assistance incomplete: {reason}", "Assistance failed."
|
||||
|
||||
### 3. System Prompt
|
||||
- Add `ASSISTANT_SYSTEM_PROMPT` to `prompt.py`
|
||||
- Conversational tone, emphasizing helpful responses, summarization, Q&A
|
||||
- Capabilities: web_search, memory_read only
|
||||
- No file system or code references
|
||||
- Encourage clear, well-structured answers
|
||||
|
||||
### 4. Service Integration
|
||||
- Import `AssistantAgent` in `service.py`
|
||||
- Instantiate with `self._config.assistant`
|
||||
- Route `AGENT_TYPE_ASSISTANT` subtasks to `assistant.run()`
|
||||
|
||||
### 5. Tests
|
||||
- Create `tests/test_assistant.py` mirroring `test_researcher.py`
|
||||
- Key assistant-specific tests:
|
||||
- Uses AGENT_TYPE_ASSISTANT for tool discovery/execution
|
||||
- Agent ID prefix is `ast-`
|
||||
- Uses ASSISTANT_SYSTEM_PROMPT
|
||||
- Only web_search and memory_read tools available
|
||||
- Source tracking: web → WEB, default → MODEL_KNOWLEDGE
|
||||
- No tool-output source (no fs/code tools)
|
||||
- Standard tests: happy path, tool discovery, empty tools, execution,
|
||||
memory, max iterations, timeout, failures, gateway errors, confidence,
|
||||
unknown tools, discovery errors, memory unavailable, parse errors,
|
||||
reasoning, execution errors, compaction
|
||||
|
||||
## Files to Create/Modify
|
||||
|
||||
| File | Action | Purpose |
|
||||
|---|---|---|
|
||||
| `services/orchestrator/src/orchestrator/assistant.py` | Create | AssistantAgent class |
|
||||
| `services/orchestrator/tests/test_assistant.py` | Create | Unit tests |
|
||||
| `services/orchestrator/src/orchestrator/prompt.py` | Modify | Add ASSISTANT_SYSTEM_PROMPT |
|
||||
| `services/orchestrator/src/orchestrator/config.py` | Modify | Add assistant AgentConfig |
|
||||
| `services/orchestrator/src/orchestrator/service.py` | Modify | Route AGENT_TYPE_ASSISTANT |
|
||||
|
||||
## Risks and Edge Cases
|
||||
|
||||
- Assistant is closest to researcher pattern — minimal risk
|
||||
- Must ensure no fs/code tools leak through (enforced by tool broker manifest, not agent code)
|
||||
|
||||
## Deviation Log
|
||||
|
||||
| Deviation | Reason |
|
||||
|---|---|
|
||||
286
services/orchestrator/src/orchestrator/assistant.py
Normal file
286
services/orchestrator/src/orchestrator/assistant.py
Normal file
@@ -0,0 +1,286 @@
|
||||
"""Assistant agent — conversational Q&A, summarization, and information retrieval."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import uuid
|
||||
|
||||
import grpc
|
||||
from llm_multiverse.v1 import common_pb2, orchestrator_pb2
|
||||
|
||||
from .clients import MemoryClient, ModelGatewayClient, ToolBrokerClient
|
||||
from .config import AgentConfig
|
||||
from .parser import DoneSignalParsed, ParseError, ToolCallParsed, parse_agent_response
|
||||
from .prompt import ASSISTANT_SYSTEM_PROMPT, PromptBuilder
|
||||
|
||||
logger = logging.getLogger("orchestrator.assistant")
|
||||
|
||||
_CONFIDENCE_MAP = {
|
||||
"VERIFIED": common_pb2.RESULT_QUALITY_VERIFIED,
|
||||
"INFERRED": common_pb2.RESULT_QUALITY_INFERRED,
|
||||
"UNCERTAIN": common_pb2.RESULT_QUALITY_UNCERTAIN,
|
||||
}
|
||||
|
||||
_MAX_CONSECUTIVE_FAILURES = 3
|
||||
|
||||
|
||||
class AssistantAgent:
|
||||
"""Executes the assistant agent loop: infer -> tool call -> observe -> repeat."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_gateway: ModelGatewayClient,
|
||||
tool_broker: ToolBrokerClient,
|
||||
memory: MemoryClient,
|
||||
config: AgentConfig,
|
||||
) -> None:
|
||||
self._gateway = model_gateway
|
||||
self._broker = tool_broker
|
||||
self._memory = memory
|
||||
self._config = config
|
||||
|
||||
async def run(
|
||||
self, request: orchestrator_pb2.SubagentRequest
|
||||
) -> common_pb2.SubagentResult:
|
||||
"""Execute the assistant loop for a given task."""
|
||||
session_ctx = request.context
|
||||
agent_id = request.agent_id or f"ast-{uuid.uuid4().hex[:8]}"
|
||||
|
||||
# Step 1: Discover available tools
|
||||
try:
|
||||
tools = await self._broker.discover_tools(
|
||||
session_ctx, common_pb2.AGENT_TYPE_ASSISTANT, request.task
|
||||
)
|
||||
except grpc.aio.AioRpcError as e:
|
||||
return _build_failed_result(f"Tool discovery failed: {e.code()}")
|
||||
|
||||
if not tools:
|
||||
return _build_failed_result("No tools available for assistant")
|
||||
|
||||
tool_names = {t.name for t in tools}
|
||||
|
||||
# Step 2: Memory context enrichment
|
||||
memory_context = list(request.relevant_memory_context)
|
||||
if not memory_context:
|
||||
try:
|
||||
mem_results = await self._memory.query_memory(
|
||||
session_ctx, request.task, limit=3
|
||||
)
|
||||
memory_context = [_format_memory(r) for r in mem_results]
|
||||
except grpc.aio.AioRpcError:
|
||||
logger.warning("Memory service unavailable, proceeding without context")
|
||||
|
||||
# Step 3: Initialize prompt builder with assistant system prompt
|
||||
max_tokens = request.max_tokens or self._config.max_tokens
|
||||
prompt_builder = PromptBuilder(
|
||||
max_tokens=max_tokens,
|
||||
system_prompt=ASSISTANT_SYSTEM_PROMPT,
|
||||
)
|
||||
prompt_builder.set_task(request.task)
|
||||
prompt_builder.set_tool_definitions(tools)
|
||||
prompt_builder.set_memory_context(memory_context)
|
||||
|
||||
# Step 4: Agent loop
|
||||
iteration = 0
|
||||
consecutive_failures = 0
|
||||
used_web_search = False
|
||||
start_time = asyncio.get_event_loop().time()
|
||||
|
||||
while True:
|
||||
# Termination checks
|
||||
if iteration >= self._config.max_iterations:
|
||||
return _build_partial_result(
|
||||
"Max iterations reached", used_web_search
|
||||
)
|
||||
|
||||
elapsed = asyncio.get_event_loop().time() - start_time
|
||||
if elapsed >= self._config.timeout_seconds:
|
||||
return _build_partial_result("Timeout", used_web_search)
|
||||
|
||||
if prompt_builder.needs_compaction():
|
||||
compacted = await prompt_builder.compact(
|
||||
self._gateway, session_ctx
|
||||
)
|
||||
if not compacted:
|
||||
return _build_partial_result(
|
||||
"Context overflow after compaction",
|
||||
used_web_search,
|
||||
)
|
||||
logger.info(
|
||||
"Context compacted for agent %s (iteration %d)",
|
||||
agent_id,
|
||||
iteration,
|
||||
)
|
||||
|
||||
if consecutive_failures >= _MAX_CONSECUTIVE_FAILURES:
|
||||
return _build_failed_result(
|
||||
f"All tools failed after {_MAX_CONSECUTIVE_FAILURES} consecutive failures"
|
||||
)
|
||||
|
||||
# Build prompt and call model
|
||||
prompt = prompt_builder.build()
|
||||
try:
|
||||
response_text = await self._gateway.stream_inference(
|
||||
session_ctx, prompt, max_tokens=max_tokens
|
||||
)
|
||||
except grpc.aio.AioRpcError as e:
|
||||
return _build_failed_result(f"Model gateway error: {e.code()}")
|
||||
|
||||
# Parse model output
|
||||
try:
|
||||
parsed = parse_agent_response(response_text)
|
||||
except ParseError as e:
|
||||
prompt_builder.add_reasoning(
|
||||
f"[Parse error: {e}. Please output valid JSON.]"
|
||||
)
|
||||
iteration += 1
|
||||
continue
|
||||
|
||||
if parsed is None:
|
||||
# Plain reasoning text — append and continue
|
||||
prompt_builder.add_reasoning(response_text)
|
||||
iteration += 1
|
||||
continue
|
||||
|
||||
if isinstance(parsed, DoneSignalParsed):
|
||||
return _build_success_result(parsed, used_web_search)
|
||||
|
||||
if isinstance(parsed, ToolCallParsed):
|
||||
# Validate tool name
|
||||
if parsed.tool not in tool_names:
|
||||
prompt_builder.add_tool_result(
|
||||
parsed.tool,
|
||||
f"Error: tool '{parsed.tool}' is not available. "
|
||||
f"Available tools: {', '.join(sorted(tool_names))}",
|
||||
False,
|
||||
agent_id,
|
||||
session_ctx.session_id,
|
||||
)
|
||||
iteration += 1
|
||||
consecutive_failures += 1
|
||||
continue
|
||||
|
||||
# Execute tool
|
||||
prompt_builder.add_tool_call(parsed.tool, parsed.parameters)
|
||||
try:
|
||||
output, success = await self._broker.execute_tool(
|
||||
session_ctx,
|
||||
common_pb2.AGENT_TYPE_ASSISTANT,
|
||||
parsed.tool,
|
||||
parsed.parameters,
|
||||
)
|
||||
except grpc.aio.AioRpcError as e:
|
||||
output = f"Tool execution error: {e.code()}"
|
||||
success = False
|
||||
|
||||
prompt_builder.add_tool_result(
|
||||
parsed.tool,
|
||||
output,
|
||||
success,
|
||||
agent_id,
|
||||
session_ctx.session_id,
|
||||
)
|
||||
|
||||
if parsed.tool == "web_search" and success:
|
||||
used_web_search = True
|
||||
|
||||
if success:
|
||||
consecutive_failures = 0
|
||||
else:
|
||||
consecutive_failures += 1
|
||||
|
||||
iteration += 1
|
||||
|
||||
# Unreachable, but satisfies type checker
|
||||
return _build_failed_result("Unexpected loop exit") # pragma: no cover
|
||||
|
||||
|
||||
def create_assistant_agent(
|
||||
model_gateway_channel: grpc.aio.Channel,
|
||||
tool_broker_channel: grpc.aio.Channel,
|
||||
memory_channel: grpc.aio.Channel,
|
||||
config: AgentConfig,
|
||||
) -> AssistantAgent:
|
||||
"""Create an AssistantAgent with gRPC client wrappers."""
|
||||
return AssistantAgent(
|
||||
model_gateway=ModelGatewayClient(model_gateway_channel),
|
||||
tool_broker=ToolBrokerClient(tool_broker_channel),
|
||||
memory=MemoryClient(memory_channel),
|
||||
config=config,
|
||||
)
|
||||
|
||||
|
||||
def _build_success_result(
|
||||
parsed: DoneSignalParsed, used_web: bool
|
||||
) -> common_pb2.SubagentResult:
|
||||
quality = _CONFIDENCE_MAP.get(
|
||||
parsed.confidence, common_pb2.RESULT_QUALITY_UNCERTAIN
|
||||
)
|
||||
|
||||
status = common_pb2.RESULT_STATUS_SUCCESS
|
||||
if parsed.confidence == "UNCERTAIN":
|
||||
status = common_pb2.RESULT_STATUS_PARTIAL
|
||||
|
||||
source = (
|
||||
common_pb2.RESULT_SOURCE_WEB
|
||||
if used_web
|
||||
else common_pb2.RESULT_SOURCE_MODEL_KNOWLEDGE
|
||||
)
|
||||
|
||||
candidates = [
|
||||
common_pb2.MemoryCandidate(
|
||||
content=mc.content,
|
||||
source=source,
|
||||
confidence=mc.confidence,
|
||||
)
|
||||
for mc in parsed.memory_candidates
|
||||
]
|
||||
|
||||
return common_pb2.SubagentResult(
|
||||
status=status,
|
||||
summary=parsed.summary,
|
||||
artifacts=parsed.findings,
|
||||
result_quality=quality,
|
||||
source=source,
|
||||
new_memory_candidates=candidates,
|
||||
)
|
||||
|
||||
|
||||
def _build_partial_result(
|
||||
reason: str, used_web: bool
|
||||
) -> common_pb2.SubagentResult:
|
||||
source = (
|
||||
common_pb2.RESULT_SOURCE_WEB
|
||||
if used_web
|
||||
else common_pb2.RESULT_SOURCE_MODEL_KNOWLEDGE
|
||||
)
|
||||
return common_pb2.SubagentResult(
|
||||
status=common_pb2.RESULT_STATUS_PARTIAL,
|
||||
summary=f"Assistance incomplete: {reason}",
|
||||
result_quality=common_pb2.RESULT_QUALITY_UNCERTAIN,
|
||||
source=source,
|
||||
)
|
||||
|
||||
|
||||
def _build_failed_result(reason: str) -> common_pb2.SubagentResult:
|
||||
return common_pb2.SubagentResult(
|
||||
status=common_pb2.RESULT_STATUS_FAILED,
|
||||
summary="Assistance failed.",
|
||||
result_quality=common_pb2.RESULT_QUALITY_UNCERTAIN,
|
||||
source=common_pb2.RESULT_SOURCE_MODEL_KNOWLEDGE,
|
||||
failure_reason=reason,
|
||||
)
|
||||
|
||||
|
||||
def _format_memory(mem_response) -> str:
|
||||
"""Format a QueryMemoryResponse into a context string."""
|
||||
entry = mem_response.entry
|
||||
parts = [f"[Memory: {entry.name}]"]
|
||||
if entry.description:
|
||||
parts.append(entry.description)
|
||||
if mem_response.cached_extracted_segment:
|
||||
parts.append(mem_response.cached_extracted_segment)
|
||||
elif entry.corpus:
|
||||
parts.append(entry.corpus[:500])
|
||||
return "\n".join(parts)
|
||||
@@ -95,6 +95,7 @@ class Config:
|
||||
# Per-agent-type configuration
|
||||
researcher: AgentConfig = field(default_factory=AgentConfig)
|
||||
coder: AgentConfig = field(default_factory=AgentConfig)
|
||||
assistant: AgentConfig = field(default_factory=AgentConfig)
|
||||
decomposer: DecomposerConfig = field(default_factory=DecomposerConfig)
|
||||
dispatcher: DispatcherConfig = field(default_factory=DispatcherConfig)
|
||||
compaction: CompactionConfig = field(default_factory=CompactionConfig)
|
||||
@@ -141,6 +142,17 @@ class Config:
|
||||
max_tokens=coder_data.get("max_tokens", AgentConfig.max_tokens),
|
||||
)
|
||||
|
||||
assistant_data = data.get("assistant", {})
|
||||
assistant = AgentConfig(
|
||||
max_iterations=assistant_data.get(
|
||||
"max_iterations", AgentConfig.max_iterations
|
||||
),
|
||||
timeout_seconds=assistant_data.get(
|
||||
"timeout_seconds", AgentConfig.timeout_seconds
|
||||
),
|
||||
max_tokens=assistant_data.get("max_tokens", AgentConfig.max_tokens),
|
||||
)
|
||||
|
||||
decomposer_data = data.get("decomposer", {})
|
||||
decomposer = DecomposerConfig(
|
||||
max_tokens=decomposer_data.get(
|
||||
@@ -255,6 +267,7 @@ class Config:
|
||||
audit_addr=data.get("audit_addr"),
|
||||
researcher=researcher,
|
||||
coder=coder,
|
||||
assistant=assistant,
|
||||
decomposer=decomposer,
|
||||
dispatcher=dispatcher,
|
||||
compaction=compaction,
|
||||
|
||||
@@ -133,6 +133,56 @@ When done, respond with a JSON result block:
|
||||
"""
|
||||
|
||||
|
||||
ASSISTANT_SYSTEM_PROMPT = """\
|
||||
You are an Assistant agent in a multi-agent system. Your role is to provide \
|
||||
helpful, clear, and well-structured answers to user questions.
|
||||
|
||||
## Capabilities
|
||||
- Web search via the `web_search` tool
|
||||
- Memory retrieval via the `memory_read` tool
|
||||
|
||||
## Instructions
|
||||
1. Analyze the user's question to understand what they need.
|
||||
2. If the answer requires current or factual information, use `web_search`.
|
||||
3. If the answer may benefit from prior knowledge, use `memory_read`.
|
||||
4. Synthesize information into a clear, conversational response.
|
||||
5. Structure your response for readability: use headings, lists, or short paragraphs.
|
||||
|
||||
## Tool Use Format
|
||||
To use a tool, respond with a JSON tool call block:
|
||||
{"tool": "<tool_name>", "parameters": {"<key>": "<value>"}}
|
||||
|
||||
After receiving a tool result, analyze it and decide whether to:
|
||||
- Make another tool call for more information
|
||||
- Produce your final response
|
||||
|
||||
## Response Guidelines
|
||||
- Be concise but thorough. Prefer clarity over brevity.
|
||||
- When summarizing, capture the key points and omit unnecessary detail.
|
||||
- For Q&A, lead with the direct answer, then provide supporting context.
|
||||
- Use numbered lists for step-by-step instructions.
|
||||
- Use bullet points for enumerating options or features.
|
||||
- Acknowledge uncertainty honestly rather than guessing.
|
||||
|
||||
## Confidence Signaling
|
||||
Rate your confidence in your response:
|
||||
- VERIFIED: Information confirmed by tool output from reliable sources
|
||||
- INFERRED: Reasonable conclusion drawn from available evidence
|
||||
- UNCERTAIN: Best guess with limited supporting evidence
|
||||
|
||||
## Completion
|
||||
When done, respond with a JSON result block:
|
||||
{"done": true, "summary": "<3 sentences max>", "findings": ["<finding1>", ...], \
|
||||
"confidence": "VERIFIED|INFERRED|UNCERTAIN", \
|
||||
"memory_candidates": [{"content": "<fact>", "confidence": 0.0-1.0}]}
|
||||
|
||||
## Constraints
|
||||
- Do NOT fabricate information. If you cannot find an answer, say so.
|
||||
- Do NOT attempt to use tools not listed in your capabilities.
|
||||
- Keep your final summary to 3 sentences or fewer.\
|
||||
"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class HistoryEntry:
|
||||
"""A single entry in the agent's interaction history."""
|
||||
|
||||
@@ -20,6 +20,7 @@ from .context import OrchestratorContext
|
||||
from .decomposer import TaskDecomposer
|
||||
from .dispatcher import SubagentDispatcher
|
||||
from .memory_gate import MemoryWriteGate
|
||||
from .assistant import AssistantAgent
|
||||
from .coder import CoderAgent
|
||||
from .researcher import ResearcherAgent
|
||||
from .session_config import SessionConfigApplicator
|
||||
@@ -126,12 +127,20 @@ class OrchestratorServiceImpl(orchestrator_pb2_grpc.OrchestratorServiceServicer)
|
||||
memory=self._memory,
|
||||
config=self._config.coder,
|
||||
)
|
||||
assistant = AssistantAgent(
|
||||
model_gateway=self._gateway,
|
||||
tool_broker=self._broker,
|
||||
memory=self._memory,
|
||||
config=self._config.assistant,
|
||||
)
|
||||
|
||||
async def runner(
|
||||
req: orchestrator_pb2.SubagentRequest,
|
||||
) -> common_pb2.SubagentResult:
|
||||
if req.agent_type == common_pb2.AGENT_TYPE_CODER:
|
||||
return await coder.run(req)
|
||||
if req.agent_type == common_pb2.AGENT_TYPE_ASSISTANT:
|
||||
return await assistant.run(req)
|
||||
# Default to researcher for all other agent types.
|
||||
return await researcher.run(req)
|
||||
|
||||
|
||||
480
services/orchestrator/tests/test_assistant.py
Normal file
480
services/orchestrator/tests/test_assistant.py
Normal file
@@ -0,0 +1,480 @@
|
||||
"""Tests for the assistant agent loop."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
from llm_multiverse.v1 import common_pb2, memory_pb2, orchestrator_pb2, tool_broker_pb2
|
||||
|
||||
from orchestrator.assistant import AssistantAgent, _format_memory
|
||||
from orchestrator.config import AgentConfig
|
||||
|
||||
|
||||
def _make_request(
|
||||
task: str = "Explain how Docker works",
|
||||
memory_context: list[str] | None = None,
|
||||
max_tokens: int = 4096,
|
||||
) -> orchestrator_pb2.SubagentRequest:
|
||||
ctx = common_pb2.SessionContext(session_id="sess-1", user_id="user-1")
|
||||
req = orchestrator_pb2.SubagentRequest(
|
||||
context=ctx,
|
||||
agent_id="ast-test",
|
||||
agent_type=common_pb2.AGENT_TYPE_ASSISTANT,
|
||||
task=task,
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
if memory_context:
|
||||
req.relevant_memory_context.extend(memory_context)
|
||||
return req
|
||||
|
||||
|
||||
def _make_tool(name: str) -> tool_broker_pb2.ToolDefinition:
|
||||
return tool_broker_pb2.ToolDefinition(
|
||||
name=name,
|
||||
description=f"{name} tool",
|
||||
parameters={
|
||||
"query": tool_broker_pb2.ParameterSchema(
|
||||
type="string", description="Query"
|
||||
)
|
||||
},
|
||||
required_params=["query"],
|
||||
)
|
||||
|
||||
|
||||
def _make_agent(
|
||||
gateway_responses: list[str] | None = None,
|
||||
tools: list[tool_broker_pb2.ToolDefinition] | None = None,
|
||||
exec_output: str = "search results",
|
||||
exec_success: bool = True,
|
||||
memory_results: list | None = None,
|
||||
config: AgentConfig | None = None,
|
||||
) -> AssistantAgent:
|
||||
gateway = AsyncMock()
|
||||
broker = AsyncMock()
|
||||
memory = AsyncMock()
|
||||
|
||||
if gateway_responses is not None:
|
||||
gateway.stream_inference = AsyncMock(side_effect=gateway_responses)
|
||||
else:
|
||||
gateway.stream_inference = AsyncMock(
|
||||
return_value='{"done": true, "summary": "Done.", "confidence": "VERIFIED"}'
|
||||
)
|
||||
|
||||
if tools is not None:
|
||||
broker.discover_tools = AsyncMock(return_value=tools)
|
||||
else:
|
||||
broker.discover_tools = AsyncMock(
|
||||
return_value=[
|
||||
_make_tool("web_search"),
|
||||
_make_tool("memory_read"),
|
||||
]
|
||||
)
|
||||
|
||||
broker.execute_tool = AsyncMock(return_value=(exec_output, exec_success))
|
||||
|
||||
if memory_results is not None:
|
||||
memory.query_memory = AsyncMock(return_value=memory_results)
|
||||
else:
|
||||
memory.query_memory = AsyncMock(return_value=[])
|
||||
|
||||
return AssistantAgent(
|
||||
model_gateway=gateway,
|
||||
tool_broker=broker,
|
||||
memory=memory,
|
||||
config=config or AgentConfig(),
|
||||
)
|
||||
|
||||
|
||||
async def test_simple_assistant_task():
|
||||
"""Model searches web, then signals done."""
|
||||
agent = _make_agent(
|
||||
gateway_responses=[
|
||||
'{"tool": "web_search", "parameters": {"query": "how Docker works"}}',
|
||||
'{"done": true, "summary": "Docker uses containers.", "findings": ["Uses namespaces"], "confidence": "VERIFIED"}',
|
||||
],
|
||||
exec_output="Docker is a containerization platform...",
|
||||
)
|
||||
result = await agent.run(_make_request())
|
||||
assert result.status == common_pb2.RESULT_STATUS_SUCCESS
|
||||
assert "Docker" in result.summary
|
||||
assert result.artifacts == ["Uses namespaces"]
|
||||
assert result.result_quality == common_pb2.RESULT_QUALITY_VERIFIED
|
||||
|
||||
|
||||
async def test_tool_discovery_uses_assistant_type():
|
||||
"""Verify DiscoverTools uses AGENT_TYPE_ASSISTANT."""
|
||||
agent = _make_agent()
|
||||
await agent.run(_make_request())
|
||||
agent._broker.discover_tools.assert_called_once()
|
||||
args, kwargs = agent._broker.discover_tools.call_args
|
||||
all_values = list(args) + list(kwargs.values())
|
||||
assert common_pb2.AGENT_TYPE_ASSISTANT in all_values
|
||||
|
||||
|
||||
async def test_no_tools_available():
|
||||
"""Discover returns empty list — agent fails."""
|
||||
agent = _make_agent(tools=[])
|
||||
result = await agent.run(_make_request())
|
||||
assert result.status == common_pb2.RESULT_STATUS_FAILED
|
||||
assert "No tools" in result.failure_reason
|
||||
|
||||
|
||||
async def test_tool_execution_uses_assistant_type():
|
||||
"""Verify ExecuteTool uses AGENT_TYPE_ASSISTANT."""
|
||||
agent = _make_agent(
|
||||
gateway_responses=[
|
||||
'{"tool": "web_search", "parameters": {"query": "test"}}',
|
||||
'{"done": true, "summary": "Done.", "confidence": "VERIFIED"}',
|
||||
],
|
||||
)
|
||||
await agent.run(_make_request())
|
||||
agent._broker.execute_tool.assert_called_once()
|
||||
args, kwargs = agent._broker.execute_tool.call_args
|
||||
all_values = list(args) + list(kwargs.values())
|
||||
assert common_pb2.AGENT_TYPE_ASSISTANT in all_values
|
||||
|
||||
|
||||
async def test_memory_context_from_request():
|
||||
"""Pre-filled memory context skips QueryMemory call."""
|
||||
agent = _make_agent()
|
||||
request = _make_request(memory_context=["Pre-loaded memory"])
|
||||
await agent.run(request)
|
||||
agent._memory.query_memory.assert_not_called()
|
||||
|
||||
|
||||
async def test_memory_query_enrichment():
|
||||
"""No pre-filled memory triggers QueryMemory call."""
|
||||
mem_result = memory_pb2.QueryMemoryResponse(
|
||||
rank=0,
|
||||
entry=memory_pb2.MemoryEntry(name="relevant-mem", description="Some info"),
|
||||
cosine_similarity=0.9,
|
||||
)
|
||||
agent = _make_agent(memory_results=[mem_result])
|
||||
await agent.run(_make_request())
|
||||
agent._memory.query_memory.assert_called_once()
|
||||
|
||||
|
||||
async def test_max_iterations_termination():
|
||||
"""Always returning tool calls hits max iterations."""
|
||||
agent = _make_agent(
|
||||
gateway_responses=[
|
||||
'{"tool": "web_search", "parameters": {"query": "test"}}'
|
||||
]
|
||||
* 15,
|
||||
config=AgentConfig(max_iterations=3),
|
||||
)
|
||||
result = await agent.run(_make_request())
|
||||
assert result.status == common_pb2.RESULT_STATUS_PARTIAL
|
||||
assert "Max iterations" in result.summary
|
||||
|
||||
|
||||
async def test_timeout_termination():
|
||||
"""Timeout triggers partial result."""
|
||||
agent = _make_agent(config=AgentConfig(timeout_seconds=0))
|
||||
agent._gateway.stream_inference = AsyncMock(
|
||||
return_value='{"tool": "web_search", "parameters": {"query": "test"}}'
|
||||
)
|
||||
result = await agent.run(_make_request())
|
||||
assert result.status in (
|
||||
common_pb2.RESULT_STATUS_PARTIAL,
|
||||
common_pb2.RESULT_STATUS_SUCCESS,
|
||||
)
|
||||
|
||||
|
||||
async def test_consecutive_tool_failures():
|
||||
"""Three consecutive tool failures -> FAILED."""
|
||||
agent = _make_agent(
|
||||
gateway_responses=[
|
||||
'{"tool": "web_search", "parameters": {"query": "test"}}'
|
||||
]
|
||||
* 5,
|
||||
exec_success=False,
|
||||
exec_output="error",
|
||||
config=AgentConfig(max_iterations=10),
|
||||
)
|
||||
result = await agent.run(_make_request())
|
||||
assert result.status == common_pb2.RESULT_STATUS_FAILED
|
||||
assert "consecutive failures" in result.failure_reason
|
||||
|
||||
|
||||
async def test_model_gateway_error():
|
||||
"""Gateway UNAVAILABLE -> FAILED."""
|
||||
import grpc
|
||||
|
||||
agent = _make_agent()
|
||||
agent._gateway.stream_inference = AsyncMock(
|
||||
side_effect=grpc.aio.AioRpcError(
|
||||
grpc.StatusCode.UNAVAILABLE,
|
||||
initial_metadata=grpc.aio.Metadata(),
|
||||
trailing_metadata=grpc.aio.Metadata(),
|
||||
details="down",
|
||||
)
|
||||
)
|
||||
result = await agent.run(_make_request())
|
||||
assert result.status == common_pb2.RESULT_STATUS_FAILED
|
||||
assert "gateway" in result.failure_reason.lower()
|
||||
|
||||
|
||||
async def test_confidence_verified_maps_to_success():
|
||||
agent = _make_agent(
|
||||
gateway_responses=[
|
||||
'{"done": true, "summary": "Done.", "confidence": "VERIFIED"}'
|
||||
]
|
||||
)
|
||||
result = await agent.run(_make_request())
|
||||
assert result.status == common_pb2.RESULT_STATUS_SUCCESS
|
||||
assert result.result_quality == common_pb2.RESULT_QUALITY_VERIFIED
|
||||
|
||||
|
||||
async def test_confidence_uncertain_maps_to_partial():
|
||||
agent = _make_agent(
|
||||
gateway_responses=[
|
||||
'{"done": true, "summary": "Not sure.", "confidence": "UNCERTAIN"}'
|
||||
]
|
||||
)
|
||||
result = await agent.run(_make_request())
|
||||
assert result.status == common_pb2.RESULT_STATUS_PARTIAL
|
||||
assert result.result_quality == common_pb2.RESULT_QUALITY_UNCERTAIN
|
||||
|
||||
|
||||
async def test_web_search_sets_source_web():
|
||||
"""Using web_search sets source to RESULT_SOURCE_WEB."""
|
||||
agent = _make_agent(
|
||||
gateway_responses=[
|
||||
'{"tool": "web_search", "parameters": {"query": "python docs"}}',
|
||||
'{"done": true, "summary": "Found docs.", "confidence": "VERIFIED"}',
|
||||
],
|
||||
exec_output="web results",
|
||||
)
|
||||
result = await agent.run(_make_request())
|
||||
assert result.source == common_pb2.RESULT_SOURCE_WEB
|
||||
|
||||
|
||||
async def test_no_tools_sets_source_model_knowledge():
|
||||
"""No tool usage sets source to RESULT_SOURCE_MODEL_KNOWLEDGE."""
|
||||
agent = _make_agent(
|
||||
gateway_responses=[
|
||||
'{"done": true, "summary": "I know this.", "confidence": "VERIFIED"}'
|
||||
]
|
||||
)
|
||||
result = await agent.run(_make_request())
|
||||
assert result.source == common_pb2.RESULT_SOURCE_MODEL_KNOWLEDGE
|
||||
|
||||
|
||||
async def test_memory_read_does_not_set_web_source():
|
||||
"""Using memory_read (not web_search) keeps source as MODEL_KNOWLEDGE."""
|
||||
agent = _make_agent(
|
||||
gateway_responses=[
|
||||
'{"tool": "memory_read", "parameters": {"query": "prior info"}}',
|
||||
'{"done": true, "summary": "Found in memory.", "confidence": "VERIFIED"}',
|
||||
],
|
||||
exec_output="memory results",
|
||||
)
|
||||
result = await agent.run(_make_request())
|
||||
assert result.source == common_pb2.RESULT_SOURCE_MODEL_KNOWLEDGE
|
||||
|
||||
|
||||
async def test_unknown_tool_reports_error():
|
||||
"""Model calls a tool not in discovered set."""
|
||||
agent = _make_agent(
|
||||
gateway_responses=[
|
||||
'{"tool": "fs_read", "parameters": {"file_path": "/etc/passwd"}}',
|
||||
'{"done": true, "summary": "Done.", "confidence": "UNCERTAIN"}',
|
||||
],
|
||||
)
|
||||
result = await agent.run(_make_request())
|
||||
assert result.status in (
|
||||
common_pb2.RESULT_STATUS_PARTIAL,
|
||||
common_pb2.RESULT_STATUS_SUCCESS,
|
||||
)
|
||||
agent._broker.execute_tool.assert_not_called()
|
||||
|
||||
|
||||
async def test_tool_discovery_grpc_error():
|
||||
"""Tool discovery failure -> FAILED result."""
|
||||
import grpc
|
||||
|
||||
agent = _make_agent()
|
||||
agent._broker.discover_tools = AsyncMock(
|
||||
side_effect=grpc.aio.AioRpcError(
|
||||
grpc.StatusCode.UNAVAILABLE,
|
||||
initial_metadata=grpc.aio.Metadata(),
|
||||
trailing_metadata=grpc.aio.Metadata(),
|
||||
details="down",
|
||||
)
|
||||
)
|
||||
result = await agent.run(_make_request())
|
||||
assert result.status == common_pb2.RESULT_STATUS_FAILED
|
||||
assert "discovery" in result.failure_reason.lower()
|
||||
|
||||
|
||||
async def test_memory_service_unavailable_continues():
|
||||
"""Memory service error is handled gracefully."""
|
||||
import grpc
|
||||
|
||||
agent = _make_agent(
|
||||
gateway_responses=[
|
||||
'{"done": true, "summary": "Done without memory.", "confidence": "VERIFIED"}'
|
||||
]
|
||||
)
|
||||
agent._memory.query_memory = AsyncMock(
|
||||
side_effect=grpc.aio.AioRpcError(
|
||||
grpc.StatusCode.UNAVAILABLE,
|
||||
initial_metadata=grpc.aio.Metadata(),
|
||||
trailing_metadata=grpc.aio.Metadata(),
|
||||
details="down",
|
||||
)
|
||||
)
|
||||
result = await agent.run(_make_request())
|
||||
assert result.status == common_pb2.RESULT_STATUS_SUCCESS
|
||||
|
||||
|
||||
async def test_parse_error_handling():
|
||||
"""ParseError from model output is handled."""
|
||||
agent = _make_agent(
|
||||
gateway_responses=[
|
||||
'{"tool": "", "parameters": {}}',
|
||||
'{"done": true, "summary": "Done after error.", "confidence": "VERIFIED"}',
|
||||
]
|
||||
)
|
||||
result = await agent.run(_make_request())
|
||||
assert result.status == common_pb2.RESULT_STATUS_SUCCESS
|
||||
assert "Done after error" in result.summary
|
||||
|
||||
|
||||
async def test_plain_reasoning_continues():
|
||||
"""Model produces plain text, then done signal."""
|
||||
agent = _make_agent(
|
||||
gateway_responses=[
|
||||
"Let me think about this question...",
|
||||
'{"done": true, "summary": "Analyzed.", "confidence": "INFERRED"}',
|
||||
]
|
||||
)
|
||||
result = await agent.run(_make_request())
|
||||
assert result.status == common_pb2.RESULT_STATUS_SUCCESS
|
||||
assert result.result_quality == common_pb2.RESULT_QUALITY_INFERRED
|
||||
|
||||
|
||||
async def test_tool_execution_grpc_error():
|
||||
"""Tool execution gRPC error is handled gracefully."""
|
||||
import grpc
|
||||
|
||||
agent = _make_agent(
|
||||
gateway_responses=[
|
||||
'{"tool": "web_search", "parameters": {"query": "test"}}'
|
||||
]
|
||||
* 5,
|
||||
config=AgentConfig(max_iterations=10),
|
||||
)
|
||||
agent._broker.execute_tool = AsyncMock(
|
||||
side_effect=grpc.aio.AioRpcError(
|
||||
grpc.StatusCode.INTERNAL,
|
||||
initial_metadata=grpc.aio.Metadata(),
|
||||
trailing_metadata=grpc.aio.Metadata(),
|
||||
details="exec failed",
|
||||
)
|
||||
)
|
||||
result = await agent.run(_make_request())
|
||||
assert result.status == common_pb2.RESULT_STATUS_FAILED
|
||||
assert "consecutive failures" in result.failure_reason
|
||||
|
||||
|
||||
def test_format_memory_with_description():
|
||||
entry = memory_pb2.MemoryEntry(name="test-mem", description="A description")
|
||||
response = memory_pb2.QueryMemoryResponse(
|
||||
rank=0, entry=entry, cosine_similarity=0.9
|
||||
)
|
||||
result = _format_memory(response)
|
||||
assert "[Memory: test-mem]" in result
|
||||
assert "A description" in result
|
||||
|
||||
|
||||
def test_format_memory_with_cached_segment():
|
||||
entry = memory_pb2.MemoryEntry(
|
||||
name="test-mem", description="desc", corpus="full corpus"
|
||||
)
|
||||
response = memory_pb2.QueryMemoryResponse(
|
||||
rank=0,
|
||||
entry=entry,
|
||||
cosine_similarity=0.9,
|
||||
cached_extracted_segment="extracted segment",
|
||||
)
|
||||
result = _format_memory(response)
|
||||
assert "extracted segment" in result
|
||||
assert "full corpus" not in result
|
||||
|
||||
|
||||
async def test_agent_id_prefix():
|
||||
"""AssistantAgent generates agent IDs with 'ast-' prefix."""
|
||||
agent = _make_agent()
|
||||
request = _make_request()
|
||||
request.agent_id = "" # Clear to trigger auto-generation
|
||||
await agent.run(request)
|
||||
|
||||
|
||||
async def test_compaction_triggers_and_continues():
|
||||
"""Compaction triggers mid-loop and allows the agent to continue."""
|
||||
agent = _make_agent(
|
||||
gateway_responses=[
|
||||
'{"tool": "web_search", "parameters": {"query": "q1"}}',
|
||||
'{"tool": "web_search", "parameters": {"query": "q2"}}',
|
||||
'{"tool": "web_search", "parameters": {"query": "q3"}}',
|
||||
'{"tool": "web_search", "parameters": {"query": "q4"}}',
|
||||
'{"done": true, "summary": "Done after compaction.", "confidence": "VERIFIED"}',
|
||||
],
|
||||
exec_output="A" * 600,
|
||||
config=AgentConfig(max_iterations=20, max_tokens=1800),
|
||||
)
|
||||
agent._gateway.inference = AsyncMock(return_value="- Short summary")
|
||||
result = await agent.run(_make_request(max_tokens=1800))
|
||||
assert result.status == common_pb2.RESULT_STATUS_SUCCESS
|
||||
assert "Done after compaction" in result.summary
|
||||
|
||||
|
||||
async def test_compaction_failure_gives_partial():
|
||||
"""If compaction can't free enough space, loop terminates with PARTIAL."""
|
||||
agent = _make_agent(
|
||||
gateway_responses=[
|
||||
'{"tool": "web_search", "parameters": {"query": "x"}}'
|
||||
] * 10,
|
||||
exec_output="B" * 500,
|
||||
config=AgentConfig(max_iterations=20, max_tokens=200),
|
||||
)
|
||||
agent._gateway.inference = AsyncMock(return_value="still big summary " * 50)
|
||||
result = await agent.run(_make_request(max_tokens=200))
|
||||
assert result.status == common_pb2.RESULT_STATUS_PARTIAL
|
||||
|
||||
|
||||
async def test_memory_candidates_in_done_signal():
|
||||
"""Memory candidates from done signal are included in result."""
|
||||
agent = _make_agent(
|
||||
gateway_responses=[
|
||||
'{"done": true, "summary": "Found answer.", "findings": ["answer"], '
|
||||
'"confidence": "VERIFIED", '
|
||||
'"memory_candidates": [{"content": "Docker uses cgroups", "confidence": 0.9}]}'
|
||||
]
|
||||
)
|
||||
result = await agent.run(_make_request())
|
||||
assert result.status == common_pb2.RESULT_STATUS_SUCCESS
|
||||
assert len(result.new_memory_candidates) == 1
|
||||
assert result.new_memory_candidates[0].content == "Docker uses cgroups"
|
||||
assert abs(result.new_memory_candidates[0].confidence - 0.9) < 1e-6
|
||||
|
||||
|
||||
async def test_partial_result_message():
|
||||
"""Partial result uses assistant-specific message."""
|
||||
agent = _make_agent(
|
||||
gateway_responses=[
|
||||
'{"tool": "web_search", "parameters": {"query": "test"}}'
|
||||
] * 5,
|
||||
config=AgentConfig(max_iterations=2),
|
||||
)
|
||||
result = await agent.run(_make_request())
|
||||
assert "Assistance incomplete" in result.summary
|
||||
|
||||
|
||||
async def test_failed_result_message():
|
||||
"""Failed result uses assistant-specific message."""
|
||||
agent = _make_agent(tools=[])
|
||||
result = await agent.run(_make_request())
|
||||
assert "Assistance failed" in result.summary
|
||||
Reference in New Issue
Block a user