feat: update reasoning and task completion markers

feat: enable autonomous mode by default
refactor: improve assistant class structure
maintenance: update pyproject.toml version to 1.53.0
fix: handle edge cases in autonomous mode content extraction
docs: clarify autonomous mode deprecation in command line arguments
This commit is contained in:
retoor 2025-11-10 10:33:31 +01:00
parent 63c2f52885
commit 20668d9086
15 changed files with 249 additions and 184 deletions

View File

@ -1,6 +1,53 @@
# Changelog
## Version 1.52.0 - 2025-11-10
This release updates the project version to 1.52.0. No new features or changes are introduced for users or developers.
**Changes:** 1 files, 2 lines
**Languages:** TOML (2 lines)
## Version 1.51.0 - 2025-11-10
The system can now extract and clean reasoning steps during task completion. Autonomous mode has been updated to recognize these reasoning steps and task completion markers, improving overall performance.
**Changes:** 5 files, 65 lines
**Languages:** Python (63 lines), TOML (2 lines)
## Version 1.50.0 - 2025-11-09
### Added
- **LLM Reasoning Display**: The assistant now displays its reasoning process before each response
- Added `REASONING:` prefix instruction in system prompt
- Reasoning is extracted and displayed with a blue thought bubble icon
- Provides transparency into the assistant's decision-making process
- **Task Completion Marker**: Implemented `[TASK_COMPLETE]` marker for explicit task completion signaling
- LLM can now mark tasks as complete with a special marker
- Marker is stripped from user-facing output
- Autonomous mode detection recognizes the marker for faster completion
- Reduces unnecessary iterations when tasks are finished
### Changed
- Updated system prompt in `context.py` to include response format instructions
- Enhanced `process_response_autonomous()` to extract and display reasoning
- Modified `is_task_complete()` to recognize `[TASK_COMPLETE]` marker
- Both autonomous and regular modes now support reasoning display
**Changes:** 3 files, 52 lines
**Languages:** Python (52 lines)
## Version 1.49.0 - 2025-11-09
Autonomous mode is now enabled by default, improving performance. Identical messages are now removed in autonomous mode to prevent redundancy.
**Changes:** 3 files, 28 lines
**Languages:** Markdown (18 lines), Python (8 lines), TOML (2 lines)
## Version 1.48.1 - 2025-11-09
### Fixed

View File

@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
[project]
name = "rp"
version = "1.51.0"
version = "1.52.0"
description = "R python edition. The ultimate autonomous AI CLI."
readme = "README.md"
requires-python = ">=3.10"

View File

@ -45,7 +45,12 @@ Commands in interactive mode:
parser.add_argument("-u", "--api-url", help="API endpoint URL")
parser.add_argument("--model-list-url", help="Model list endpoint URL")
parser.add_argument("-i", "--interactive", action="store_true", help="Interactive mode")
parser.add_argument("-a", "--autonomous", action="store_true", help="Autonomous mode (now default, this flag is deprecated)")
parser.add_argument(
"-a",
"--autonomous",
action="store_true",
help="Autonomous mode (now default, this flag is deprecated)",
)
parser.add_argument("-v", "--verbose", action="store_true", help="Verbose output")
parser.add_argument(
"--debug", action="store_true", help="Enable debug mode with detailed logging"

View File

@ -21,17 +21,17 @@ def extract_reasoning_and_clean_content(content):
tuple: (reasoning, cleaned_content)
"""
reasoning = None
lines = content.split('\n')
lines = content.split("\n")
cleaned_lines = []
for line in lines:
if line.strip().startswith('REASONING:'):
if line.strip().startswith("REASONING:"):
reasoning = line.strip()[10:].strip()
else:
cleaned_lines.append(line)
cleaned_content = '\n'.join(cleaned_lines)
cleaned_content = cleaned_content.replace('[TASK_COMPLETE]', '').strip()
cleaned_content = "\n".join(cleaned_lines)
cleaned_content = cleaned_content.replace("[TASK_COMPLETE]", "").strip()
return reasoning, cleaned_content
@ -128,7 +128,11 @@ def process_response_autonomous(assistant, response):
display_tool_call(func_name, arguments, status, result)
sanitized_result = sanitize_for_json(result)
tool_results.append(
{"tool_call_id": tool_call["id"], "role": "tool", "content": json.dumps(sanitized_result)}
{
"tool_call_id": tool_call["id"],
"role": "tool",
"content": json.dumps(sanitized_result),
}
)
for result in tool_results:
assistant.messages.append(result)

View File

@ -2,4 +2,11 @@ from rp.core.api import call_api, list_models
from rp.core.assistant import Assistant
from rp.core.context import init_system_message, manage_context_window, get_context_content
__all__ = ["Assistant", "call_api", "list_models", "init_system_message", "manage_context_window", "get_context_content"]
__all__ = [
"Assistant",
"call_api",
"list_models",
"init_system_message",
"manage_context_window",
"get_context_content",
]

View File

@ -6,9 +6,7 @@ import readline
import signal
import sqlite3
import sys
import time
import traceback
import uuid
from concurrent.futures import ThreadPoolExecutor
from rp.commands import handle_command
@ -68,7 +66,7 @@ from rp.tools.memory import (
from rp.tools.patch import apply_patch, create_diff, display_file_diff
from rp.tools.python_exec import python_exec
from rp.tools.web import http_fetch, web_search, web_search_news
from rp.ui import Colors, Spinner, render_markdown
from rp.ui import Colors, render_markdown
from rp.ui.progress import ProgressIndicator
logger = logging.getLogger("rp")
@ -112,6 +110,7 @@ class Assistant:
self.last_result = None
self.init_database()
from rp.memory import KnowledgeStore, FactExtractor, GraphMemory
self.knowledge_store = KnowledgeStore(DB_PATH, db_conn=self.db_conn)
self.fact_extractor = FactExtractor()
self.graph_memory = GraphMemory(DB_PATH, db_conn=self.db_conn)
@ -127,7 +126,10 @@ class Assistant:
self.enhanced = None
from rp.config import BACKGROUND_MONITOR_ENABLED
bg_enabled = os.environ.get("BACKGROUND_MONITOR", str(BACKGROUND_MONITOR_ENABLED)).lower() in ("1", "true", "yes")
bg_enabled = os.environ.get(
"BACKGROUND_MONITOR", str(BACKGROUND_MONITOR_ENABLED)
).lower() in ("1", "true", "yes")
if bg_enabled:
try:
@ -338,6 +340,7 @@ class Assistant:
content = message.get("content", "")
from rp.autonomous.mode import extract_reasoning_and_clean_content
reasoning, cleaned_content = extract_reasoning_and_clean_content(content)
if reasoning:
@ -438,6 +441,7 @@ class Assistant:
# If cmd_result is None, it's not a special command, process with autonomous mode.
elif cmd_result is None:
from rp.autonomous import run_autonomous_mode
run_autonomous_mode(self, user_input)
except EOFError:
break
@ -453,6 +457,7 @@ class Assistant:
else:
message = sys.stdin.read()
from rp.autonomous import run_autonomous_mode
run_autonomous_mode(self, message)
def run_autonomous(self):

View File

@ -70,6 +70,7 @@ def get_context_content():
logging.error(f"Error reading context file {knowledge_file}: {e}")
return "\n\n".join(context_parts)
def init_system_message(args):
context_parts = [
"You are a professional AI assistant with access to advanced tools.",

View File

@ -17,7 +17,9 @@ def inject_knowledge_context(assistant, user_message):
break
try:
# Run all search methods
knowledge_results = assistant.enhanced.knowledge_store.search_entries(user_message, top_k=5) # Hybrid semantic + keyword + category
knowledge_results = assistant.enhanced.knowledge_store.search_entries(
user_message, top_k=5
) # Hybrid semantic + keyword + category
# Additional keyword search if needed (but already in hybrid)
# Category-specific: preferences and general
pref_results = assistant.enhanced.knowledge_store.get_by_category("preferences", limit=5)
@ -25,12 +27,14 @@ def inject_knowledge_context(assistant, user_message):
category_results = []
for entry in pref_results + general_results:
if any(word in entry.content.lower() for word in user_message.lower().split()):
category_results.append({
category_results.append(
{
"content": entry.content,
"score": 0.6,
"source": f"Knowledge Base ({entry.category})",
"type": "knowledge_category",
})
}
)
conversation_results = []
if hasattr(assistant.enhanced, "conversation_memory"):

View File

@ -4,87 +4,107 @@ import sqlite3
from typing import List, Optional, Set
from dataclasses import dataclass, field
@dataclass
class Entity:
name: str
entityType: str
observations: List[str]
@dataclass
class Relation:
from_: str = field(metadata={'alias': 'from'})
from_: str = field(metadata={"alias": "from"})
to: str
relationType: str
@dataclass
class KnowledgeGraph:
entities: List[Entity]
relations: List[Relation]
@dataclass
class CreateEntitiesRequest:
entities: List[Entity]
@dataclass
class CreateRelationsRequest:
relations: List[Relation]
@dataclass
class ObservationItem:
entityName: str
contents: List[str]
@dataclass
class AddObservationsRequest:
observations: List[ObservationItem]
@dataclass
class DeletionItem:
entityName: str
observations: List[str]
@dataclass
class DeleteObservationsRequest:
deletions: List[DeletionItem]
@dataclass
class DeleteEntitiesRequest:
entityNames: List[str]
@dataclass
class DeleteRelationsRequest:
relations: List[Relation]
@dataclass
class SearchNodesRequest:
query: str
@dataclass
class OpenNodesRequest:
names: List[str]
depth: int = 1
@dataclass
class PopulateRequest:
text: str
class GraphMemory:
def __init__(self, db_path: str = 'graph_memory.db', db_conn: Optional[sqlite3.Connection] = None):
def __init__(
self, db_path: str = "graph_memory.db", db_conn: Optional[sqlite3.Connection] = None
):
self.db_path = db_path
self.conn = db_conn if db_conn else sqlite3.connect(self.db_path, check_same_thread=False)
self.init_db()
def init_db(self):
cursor = self.conn.cursor()
cursor.execute('''
cursor.execute(
"""
CREATE TABLE IF NOT EXISTS entities (
id INTEGER PRIMARY KEY,
name TEXT UNIQUE,
entity_type TEXT,
observations TEXT
)
''')
cursor.execute('''
"""
)
cursor.execute(
"""
CREATE TABLE IF NOT EXISTS relations (
id INTEGER PRIMARY KEY,
from_entity TEXT,
@ -92,7 +112,8 @@ class GraphMemory:
relation_type TEXT,
UNIQUE(from_entity, to_entity, relation_type)
)
''')
"""
)
self.conn.commit()
def create_entities(self, entities: List[Entity]) -> List[Entity]:
@ -101,8 +122,10 @@ class GraphMemory:
cursor = conn.cursor()
for e in entities:
try:
cursor.execute('INSERT INTO entities (name, entity_type, observations) VALUES (?, ?, ?)',
(e.name, e.entityType, json.dumps(e.observations)))
cursor.execute(
"INSERT INTO entities (name, entity_type, observations) VALUES (?, ?, ?)",
(e.name, e.entityType, json.dumps(e.observations)),
)
new_entities.append(e)
except sqlite3.IntegrityError:
pass # already exists
@ -115,8 +138,10 @@ class GraphMemory:
cursor = conn.cursor()
for r in relations:
try:
cursor.execute('INSERT INTO relations (from_entity, to_entity, relation_type) VALUES (?, ?, ?)',
(r.from_, r.to, r.relationType))
cursor.execute(
"INSERT INTO relations (from_entity, to_entity, relation_type) VALUES (?, ?, ?)",
(r.from_, r.to, r.relationType),
)
new_relations.append(r)
except sqlite3.IntegrityError:
pass # already exists
@ -130,7 +155,7 @@ class GraphMemory:
for obs in observations:
name = obs.entityName.lower()
contents = obs.contents
cursor.execute('SELECT observations FROM entities WHERE LOWER(name) = ?', (name,))
cursor.execute("SELECT observations FROM entities WHERE LOWER(name) = ?", (name,))
row = cursor.fetchone()
if not row:
# Log the error instead of raising an exception
@ -139,7 +164,10 @@ class GraphMemory:
current_obs = json.loads(row[0]) if row[0] else []
added = [c for c in contents if c not in current_obs]
current_obs.extend(added)
cursor.execute('UPDATE entities SET observations = ? WHERE LOWER(name) = ?', (json.dumps(current_obs), name))
cursor.execute(
"UPDATE entities SET observations = ? WHERE LOWER(name) = ?",
(json.dumps(current_obs), name),
)
results.append({"entityName": name, "addedObservations": added})
conn.commit()
return results
@ -148,11 +176,16 @@ class GraphMemory:
conn = self.conn
cursor = conn.cursor()
# delete entities
cursor.executemany('DELETE FROM entities WHERE LOWER(name) = ?', [(n.lower(),) for n in entity_names])
cursor.executemany(
"DELETE FROM entities WHERE LOWER(name) = ?", [(n.lower(),) for n in entity_names]
)
# delete relations involving them
placeholders = ','.join('?' * len(entity_names))
placeholders = ",".join("?" * len(entity_names))
params = [n.lower() for n in entity_names] * 2
cursor.execute(f'DELETE FROM relations WHERE LOWER(from_entity) IN ({placeholders}) OR LOWER(to_entity) IN ({placeholders})', params)
cursor.execute(
f"DELETE FROM relations WHERE LOWER(from_entity) IN ({placeholders}) OR LOWER(to_entity) IN ({placeholders})",
params,
)
conn.commit()
def delete_observations(self, deletions: List[DeletionItem]):
@ -161,20 +194,25 @@ class GraphMemory:
for del_item in deletions:
name = del_item.entityName.lower()
to_delete = del_item.observations
cursor.execute('SELECT observations FROM entities WHERE LOWER(name) = ?', (name,))
cursor.execute("SELECT observations FROM entities WHERE LOWER(name) = ?", (name,))
row = cursor.fetchone()
if row:
current_obs = json.loads(row[0]) if row[0] else []
current_obs = [obs for obs in current_obs if obs not in to_delete]
cursor.execute('UPDATE entities SET observations = ? WHERE LOWER(name) = ?', (json.dumps(current_obs), name))
cursor.execute(
"UPDATE entities SET observations = ? WHERE LOWER(name) = ?",
(json.dumps(current_obs), name),
)
conn.commit()
def delete_relations(self, relations: List[Relation]):
conn = self.conn
cursor = conn.cursor()
for r in relations:
cursor.execute('DELETE FROM relations WHERE LOWER(from_entity) = ? AND LOWER(to_entity) = ? AND LOWER(relation_type) = ?',
(r.from_.lower(), r.to.lower(), r.relationType.lower()))
cursor.execute(
"DELETE FROM relations WHERE LOWER(from_entity) = ? AND LOWER(to_entity) = ? AND LOWER(relation_type) = ?",
(r.from_.lower(), r.to.lower(), r.relationType.lower()),
)
conn.commit()
def read_graph(self) -> KnowledgeGraph:
@ -182,12 +220,12 @@ class GraphMemory:
relations = []
conn = self.conn
cursor = conn.cursor()
cursor.execute('SELECT name, entity_type, observations FROM entities')
cursor.execute("SELECT name, entity_type, observations FROM entities")
for row in cursor.fetchall():
name, etype, obs = row
observations = json.loads(obs) if obs else []
entities.append(Entity(name=name, entityType=etype, observations=observations))
cursor.execute('SELECT from_entity, to_entity, relation_type FROM relations')
cursor.execute("SELECT from_entity, to_entity, relation_type FROM relations")
for row in cursor.fetchall():
relations.append(Relation(from_=row[0], to=row[1], relationType=row[2]))
return KnowledgeGraph(entities=entities, relations=relations)
@ -197,17 +235,19 @@ class GraphMemory:
conn = self.conn
cursor = conn.cursor()
query_lower = query.lower()
cursor.execute('SELECT name, entity_type, observations FROM entities')
cursor.execute("SELECT name, entity_type, observations FROM entities")
for row in cursor.fetchall():
name, etype, obs = row
observations = json.loads(obs) if obs else []
if (query_lower in name.lower() or
query_lower in etype.lower() or
any(query_lower in o.lower() for o in observations)):
if (
query_lower in name.lower()
or query_lower in etype.lower()
or any(query_lower in o.lower() for o in observations)
):
entities.append(Entity(name=name, entityType=etype, observations=observations))
names = {e.name.lower() for e in entities}
relations = []
cursor.execute('SELECT from_entity, to_entity, relation_type FROM relations')
cursor.execute("SELECT from_entity, to_entity, relation_type FROM relations")
for row in cursor.fetchall():
if row[0].lower() in names and row[1].lower() in names:
relations.append(Relation(from_=row[0], to=row[1], relationType=row[2]))
@ -221,13 +261,16 @@ class GraphMemory:
def traverse(current_names: List[str], current_depth: int):
if current_depth > depth:
return
name_set = {n.lower() for n in current_names}
{n.lower() for n in current_names}
new_entities = []
conn = self.conn
cursor = conn.cursor()
placeholders = ','.join('?' * len(current_names))
placeholders = ",".join("?" * len(current_names))
params = [n.lower() for n in current_names]
cursor.execute(f'SELECT name, entity_type, observations FROM entities WHERE LOWER(name) IN ({placeholders})', params)
cursor.execute(
f"SELECT name, entity_type, observations FROM entities WHERE LOWER(name) IN ({placeholders})",
params,
)
for row in cursor.fetchall():
name, etype, obs = row
if name.lower() not in visited:
@ -237,9 +280,12 @@ class GraphMemory:
new_entities.append(entity)
entities.append(entity)
# Find relations involving these entities
placeholders = ','.join('?' * len(new_entities))
placeholders = ",".join("?" * len(new_entities))
params = [e.name.lower() for e in new_entities] * 2
cursor.execute(f'SELECT from_entity, to_entity, relation_type FROM relations WHERE LOWER(from_entity) IN ({placeholders}) OR LOWER(to_entity) IN ({placeholders})', params)
cursor.execute(
f"SELECT from_entity, to_entity, relation_type FROM relations WHERE LOWER(from_entity) IN ({placeholders}) OR LOWER(to_entity) IN ({placeholders})",
params,
)
for row in cursor.fetchall():
rel = Relation(from_=row[0], to=row[1], relationType=row[2])
if rel not in relations:
@ -254,24 +300,24 @@ class GraphMemory:
def populate_from_text(self, text: str):
# Algorithm: Extract entities as capitalized words, relations from patterns, observations from sentences mentioning entities
entities = set(re.findall(r'\b[A-Z][a-zA-Z]*\b', text))
entities = set(re.findall(r"\b[A-Z][a-zA-Z]*\b", text))
for entity in entities:
self.create_entities([Entity(name=entity, entityType='unknown', observations=[])])
self.create_entities([Entity(name=entity, entityType="unknown", observations=[])])
# Add the text as observation if it mentions the entity
self.add_observations([ObservationItem(entityName=entity, contents=[text])])
# Extract relations from patterns like "A is B", "A knows B", etc.
patterns = [
(r'(\w+) is (a|an) (\w+)', 'is_a'),
(r'(\w+) knows (\w+)', 'knows'),
(r'(\w+) works at (\w+)', 'works_at'),
(r'(\w+) lives in (\w+)', 'lives_in'),
(r'(\w+) is (\w+)', 'is'), # general
(r"(\w+) is (a|an) (\w+)", "is_a"),
(r"(\w+) knows (\w+)", "knows"),
(r"(\w+) works at (\w+)", "works_at"),
(r"(\w+) lives in (\w+)", "lives_in"),
(r"(\w+) is (\w+)", "is"), # general
]
for pattern, rel_type in patterns:
matches = re.findall(pattern, text, re.IGNORECASE)
for match in matches:
if len(match) == 3 and match[1].lower() in ['a', 'an']:
if len(match) == 3 and match[1].lower() in ["a", "an"]:
from_e, _, to_e = match
elif len(match) == 2:
from_e, to_e = match
@ -280,9 +326,10 @@ class GraphMemory:
if from_e in entities and to_e in entities:
self.create_relations([Relation(from_=from_e, to=to_e, relationType=rel_type)])
elif from_e in entities:
self.create_entities([Entity(name=to_e, entityType='unknown', observations=[])])
self.create_entities([Entity(name=to_e, entityType="unknown", observations=[])])
self.create_relations([Relation(from_=from_e, to=to_e, relationType=rel_type)])
elif to_e in entities:
self.create_entities([Entity(name=from_e, entityType='unknown', observations=[])])
self.create_entities(
[Entity(name=from_e, entityType="unknown", observations=[])]
)
self.create_relations([Relation(from_=from_e, to=to_e, relationType=rel_type)])

View File

@ -20,7 +20,7 @@ class KnowledgeEntry:
importance_score: float = 1.0
def __str__(self):
return json.dumps(self.to_dict(), indent=4, sort_keys=True,default=str)
return json.dumps(self.to_dict(), indent=4, sort_keys=True, default=str)
def to_dict(self) -> Dict[str, Any]:
return {

View File

@ -1,4 +1,3 @@
import os
from pathlib import Path
from typing import Optional

View File

@ -4,10 +4,10 @@ from rp.ui.colors import Colors
def display_tool_call(tool_name, arguments, status="running", result=None):
if status == "running":
return
args_str = ", ".join([f"{k}={str(v)[:20]}" for k, v in list(arguments.items())[:2]])
args_str = ", ".join([f"{k}={repr(v)}" for k, v in arguments.items()])
line = f"{tool_name}({args_str})"
if len(line) > 80:
line = line[:77] + "..."
if len(line) > 120:
line = line[:117] + "..."
print(f"{Colors.GRAY}{line}{Colors.RESET}")

View File

@ -114,10 +114,6 @@ class TestAssistant(unittest.TestCase):
process_message(assistant, "test message")
from rp.memory import KnowledgeEntry
import json
import time
import uuid
from unittest.mock import ANY
# Mock time.time() and uuid.uuid4() to return consistent values
expected_entry = KnowledgeEntry(
@ -132,7 +128,7 @@ class TestAssistant(unittest.TestCase):
created_at=1234567890.123456,
updated_at=1234567890.123456,
)
expected_content = str(expected_entry)
str(expected_entry)
assistant.knowledge_store.add_entry.assert_called_once_with(expected_entry)

View File

@ -696,4 +696,3 @@ class TestShowBackgroundEvents:
def test_show_events_exception(self, mock_get):
mock_get.side_effect = Exception("test")
show_background_events(self.assistant)

View File

@ -1,4 +1,3 @@
import pytest
import tempfile
import os
from rp.workflows.workflow_definition import ExecutionMode, Workflow, WorkflowStep
@ -16,7 +15,7 @@ class TestWorkflowStep:
on_success=["step2"],
on_failure=["step3"],
retry_count=2,
timeout_seconds=600
timeout_seconds=600,
)
assert step.tool_name == "test_tool"
assert step.arguments == {"arg1": "value1"}
@ -28,11 +27,7 @@ class TestWorkflowStep:
assert step.timeout_seconds == 600
def test_to_dict(self):
step = WorkflowStep(
tool_name="test_tool",
arguments={"arg1": "value1"},
step_id="step1"
)
step = WorkflowStep(tool_name="test_tool", arguments={"arg1": "value1"}, step_id="step1")
expected = {
"tool_name": "test_tool",
"arguments": {"arg1": "value1"},
@ -77,7 +72,7 @@ class TestWorkflow:
steps=[step1, step2],
execution_mode=ExecutionMode.PARALLEL,
variables={"var1": "value1"},
tags=["tag1", "tag2"]
tags=["tag1", "tag2"],
)
assert workflow.name == "test_workflow"
assert workflow.description == "A test workflow"
@ -94,7 +89,7 @@ class TestWorkflow:
steps=[step1],
execution_mode=ExecutionMode.SEQUENTIAL,
variables={"var1": "value1"},
tags=["tag1"]
tags=["tag1"],
)
expected = {
"name": "test_workflow",
@ -110,7 +105,8 @@ class TestWorkflow:
data = {
"name": "test_workflow",
"description": "A test workflow",
"steps": [{
"steps": [
{
"tool_name": "tool1",
"arguments": {},
"step_id": "step1",
@ -119,7 +115,8 @@ class TestWorkflow:
"on_failure": None,
"retry_count": 0,
"timeout_seconds": 300,
}],
}
],
"execution_mode": "parallel",
"variables": {"var1": "value1"},
"tags": ["tag1"],
@ -134,11 +131,7 @@ class TestWorkflow:
assert workflow.tags == ["tag1"]
def test_add_step(self):
workflow = Workflow(
name="test_workflow",
description="A test workflow",
steps=[]
)
workflow = Workflow(name="test_workflow", description="A test workflow", steps=[])
step = WorkflowStep(tool_name="tool1", arguments={}, step_id="step1")
workflow.add_step(step)
assert workflow.steps == [step]
@ -147,9 +140,7 @@ class TestWorkflow:
step1 = WorkflowStep(tool_name="tool1", arguments={}, step_id="step1")
step2 = WorkflowStep(tool_name="tool2", arguments={}, step_id="step2")
workflow = Workflow(
name="test_workflow",
description="A test workflow",
steps=[step1, step2]
name="test_workflow", description="A test workflow", steps=[step1, step2]
)
assert workflow.get_step("step1") == step1
assert workflow.get_step("step2") == step2
@ -162,7 +153,7 @@ class TestWorkflow:
name="test_workflow",
description="A test workflow",
steps=[step1, step2],
execution_mode=ExecutionMode.SEQUENTIAL
execution_mode=ExecutionMode.SEQUENTIAL,
)
assert workflow.get_initial_steps() == [step1]
@ -173,7 +164,7 @@ class TestWorkflow:
name="test_workflow",
description="A test workflow",
steps=[step1, step2],
execution_mode=ExecutionMode.PARALLEL
execution_mode=ExecutionMode.PARALLEL,
)
assert workflow.get_initial_steps() == [step1, step2]
@ -184,7 +175,7 @@ class TestWorkflow:
name="test_workflow",
description="A test workflow",
steps=[step1, step2],
execution_mode=ExecutionMode.CONDITIONAL
execution_mode=ExecutionMode.CONDITIONAL,
)
assert workflow.get_initial_steps() == [step2] # Only step without condition
@ -200,11 +191,7 @@ class TestWorkflowStorage:
def test_save_and_load_workflow(self):
step = WorkflowStep(tool_name="tool1", arguments={}, step_id="step1")
workflow = Workflow(
name="test_workflow",
description="A test workflow",
steps=[step]
)
workflow = Workflow(name="test_workflow", description="A test workflow", steps=[step])
workflow_id = self.storage.save_workflow(workflow)
loaded = self.storage.load_workflow(workflow_id)
assert loaded is not None
@ -219,11 +206,7 @@ class TestWorkflowStorage:
def test_load_workflow_by_name(self):
step = WorkflowStep(tool_name="tool1", arguments={}, step_id="step1")
workflow = Workflow(
name="test_workflow",
description="A test workflow",
steps=[step]
)
workflow = Workflow(name="test_workflow", description="A test workflow", steps=[step])
self.storage.save_workflow(workflow)
loaded = self.storage.load_workflow_by_name("test_workflow")
assert loaded is not None
@ -236,10 +219,7 @@ class TestWorkflowStorage:
def test_list_workflows(self):
step = WorkflowStep(tool_name="tool1", arguments={}, step_id="step1")
workflow = Workflow(
name="test_workflow",
description="A test workflow",
steps=[step],
tags=["tag1"]
name="test_workflow", description="A test workflow", steps=[step], tags=["tag1"]
)
self.storage.save_workflow(workflow)
workflows = self.storage.list_workflows()
@ -250,16 +230,10 @@ class TestWorkflowStorage:
def test_list_workflows_with_tag(self):
step = WorkflowStep(tool_name="tool1", arguments={}, step_id="step1")
workflow1 = Workflow(
name="test_workflow1",
description="A test workflow",
steps=[step],
tags=["tag1"]
name="test_workflow1", description="A test workflow", steps=[step], tags=["tag1"]
)
workflow2 = Workflow(
name="test_workflow2",
description="A test workflow",
steps=[step],
tags=["tag2"]
name="test_workflow2", description="A test workflow", steps=[step], tags=["tag2"]
)
self.storage.save_workflow(workflow1)
self.storage.save_workflow(workflow2)
@ -269,11 +243,7 @@ class TestWorkflowStorage:
def test_delete_workflow(self):
step = WorkflowStep(tool_name="tool1", arguments={}, step_id="step1")
workflow = Workflow(
name="test_workflow",
description="A test workflow",
steps=[step]
)
workflow = Workflow(name="test_workflow", description="A test workflow", steps=[step])
workflow_id = self.storage.save_workflow(workflow)
deleted = self.storage.delete_workflow(workflow_id)
assert deleted is True
@ -286,11 +256,7 @@ class TestWorkflowStorage:
def test_save_execution(self):
step = WorkflowStep(tool_name="tool1", arguments={}, step_id="step1")
workflow = Workflow(
name="test_workflow",
description="A test workflow",
steps=[step]
)
workflow = Workflow(name="test_workflow", description="A test workflow", steps=[step])
workflow_id = self.storage.save_workflow(workflow)
context = WorkflowExecutionContext()
context.set_step_result("step1", "result")
@ -303,11 +269,7 @@ class TestWorkflowStorage:
def test_get_execution_history(self):
step = WorkflowStep(tool_name="tool1", arguments={}, step_id="step1")
workflow = Workflow(
name="test_workflow",
description="A test workflow",
steps=[step]
)
workflow = Workflow(name="test_workflow", description="A test workflow", steps=[step])
workflow_id = self.storage.save_workflow(workflow)
context = WorkflowExecutionContext()
context.set_step_result("step1", "result")
@ -318,11 +280,7 @@ class TestWorkflowStorage:
def test_get_execution_history_limit(self):
step = WorkflowStep(tool_name="tool1", arguments={}, step_id="step1")
workflow = Workflow(
name="test_workflow",
description="A test workflow",
steps=[step]
)
workflow = Workflow(name="test_workflow", description="A test workflow", steps=[step])
workflow_id = self.storage.save_workflow(workflow)
for i in range(5):
context = WorkflowExecutionContext()
@ -367,6 +325,7 @@ class TestWorkflowEngine:
def test_init(self):
def tool_executor(tool_name, args):
return f"executed {tool_name} with {args}"
engine = WorkflowEngine(tool_executor, max_workers=10)
assert engine.tool_executor == tool_executor
assert engine.max_workers == 10
@ -374,6 +333,7 @@ class TestWorkflowEngine:
def test_evaluate_condition_true(self):
def tool_executor(tool_name, args):
return "result"
engine = WorkflowEngine(tool_executor)
context = WorkflowExecutionContext()
assert engine._evaluate_condition("True", context) is True
@ -381,6 +341,7 @@ class TestWorkflowEngine:
def test_evaluate_condition_false(self):
def tool_executor(tool_name, args):
return "result"
engine = WorkflowEngine(tool_executor)
context = WorkflowExecutionContext()
assert engine._evaluate_condition("False", context) is False
@ -388,6 +349,7 @@ class TestWorkflowEngine:
def test_evaluate_condition_with_variables(self):
def tool_executor(tool_name, args):
return "result"
engine = WorkflowEngine(tool_executor)
context = WorkflowExecutionContext()
context.set_variable("test_var", "test_value")
@ -396,15 +358,12 @@ class TestWorkflowEngine:
def test_substitute_variables(self):
def tool_executor(tool_name, args):
return "result"
engine = WorkflowEngine(tool_executor)
context = WorkflowExecutionContext()
context.set_variable("var1", "value1")
context.set_step_result("step1", "result1")
arguments = {
"arg1": "${var.var1}",
"arg2": "${step.step1}",
"arg3": "plain_value"
}
arguments = {"arg1": "${var.var1}", "arg2": "${step.step1}", "arg3": "plain_value"}
substituted = engine._substitute_variables(arguments, context)
assert substituted["arg1"] == "value1"
assert substituted["arg2"] == "result1"
@ -412,16 +371,14 @@ class TestWorkflowEngine:
def test_execute_step_success(self):
executed = []
def tool_executor(tool_name, args):
executed.append((tool_name, args))
return "success_result"
engine = WorkflowEngine(tool_executor)
context = WorkflowExecutionContext()
step = WorkflowStep(
tool_name="test_tool",
arguments={"arg": "value"},
step_id="step1"
)
step = WorkflowStep(tool_name="test_tool", arguments={"arg": "value"}, step_id="step1")
result = engine._execute_step(step, context)
assert result["status"] == "success"
assert result["step_id"] == "step1"
@ -431,16 +388,15 @@ class TestWorkflowEngine:
def test_execute_step_skipped(self):
executed = []
def tool_executor(tool_name, args):
executed.append((tool_name, args))
return "result"
engine = WorkflowEngine(tool_executor)
context = WorkflowExecutionContext()
step = WorkflowStep(
tool_name="test_tool",
arguments={"arg": "value"},
step_id="step1",
condition="False"
tool_name="test_tool", arguments={"arg": "value"}, step_id="step1", condition="False"
)
result = engine._execute_step(step, context)
assert result["status"] == "skipped"
@ -449,18 +405,17 @@ class TestWorkflowEngine:
def test_execute_step_failed_with_retry(self):
executed = []
def tool_executor(tool_name, args):
executed.append((tool_name, args))
if len(executed) < 2:
raise Exception("Temporary failure")
return "success_result"
engine = WorkflowEngine(tool_executor)
context = WorkflowExecutionContext()
step = WorkflowStep(
tool_name="test_tool",
arguments={"arg": "value"},
step_id="step1",
retry_count=1
tool_name="test_tool", arguments={"arg": "value"}, step_id="step1", retry_count=1
)
result = engine._execute_step(step, context)
assert result["status"] == "success"
@ -469,16 +424,15 @@ class TestWorkflowEngine:
def test_execute_step_failed(self):
executed = []
def tool_executor(tool_name, args):
executed.append((tool_name, args))
raise Exception("Permanent failure")
engine = WorkflowEngine(tool_executor)
context = WorkflowExecutionContext()
step = WorkflowStep(
tool_name="test_tool",
arguments={"arg": "value"},
step_id="step1",
retry_count=1
tool_name="test_tool", arguments={"arg": "value"}, step_id="step1", retry_count=1
)
result = engine._execute_step(step, context)
assert result["status"] == "failed"
@ -488,6 +442,7 @@ class TestWorkflowEngine:
def test_get_next_steps_sequential(self):
def tool_executor(tool_name, args):
return "result"
engine = WorkflowEngine(tool_executor)
step1 = WorkflowStep(tool_name="tool1", arguments={}, step_id="step1")
step2 = WorkflowStep(tool_name="tool2", arguments={}, step_id="step2")
@ -495,7 +450,7 @@ class TestWorkflowEngine:
name="test",
description="test",
steps=[step1, step2],
execution_mode=ExecutionMode.SEQUENTIAL
execution_mode=ExecutionMode.SEQUENTIAL,
)
result = {"status": "success"}
next_steps = engine._get_next_steps(step1, result, workflow)
@ -504,28 +459,22 @@ class TestWorkflowEngine:
def test_get_next_steps_on_success(self):
def tool_executor(tool_name, args):
return "result"
engine = WorkflowEngine(tool_executor)
step1 = WorkflowStep(
tool_name="tool1",
arguments={},
step_id="step1",
on_success=["step2"]
)
step1 = WorkflowStep(tool_name="tool1", arguments={}, step_id="step1", on_success=["step2"])
step2 = WorkflowStep(tool_name="tool2", arguments={}, step_id="step2")
workflow = Workflow(
name="test",
description="test",
steps=[step1, step2]
)
workflow = Workflow(name="test", description="test", steps=[step1, step2])
result = {"status": "success"}
next_steps = engine._get_next_steps(step1, result, workflow)
assert next_steps == [step2]
def test_execute_workflow_sequential(self):
executed = []
def tool_executor(tool_name, args):
executed.append(tool_name)
return f"result_{tool_name}"
engine = WorkflowEngine(tool_executor)
step1 = WorkflowStep(tool_name="tool1", arguments={}, step_id="step1")
step2 = WorkflowStep(tool_name="tool2", arguments={}, step_id="step2")
@ -533,7 +482,7 @@ class TestWorkflowEngine:
name="test",
description="test",
steps=[step1, step2],
execution_mode=ExecutionMode.SEQUENTIAL
execution_mode=ExecutionMode.SEQUENTIAL,
)
context = engine.execute_workflow(workflow)
assert executed == ["tool1", "tool2"]
@ -542,9 +491,11 @@ class TestWorkflowEngine:
def test_execute_workflow_parallel(self):
executed = []
def tool_executor(tool_name, args):
executed.append(tool_name)
return f"result_{tool_name}"
engine = WorkflowEngine(tool_executor)
step1 = WorkflowStep(tool_name="tool1", arguments={}, step_id="step1")
step2 = WorkflowStep(tool_name="tool2", arguments={}, step_id="step2")
@ -552,7 +503,7 @@ class TestWorkflowEngine:
name="test",
description="test",
steps=[step1, step2],
execution_mode=ExecutionMode.PARALLEL
execution_mode=ExecutionMode.PARALLEL,
)
context = engine.execute_workflow(workflow)
assert set(executed) == {"tool1", "tool2"}