feat: update reasoning and task completion markers

feat: enable autonomous mode by default
refactor: improve assistant class structure
maintenance: update pyproject.toml version to 1.53.0
fix: handle edge cases in autonomous mode content extraction
docs: clarify autonomous mode deprecation in command line arguments
This commit is contained in:
retoor 2025-11-10 10:33:31 +01:00
parent 63c2f52885
commit 20668d9086
15 changed files with 249 additions and 184 deletions

View File

@ -1,6 +1,53 @@
# Changelog # Changelog
## Version 1.52.0 - 2025-11-10
This release updates the project version to 1.52.0. No new features or changes are introduced for users or developers.
**Changes:** 1 files, 2 lines
**Languages:** TOML (2 lines)
## Version 1.51.0 - 2025-11-10
The system can now extract and clean reasoning steps during task completion. Autonomous mode has been updated to recognize these reasoning steps and task completion markers, improving overall performance.
**Changes:** 5 files, 65 lines
**Languages:** Python (63 lines), TOML (2 lines)
## Version 1.50.0 - 2025-11-09
### Added
- **LLM Reasoning Display**: The assistant now displays its reasoning process before each response
- Added `REASONING:` prefix instruction in system prompt
- Reasoning is extracted and displayed with a blue thought bubble icon
- Provides transparency into the assistant's decision-making process
- **Task Completion Marker**: Implemented `[TASK_COMPLETE]` marker for explicit task completion signaling
- LLM can now mark tasks as complete with a special marker
- Marker is stripped from user-facing output
- Autonomous mode detection recognizes the marker for faster completion
- Reduces unnecessary iterations when tasks are finished
### Changed
- Updated system prompt in `context.py` to include response format instructions
- Enhanced `process_response_autonomous()` to extract and display reasoning
- Modified `is_task_complete()` to recognize `[TASK_COMPLETE]` marker
- Both autonomous and regular modes now support reasoning display
**Changes:** 3 files, 52 lines
**Languages:** Python (52 lines)
## Version 1.49.0 - 2025-11-09
Autonomous mode is now enabled by default, improving performance. Identical messages are now removed in autonomous mode to prevent redundancy.
**Changes:** 3 files, 28 lines
**Languages:** Markdown (18 lines), Python (8 lines), TOML (2 lines)
## Version 1.48.1 - 2025-11-09 ## Version 1.48.1 - 2025-11-09
### Fixed ### Fixed

View File

@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
[project] [project]
name = "rp" name = "rp"
version = "1.51.0" version = "1.52.0"
description = "R python edition. The ultimate autonomous AI CLI." description = "R python edition. The ultimate autonomous AI CLI."
readme = "README.md" readme = "README.md"
requires-python = ">=3.10" requires-python = ">=3.10"

View File

@ -45,7 +45,12 @@ Commands in interactive mode:
parser.add_argument("-u", "--api-url", help="API endpoint URL") parser.add_argument("-u", "--api-url", help="API endpoint URL")
parser.add_argument("--model-list-url", help="Model list endpoint URL") parser.add_argument("--model-list-url", help="Model list endpoint URL")
parser.add_argument("-i", "--interactive", action="store_true", help="Interactive mode") parser.add_argument("-i", "--interactive", action="store_true", help="Interactive mode")
parser.add_argument("-a", "--autonomous", action="store_true", help="Autonomous mode (now default, this flag is deprecated)") parser.add_argument(
"-a",
"--autonomous",
action="store_true",
help="Autonomous mode (now default, this flag is deprecated)",
)
parser.add_argument("-v", "--verbose", action="store_true", help="Verbose output") parser.add_argument("-v", "--verbose", action="store_true", help="Verbose output")
parser.add_argument( parser.add_argument(
"--debug", action="store_true", help="Enable debug mode with detailed logging" "--debug", action="store_true", help="Enable debug mode with detailed logging"

View File

@ -21,17 +21,17 @@ def extract_reasoning_and_clean_content(content):
tuple: (reasoning, cleaned_content) tuple: (reasoning, cleaned_content)
""" """
reasoning = None reasoning = None
lines = content.split('\n') lines = content.split("\n")
cleaned_lines = [] cleaned_lines = []
for line in lines: for line in lines:
if line.strip().startswith('REASONING:'): if line.strip().startswith("REASONING:"):
reasoning = line.strip()[10:].strip() reasoning = line.strip()[10:].strip()
else: else:
cleaned_lines.append(line) cleaned_lines.append(line)
cleaned_content = '\n'.join(cleaned_lines) cleaned_content = "\n".join(cleaned_lines)
cleaned_content = cleaned_content.replace('[TASK_COMPLETE]', '').strip() cleaned_content = cleaned_content.replace("[TASK_COMPLETE]", "").strip()
return reasoning, cleaned_content return reasoning, cleaned_content
@ -128,7 +128,11 @@ def process_response_autonomous(assistant, response):
display_tool_call(func_name, arguments, status, result) display_tool_call(func_name, arguments, status, result)
sanitized_result = sanitize_for_json(result) sanitized_result = sanitize_for_json(result)
tool_results.append( tool_results.append(
{"tool_call_id": tool_call["id"], "role": "tool", "content": json.dumps(sanitized_result)} {
"tool_call_id": tool_call["id"],
"role": "tool",
"content": json.dumps(sanitized_result),
}
) )
for result in tool_results: for result in tool_results:
assistant.messages.append(result) assistant.messages.append(result)

View File

@ -2,4 +2,11 @@ from rp.core.api import call_api, list_models
from rp.core.assistant import Assistant from rp.core.assistant import Assistant
from rp.core.context import init_system_message, manage_context_window, get_context_content from rp.core.context import init_system_message, manage_context_window, get_context_content
__all__ = ["Assistant", "call_api", "list_models", "init_system_message", "manage_context_window", "get_context_content"] __all__ = [
"Assistant",
"call_api",
"list_models",
"init_system_message",
"manage_context_window",
"get_context_content",
]

View File

@ -6,9 +6,7 @@ import readline
import signal import signal
import sqlite3 import sqlite3
import sys import sys
import time
import traceback import traceback
import uuid
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from rp.commands import handle_command from rp.commands import handle_command
@ -68,7 +66,7 @@ from rp.tools.memory import (
from rp.tools.patch import apply_patch, create_diff, display_file_diff from rp.tools.patch import apply_patch, create_diff, display_file_diff
from rp.tools.python_exec import python_exec from rp.tools.python_exec import python_exec
from rp.tools.web import http_fetch, web_search, web_search_news from rp.tools.web import http_fetch, web_search, web_search_news
from rp.ui import Colors, Spinner, render_markdown from rp.ui import Colors, render_markdown
from rp.ui.progress import ProgressIndicator from rp.ui.progress import ProgressIndicator
logger = logging.getLogger("rp") logger = logging.getLogger("rp")
@ -112,6 +110,7 @@ class Assistant:
self.last_result = None self.last_result = None
self.init_database() self.init_database()
from rp.memory import KnowledgeStore, FactExtractor, GraphMemory from rp.memory import KnowledgeStore, FactExtractor, GraphMemory
self.knowledge_store = KnowledgeStore(DB_PATH, db_conn=self.db_conn) self.knowledge_store = KnowledgeStore(DB_PATH, db_conn=self.db_conn)
self.fact_extractor = FactExtractor() self.fact_extractor = FactExtractor()
self.graph_memory = GraphMemory(DB_PATH, db_conn=self.db_conn) self.graph_memory = GraphMemory(DB_PATH, db_conn=self.db_conn)
@ -127,7 +126,10 @@ class Assistant:
self.enhanced = None self.enhanced = None
from rp.config import BACKGROUND_MONITOR_ENABLED from rp.config import BACKGROUND_MONITOR_ENABLED
bg_enabled = os.environ.get("BACKGROUND_MONITOR", str(BACKGROUND_MONITOR_ENABLED)).lower() in ("1", "true", "yes")
bg_enabled = os.environ.get(
"BACKGROUND_MONITOR", str(BACKGROUND_MONITOR_ENABLED)
).lower() in ("1", "true", "yes")
if bg_enabled: if bg_enabled:
try: try:
@ -338,6 +340,7 @@ class Assistant:
content = message.get("content", "") content = message.get("content", "")
from rp.autonomous.mode import extract_reasoning_and_clean_content from rp.autonomous.mode import extract_reasoning_and_clean_content
reasoning, cleaned_content = extract_reasoning_and_clean_content(content) reasoning, cleaned_content = extract_reasoning_and_clean_content(content)
if reasoning: if reasoning:
@ -438,6 +441,7 @@ class Assistant:
# If cmd_result is None, it's not a special command, process with autonomous mode. # If cmd_result is None, it's not a special command, process with autonomous mode.
elif cmd_result is None: elif cmd_result is None:
from rp.autonomous import run_autonomous_mode from rp.autonomous import run_autonomous_mode
run_autonomous_mode(self, user_input) run_autonomous_mode(self, user_input)
except EOFError: except EOFError:
break break
@ -453,6 +457,7 @@ class Assistant:
else: else:
message = sys.stdin.read() message = sys.stdin.read()
from rp.autonomous import run_autonomous_mode from rp.autonomous import run_autonomous_mode
run_autonomous_mode(self, message) run_autonomous_mode(self, message)
def run_autonomous(self): def run_autonomous(self):

View File

@ -70,6 +70,7 @@ def get_context_content():
logging.error(f"Error reading context file {knowledge_file}: {e}") logging.error(f"Error reading context file {knowledge_file}: {e}")
return "\n\n".join(context_parts) return "\n\n".join(context_parts)
def init_system_message(args): def init_system_message(args):
context_parts = [ context_parts = [
"You are a professional AI assistant with access to advanced tools.", "You are a professional AI assistant with access to advanced tools.",

View File

@ -17,7 +17,9 @@ def inject_knowledge_context(assistant, user_message):
break break
try: try:
# Run all search methods # Run all search methods
knowledge_results = assistant.enhanced.knowledge_store.search_entries(user_message, top_k=5) # Hybrid semantic + keyword + category knowledge_results = assistant.enhanced.knowledge_store.search_entries(
user_message, top_k=5
) # Hybrid semantic + keyword + category
# Additional keyword search if needed (but already in hybrid) # Additional keyword search if needed (but already in hybrid)
# Category-specific: preferences and general # Category-specific: preferences and general
pref_results = assistant.enhanced.knowledge_store.get_by_category("preferences", limit=5) pref_results = assistant.enhanced.knowledge_store.get_by_category("preferences", limit=5)
@ -25,12 +27,14 @@ def inject_knowledge_context(assistant, user_message):
category_results = [] category_results = []
for entry in pref_results + general_results: for entry in pref_results + general_results:
if any(word in entry.content.lower() for word in user_message.lower().split()): if any(word in entry.content.lower() for word in user_message.lower().split()):
category_results.append({ category_results.append(
{
"content": entry.content, "content": entry.content,
"score": 0.6, "score": 0.6,
"source": f"Knowledge Base ({entry.category})", "source": f"Knowledge Base ({entry.category})",
"type": "knowledge_category", "type": "knowledge_category",
}) }
)
conversation_results = [] conversation_results = []
if hasattr(assistant.enhanced, "conversation_memory"): if hasattr(assistant.enhanced, "conversation_memory"):

View File

@ -4,87 +4,107 @@ import sqlite3
from typing import List, Optional, Set from typing import List, Optional, Set
from dataclasses import dataclass, field from dataclasses import dataclass, field
@dataclass @dataclass
class Entity: class Entity:
name: str name: str
entityType: str entityType: str
observations: List[str] observations: List[str]
@dataclass @dataclass
class Relation: class Relation:
from_: str = field(metadata={'alias': 'from'}) from_: str = field(metadata={"alias": "from"})
to: str to: str
relationType: str relationType: str
@dataclass @dataclass
class KnowledgeGraph: class KnowledgeGraph:
entities: List[Entity] entities: List[Entity]
relations: List[Relation] relations: List[Relation]
@dataclass @dataclass
class CreateEntitiesRequest: class CreateEntitiesRequest:
entities: List[Entity] entities: List[Entity]
@dataclass @dataclass
class CreateRelationsRequest: class CreateRelationsRequest:
relations: List[Relation] relations: List[Relation]
@dataclass @dataclass
class ObservationItem: class ObservationItem:
entityName: str entityName: str
contents: List[str] contents: List[str]
@dataclass @dataclass
class AddObservationsRequest: class AddObservationsRequest:
observations: List[ObservationItem] observations: List[ObservationItem]
@dataclass @dataclass
class DeletionItem: class DeletionItem:
entityName: str entityName: str
observations: List[str] observations: List[str]
@dataclass @dataclass
class DeleteObservationsRequest: class DeleteObservationsRequest:
deletions: List[DeletionItem] deletions: List[DeletionItem]
@dataclass @dataclass
class DeleteEntitiesRequest: class DeleteEntitiesRequest:
entityNames: List[str] entityNames: List[str]
@dataclass @dataclass
class DeleteRelationsRequest: class DeleteRelationsRequest:
relations: List[Relation] relations: List[Relation]
@dataclass @dataclass
class SearchNodesRequest: class SearchNodesRequest:
query: str query: str
@dataclass @dataclass
class OpenNodesRequest: class OpenNodesRequest:
names: List[str] names: List[str]
depth: int = 1 depth: int = 1
@dataclass @dataclass
class PopulateRequest: class PopulateRequest:
text: str text: str
class GraphMemory: class GraphMemory:
def __init__(self, db_path: str = 'graph_memory.db', db_conn: Optional[sqlite3.Connection] = None): def __init__(
self, db_path: str = "graph_memory.db", db_conn: Optional[sqlite3.Connection] = None
):
self.db_path = db_path self.db_path = db_path
self.conn = db_conn if db_conn else sqlite3.connect(self.db_path, check_same_thread=False) self.conn = db_conn if db_conn else sqlite3.connect(self.db_path, check_same_thread=False)
self.init_db() self.init_db()
def init_db(self): def init_db(self):
cursor = self.conn.cursor() cursor = self.conn.cursor()
cursor.execute(''' cursor.execute(
"""
CREATE TABLE IF NOT EXISTS entities ( CREATE TABLE IF NOT EXISTS entities (
id INTEGER PRIMARY KEY, id INTEGER PRIMARY KEY,
name TEXT UNIQUE, name TEXT UNIQUE,
entity_type TEXT, entity_type TEXT,
observations TEXT observations TEXT
) )
''') """
cursor.execute(''' )
cursor.execute(
"""
CREATE TABLE IF NOT EXISTS relations ( CREATE TABLE IF NOT EXISTS relations (
id INTEGER PRIMARY KEY, id INTEGER PRIMARY KEY,
from_entity TEXT, from_entity TEXT,
@ -92,7 +112,8 @@ class GraphMemory:
relation_type TEXT, relation_type TEXT,
UNIQUE(from_entity, to_entity, relation_type) UNIQUE(from_entity, to_entity, relation_type)
) )
''') """
)
self.conn.commit() self.conn.commit()
def create_entities(self, entities: List[Entity]) -> List[Entity]: def create_entities(self, entities: List[Entity]) -> List[Entity]:
@ -101,8 +122,10 @@ class GraphMemory:
cursor = conn.cursor() cursor = conn.cursor()
for e in entities: for e in entities:
try: try:
cursor.execute('INSERT INTO entities (name, entity_type, observations) VALUES (?, ?, ?)', cursor.execute(
(e.name, e.entityType, json.dumps(e.observations))) "INSERT INTO entities (name, entity_type, observations) VALUES (?, ?, ?)",
(e.name, e.entityType, json.dumps(e.observations)),
)
new_entities.append(e) new_entities.append(e)
except sqlite3.IntegrityError: except sqlite3.IntegrityError:
pass # already exists pass # already exists
@ -115,8 +138,10 @@ class GraphMemory:
cursor = conn.cursor() cursor = conn.cursor()
for r in relations: for r in relations:
try: try:
cursor.execute('INSERT INTO relations (from_entity, to_entity, relation_type) VALUES (?, ?, ?)', cursor.execute(
(r.from_, r.to, r.relationType)) "INSERT INTO relations (from_entity, to_entity, relation_type) VALUES (?, ?, ?)",
(r.from_, r.to, r.relationType),
)
new_relations.append(r) new_relations.append(r)
except sqlite3.IntegrityError: except sqlite3.IntegrityError:
pass # already exists pass # already exists
@ -130,7 +155,7 @@ class GraphMemory:
for obs in observations: for obs in observations:
name = obs.entityName.lower() name = obs.entityName.lower()
contents = obs.contents contents = obs.contents
cursor.execute('SELECT observations FROM entities WHERE LOWER(name) = ?', (name,)) cursor.execute("SELECT observations FROM entities WHERE LOWER(name) = ?", (name,))
row = cursor.fetchone() row = cursor.fetchone()
if not row: if not row:
# Log the error instead of raising an exception # Log the error instead of raising an exception
@ -139,7 +164,10 @@ class GraphMemory:
current_obs = json.loads(row[0]) if row[0] else [] current_obs = json.loads(row[0]) if row[0] else []
added = [c for c in contents if c not in current_obs] added = [c for c in contents if c not in current_obs]
current_obs.extend(added) current_obs.extend(added)
cursor.execute('UPDATE entities SET observations = ? WHERE LOWER(name) = ?', (json.dumps(current_obs), name)) cursor.execute(
"UPDATE entities SET observations = ? WHERE LOWER(name) = ?",
(json.dumps(current_obs), name),
)
results.append({"entityName": name, "addedObservations": added}) results.append({"entityName": name, "addedObservations": added})
conn.commit() conn.commit()
return results return results
@ -148,11 +176,16 @@ class GraphMemory:
conn = self.conn conn = self.conn
cursor = conn.cursor() cursor = conn.cursor()
# delete entities # delete entities
cursor.executemany('DELETE FROM entities WHERE LOWER(name) = ?', [(n.lower(),) for n in entity_names]) cursor.executemany(
"DELETE FROM entities WHERE LOWER(name) = ?", [(n.lower(),) for n in entity_names]
)
# delete relations involving them # delete relations involving them
placeholders = ','.join('?' * len(entity_names)) placeholders = ",".join("?" * len(entity_names))
params = [n.lower() for n in entity_names] * 2 params = [n.lower() for n in entity_names] * 2
cursor.execute(f'DELETE FROM relations WHERE LOWER(from_entity) IN ({placeholders}) OR LOWER(to_entity) IN ({placeholders})', params) cursor.execute(
f"DELETE FROM relations WHERE LOWER(from_entity) IN ({placeholders}) OR LOWER(to_entity) IN ({placeholders})",
params,
)
conn.commit() conn.commit()
def delete_observations(self, deletions: List[DeletionItem]): def delete_observations(self, deletions: List[DeletionItem]):
@ -161,20 +194,25 @@ class GraphMemory:
for del_item in deletions: for del_item in deletions:
name = del_item.entityName.lower() name = del_item.entityName.lower()
to_delete = del_item.observations to_delete = del_item.observations
cursor.execute('SELECT observations FROM entities WHERE LOWER(name) = ?', (name,)) cursor.execute("SELECT observations FROM entities WHERE LOWER(name) = ?", (name,))
row = cursor.fetchone() row = cursor.fetchone()
if row: if row:
current_obs = json.loads(row[0]) if row[0] else [] current_obs = json.loads(row[0]) if row[0] else []
current_obs = [obs for obs in current_obs if obs not in to_delete] current_obs = [obs for obs in current_obs if obs not in to_delete]
cursor.execute('UPDATE entities SET observations = ? WHERE LOWER(name) = ?', (json.dumps(current_obs), name)) cursor.execute(
"UPDATE entities SET observations = ? WHERE LOWER(name) = ?",
(json.dumps(current_obs), name),
)
conn.commit() conn.commit()
def delete_relations(self, relations: List[Relation]): def delete_relations(self, relations: List[Relation]):
conn = self.conn conn = self.conn
cursor = conn.cursor() cursor = conn.cursor()
for r in relations: for r in relations:
cursor.execute('DELETE FROM relations WHERE LOWER(from_entity) = ? AND LOWER(to_entity) = ? AND LOWER(relation_type) = ?', cursor.execute(
(r.from_.lower(), r.to.lower(), r.relationType.lower())) "DELETE FROM relations WHERE LOWER(from_entity) = ? AND LOWER(to_entity) = ? AND LOWER(relation_type) = ?",
(r.from_.lower(), r.to.lower(), r.relationType.lower()),
)
conn.commit() conn.commit()
def read_graph(self) -> KnowledgeGraph: def read_graph(self) -> KnowledgeGraph:
@ -182,12 +220,12 @@ class GraphMemory:
relations = [] relations = []
conn = self.conn conn = self.conn
cursor = conn.cursor() cursor = conn.cursor()
cursor.execute('SELECT name, entity_type, observations FROM entities') cursor.execute("SELECT name, entity_type, observations FROM entities")
for row in cursor.fetchall(): for row in cursor.fetchall():
name, etype, obs = row name, etype, obs = row
observations = json.loads(obs) if obs else [] observations = json.loads(obs) if obs else []
entities.append(Entity(name=name, entityType=etype, observations=observations)) entities.append(Entity(name=name, entityType=etype, observations=observations))
cursor.execute('SELECT from_entity, to_entity, relation_type FROM relations') cursor.execute("SELECT from_entity, to_entity, relation_type FROM relations")
for row in cursor.fetchall(): for row in cursor.fetchall():
relations.append(Relation(from_=row[0], to=row[1], relationType=row[2])) relations.append(Relation(from_=row[0], to=row[1], relationType=row[2]))
return KnowledgeGraph(entities=entities, relations=relations) return KnowledgeGraph(entities=entities, relations=relations)
@ -197,17 +235,19 @@ class GraphMemory:
conn = self.conn conn = self.conn
cursor = conn.cursor() cursor = conn.cursor()
query_lower = query.lower() query_lower = query.lower()
cursor.execute('SELECT name, entity_type, observations FROM entities') cursor.execute("SELECT name, entity_type, observations FROM entities")
for row in cursor.fetchall(): for row in cursor.fetchall():
name, etype, obs = row name, etype, obs = row
observations = json.loads(obs) if obs else [] observations = json.loads(obs) if obs else []
if (query_lower in name.lower() or if (
query_lower in etype.lower() or query_lower in name.lower()
any(query_lower in o.lower() for o in observations)): or query_lower in etype.lower()
or any(query_lower in o.lower() for o in observations)
):
entities.append(Entity(name=name, entityType=etype, observations=observations)) entities.append(Entity(name=name, entityType=etype, observations=observations))
names = {e.name.lower() for e in entities} names = {e.name.lower() for e in entities}
relations = [] relations = []
cursor.execute('SELECT from_entity, to_entity, relation_type FROM relations') cursor.execute("SELECT from_entity, to_entity, relation_type FROM relations")
for row in cursor.fetchall(): for row in cursor.fetchall():
if row[0].lower() in names and row[1].lower() in names: if row[0].lower() in names and row[1].lower() in names:
relations.append(Relation(from_=row[0], to=row[1], relationType=row[2])) relations.append(Relation(from_=row[0], to=row[1], relationType=row[2]))
@ -221,13 +261,16 @@ class GraphMemory:
def traverse(current_names: List[str], current_depth: int): def traverse(current_names: List[str], current_depth: int):
if current_depth > depth: if current_depth > depth:
return return
name_set = {n.lower() for n in current_names} {n.lower() for n in current_names}
new_entities = [] new_entities = []
conn = self.conn conn = self.conn
cursor = conn.cursor() cursor = conn.cursor()
placeholders = ','.join('?' * len(current_names)) placeholders = ",".join("?" * len(current_names))
params = [n.lower() for n in current_names] params = [n.lower() for n in current_names]
cursor.execute(f'SELECT name, entity_type, observations FROM entities WHERE LOWER(name) IN ({placeholders})', params) cursor.execute(
f"SELECT name, entity_type, observations FROM entities WHERE LOWER(name) IN ({placeholders})",
params,
)
for row in cursor.fetchall(): for row in cursor.fetchall():
name, etype, obs = row name, etype, obs = row
if name.lower() not in visited: if name.lower() not in visited:
@ -237,9 +280,12 @@ class GraphMemory:
new_entities.append(entity) new_entities.append(entity)
entities.append(entity) entities.append(entity)
# Find relations involving these entities # Find relations involving these entities
placeholders = ','.join('?' * len(new_entities)) placeholders = ",".join("?" * len(new_entities))
params = [e.name.lower() for e in new_entities] * 2 params = [e.name.lower() for e in new_entities] * 2
cursor.execute(f'SELECT from_entity, to_entity, relation_type FROM relations WHERE LOWER(from_entity) IN ({placeholders}) OR LOWER(to_entity) IN ({placeholders})', params) cursor.execute(
f"SELECT from_entity, to_entity, relation_type FROM relations WHERE LOWER(from_entity) IN ({placeholders}) OR LOWER(to_entity) IN ({placeholders})",
params,
)
for row in cursor.fetchall(): for row in cursor.fetchall():
rel = Relation(from_=row[0], to=row[1], relationType=row[2]) rel = Relation(from_=row[0], to=row[1], relationType=row[2])
if rel not in relations: if rel not in relations:
@ -254,24 +300,24 @@ class GraphMemory:
def populate_from_text(self, text: str): def populate_from_text(self, text: str):
# Algorithm: Extract entities as capitalized words, relations from patterns, observations from sentences mentioning entities # Algorithm: Extract entities as capitalized words, relations from patterns, observations from sentences mentioning entities
entities = set(re.findall(r'\b[A-Z][a-zA-Z]*\b', text)) entities = set(re.findall(r"\b[A-Z][a-zA-Z]*\b", text))
for entity in entities: for entity in entities:
self.create_entities([Entity(name=entity, entityType='unknown', observations=[])]) self.create_entities([Entity(name=entity, entityType="unknown", observations=[])])
# Add the text as observation if it mentions the entity # Add the text as observation if it mentions the entity
self.add_observations([ObservationItem(entityName=entity, contents=[text])]) self.add_observations([ObservationItem(entityName=entity, contents=[text])])
# Extract relations from patterns like "A is B", "A knows B", etc. # Extract relations from patterns like "A is B", "A knows B", etc.
patterns = [ patterns = [
(r'(\w+) is (a|an) (\w+)', 'is_a'), (r"(\w+) is (a|an) (\w+)", "is_a"),
(r'(\w+) knows (\w+)', 'knows'), (r"(\w+) knows (\w+)", "knows"),
(r'(\w+) works at (\w+)', 'works_at'), (r"(\w+) works at (\w+)", "works_at"),
(r'(\w+) lives in (\w+)', 'lives_in'), (r"(\w+) lives in (\w+)", "lives_in"),
(r'(\w+) is (\w+)', 'is'), # general (r"(\w+) is (\w+)", "is"), # general
] ]
for pattern, rel_type in patterns: for pattern, rel_type in patterns:
matches = re.findall(pattern, text, re.IGNORECASE) matches = re.findall(pattern, text, re.IGNORECASE)
for match in matches: for match in matches:
if len(match) == 3 and match[1].lower() in ['a', 'an']: if len(match) == 3 and match[1].lower() in ["a", "an"]:
from_e, _, to_e = match from_e, _, to_e = match
elif len(match) == 2: elif len(match) == 2:
from_e, to_e = match from_e, to_e = match
@ -280,9 +326,10 @@ class GraphMemory:
if from_e in entities and to_e in entities: if from_e in entities and to_e in entities:
self.create_relations([Relation(from_=from_e, to=to_e, relationType=rel_type)]) self.create_relations([Relation(from_=from_e, to=to_e, relationType=rel_type)])
elif from_e in entities: elif from_e in entities:
self.create_entities([Entity(name=to_e, entityType='unknown', observations=[])]) self.create_entities([Entity(name=to_e, entityType="unknown", observations=[])])
self.create_relations([Relation(from_=from_e, to=to_e, relationType=rel_type)]) self.create_relations([Relation(from_=from_e, to=to_e, relationType=rel_type)])
elif to_e in entities: elif to_e in entities:
self.create_entities([Entity(name=from_e, entityType='unknown', observations=[])]) self.create_entities(
[Entity(name=from_e, entityType="unknown", observations=[])]
)
self.create_relations([Relation(from_=from_e, to=to_e, relationType=rel_type)]) self.create_relations([Relation(from_=from_e, to=to_e, relationType=rel_type)])

View File

@ -20,7 +20,7 @@ class KnowledgeEntry:
importance_score: float = 1.0 importance_score: float = 1.0
def __str__(self): def __str__(self):
return json.dumps(self.to_dict(), indent=4, sort_keys=True,default=str) return json.dumps(self.to_dict(), indent=4, sort_keys=True, default=str)
def to_dict(self) -> Dict[str, Any]: def to_dict(self) -> Dict[str, Any]:
return { return {

View File

@ -1,4 +1,3 @@
import os
from pathlib import Path from pathlib import Path
from typing import Optional from typing import Optional

View File

@ -4,10 +4,10 @@ from rp.ui.colors import Colors
def display_tool_call(tool_name, arguments, status="running", result=None): def display_tool_call(tool_name, arguments, status="running", result=None):
if status == "running": if status == "running":
return return
args_str = ", ".join([f"{k}={str(v)[:20]}" for k, v in list(arguments.items())[:2]]) args_str = ", ".join([f"{k}={repr(v)}" for k, v in arguments.items()])
line = f"{tool_name}({args_str})" line = f"{tool_name}({args_str})"
if len(line) > 80: if len(line) > 120:
line = line[:77] + "..." line = line[:117] + "..."
print(f"{Colors.GRAY}{line}{Colors.RESET}") print(f"{Colors.GRAY}{line}{Colors.RESET}")

View File

@ -114,10 +114,6 @@ class TestAssistant(unittest.TestCase):
process_message(assistant, "test message") process_message(assistant, "test message")
from rp.memory import KnowledgeEntry from rp.memory import KnowledgeEntry
import json
import time
import uuid
from unittest.mock import ANY
# Mock time.time() and uuid.uuid4() to return consistent values # Mock time.time() and uuid.uuid4() to return consistent values
expected_entry = KnowledgeEntry( expected_entry = KnowledgeEntry(
@ -132,7 +128,7 @@ class TestAssistant(unittest.TestCase):
created_at=1234567890.123456, created_at=1234567890.123456,
updated_at=1234567890.123456, updated_at=1234567890.123456,
) )
expected_content = str(expected_entry) str(expected_entry)
assistant.knowledge_store.add_entry.assert_called_once_with(expected_entry) assistant.knowledge_store.add_entry.assert_called_once_with(expected_entry)

View File

@ -696,4 +696,3 @@ class TestShowBackgroundEvents:
def test_show_events_exception(self, mock_get): def test_show_events_exception(self, mock_get):
mock_get.side_effect = Exception("test") mock_get.side_effect = Exception("test")
show_background_events(self.assistant) show_background_events(self.assistant)

View File

@ -1,4 +1,3 @@
import pytest
import tempfile import tempfile
import os import os
from rp.workflows.workflow_definition import ExecutionMode, Workflow, WorkflowStep from rp.workflows.workflow_definition import ExecutionMode, Workflow, WorkflowStep
@ -16,7 +15,7 @@ class TestWorkflowStep:
on_success=["step2"], on_success=["step2"],
on_failure=["step3"], on_failure=["step3"],
retry_count=2, retry_count=2,
timeout_seconds=600 timeout_seconds=600,
) )
assert step.tool_name == "test_tool" assert step.tool_name == "test_tool"
assert step.arguments == {"arg1": "value1"} assert step.arguments == {"arg1": "value1"}
@ -28,11 +27,7 @@ class TestWorkflowStep:
assert step.timeout_seconds == 600 assert step.timeout_seconds == 600
def test_to_dict(self): def test_to_dict(self):
step = WorkflowStep( step = WorkflowStep(tool_name="test_tool", arguments={"arg1": "value1"}, step_id="step1")
tool_name="test_tool",
arguments={"arg1": "value1"},
step_id="step1"
)
expected = { expected = {
"tool_name": "test_tool", "tool_name": "test_tool",
"arguments": {"arg1": "value1"}, "arguments": {"arg1": "value1"},
@ -77,7 +72,7 @@ class TestWorkflow:
steps=[step1, step2], steps=[step1, step2],
execution_mode=ExecutionMode.PARALLEL, execution_mode=ExecutionMode.PARALLEL,
variables={"var1": "value1"}, variables={"var1": "value1"},
tags=["tag1", "tag2"] tags=["tag1", "tag2"],
) )
assert workflow.name == "test_workflow" assert workflow.name == "test_workflow"
assert workflow.description == "A test workflow" assert workflow.description == "A test workflow"
@ -94,7 +89,7 @@ class TestWorkflow:
steps=[step1], steps=[step1],
execution_mode=ExecutionMode.SEQUENTIAL, execution_mode=ExecutionMode.SEQUENTIAL,
variables={"var1": "value1"}, variables={"var1": "value1"},
tags=["tag1"] tags=["tag1"],
) )
expected = { expected = {
"name": "test_workflow", "name": "test_workflow",
@ -110,7 +105,8 @@ class TestWorkflow:
data = { data = {
"name": "test_workflow", "name": "test_workflow",
"description": "A test workflow", "description": "A test workflow",
"steps": [{ "steps": [
{
"tool_name": "tool1", "tool_name": "tool1",
"arguments": {}, "arguments": {},
"step_id": "step1", "step_id": "step1",
@ -119,7 +115,8 @@ class TestWorkflow:
"on_failure": None, "on_failure": None,
"retry_count": 0, "retry_count": 0,
"timeout_seconds": 300, "timeout_seconds": 300,
}], }
],
"execution_mode": "parallel", "execution_mode": "parallel",
"variables": {"var1": "value1"}, "variables": {"var1": "value1"},
"tags": ["tag1"], "tags": ["tag1"],
@ -134,11 +131,7 @@ class TestWorkflow:
assert workflow.tags == ["tag1"] assert workflow.tags == ["tag1"]
def test_add_step(self): def test_add_step(self):
workflow = Workflow( workflow = Workflow(name="test_workflow", description="A test workflow", steps=[])
name="test_workflow",
description="A test workflow",
steps=[]
)
step = WorkflowStep(tool_name="tool1", arguments={}, step_id="step1") step = WorkflowStep(tool_name="tool1", arguments={}, step_id="step1")
workflow.add_step(step) workflow.add_step(step)
assert workflow.steps == [step] assert workflow.steps == [step]
@ -147,9 +140,7 @@ class TestWorkflow:
step1 = WorkflowStep(tool_name="tool1", arguments={}, step_id="step1") step1 = WorkflowStep(tool_name="tool1", arguments={}, step_id="step1")
step2 = WorkflowStep(tool_name="tool2", arguments={}, step_id="step2") step2 = WorkflowStep(tool_name="tool2", arguments={}, step_id="step2")
workflow = Workflow( workflow = Workflow(
name="test_workflow", name="test_workflow", description="A test workflow", steps=[step1, step2]
description="A test workflow",
steps=[step1, step2]
) )
assert workflow.get_step("step1") == step1 assert workflow.get_step("step1") == step1
assert workflow.get_step("step2") == step2 assert workflow.get_step("step2") == step2
@ -162,7 +153,7 @@ class TestWorkflow:
name="test_workflow", name="test_workflow",
description="A test workflow", description="A test workflow",
steps=[step1, step2], steps=[step1, step2],
execution_mode=ExecutionMode.SEQUENTIAL execution_mode=ExecutionMode.SEQUENTIAL,
) )
assert workflow.get_initial_steps() == [step1] assert workflow.get_initial_steps() == [step1]
@ -173,7 +164,7 @@ class TestWorkflow:
name="test_workflow", name="test_workflow",
description="A test workflow", description="A test workflow",
steps=[step1, step2], steps=[step1, step2],
execution_mode=ExecutionMode.PARALLEL execution_mode=ExecutionMode.PARALLEL,
) )
assert workflow.get_initial_steps() == [step1, step2] assert workflow.get_initial_steps() == [step1, step2]
@ -184,7 +175,7 @@ class TestWorkflow:
name="test_workflow", name="test_workflow",
description="A test workflow", description="A test workflow",
steps=[step1, step2], steps=[step1, step2],
execution_mode=ExecutionMode.CONDITIONAL execution_mode=ExecutionMode.CONDITIONAL,
) )
assert workflow.get_initial_steps() == [step2] # Only step without condition assert workflow.get_initial_steps() == [step2] # Only step without condition
@ -200,11 +191,7 @@ class TestWorkflowStorage:
def test_save_and_load_workflow(self): def test_save_and_load_workflow(self):
step = WorkflowStep(tool_name="tool1", arguments={}, step_id="step1") step = WorkflowStep(tool_name="tool1", arguments={}, step_id="step1")
workflow = Workflow( workflow = Workflow(name="test_workflow", description="A test workflow", steps=[step])
name="test_workflow",
description="A test workflow",
steps=[step]
)
workflow_id = self.storage.save_workflow(workflow) workflow_id = self.storage.save_workflow(workflow)
loaded = self.storage.load_workflow(workflow_id) loaded = self.storage.load_workflow(workflow_id)
assert loaded is not None assert loaded is not None
@ -219,11 +206,7 @@ class TestWorkflowStorage:
def test_load_workflow_by_name(self): def test_load_workflow_by_name(self):
step = WorkflowStep(tool_name="tool1", arguments={}, step_id="step1") step = WorkflowStep(tool_name="tool1", arguments={}, step_id="step1")
workflow = Workflow( workflow = Workflow(name="test_workflow", description="A test workflow", steps=[step])
name="test_workflow",
description="A test workflow",
steps=[step]
)
self.storage.save_workflow(workflow) self.storage.save_workflow(workflow)
loaded = self.storage.load_workflow_by_name("test_workflow") loaded = self.storage.load_workflow_by_name("test_workflow")
assert loaded is not None assert loaded is not None
@ -236,10 +219,7 @@ class TestWorkflowStorage:
def test_list_workflows(self): def test_list_workflows(self):
step = WorkflowStep(tool_name="tool1", arguments={}, step_id="step1") step = WorkflowStep(tool_name="tool1", arguments={}, step_id="step1")
workflow = Workflow( workflow = Workflow(
name="test_workflow", name="test_workflow", description="A test workflow", steps=[step], tags=["tag1"]
description="A test workflow",
steps=[step],
tags=["tag1"]
) )
self.storage.save_workflow(workflow) self.storage.save_workflow(workflow)
workflows = self.storage.list_workflows() workflows = self.storage.list_workflows()
@ -250,16 +230,10 @@ class TestWorkflowStorage:
def test_list_workflows_with_tag(self): def test_list_workflows_with_tag(self):
step = WorkflowStep(tool_name="tool1", arguments={}, step_id="step1") step = WorkflowStep(tool_name="tool1", arguments={}, step_id="step1")
workflow1 = Workflow( workflow1 = Workflow(
name="test_workflow1", name="test_workflow1", description="A test workflow", steps=[step], tags=["tag1"]
description="A test workflow",
steps=[step],
tags=["tag1"]
) )
workflow2 = Workflow( workflow2 = Workflow(
name="test_workflow2", name="test_workflow2", description="A test workflow", steps=[step], tags=["tag2"]
description="A test workflow",
steps=[step],
tags=["tag2"]
) )
self.storage.save_workflow(workflow1) self.storage.save_workflow(workflow1)
self.storage.save_workflow(workflow2) self.storage.save_workflow(workflow2)
@ -269,11 +243,7 @@ class TestWorkflowStorage:
def test_delete_workflow(self): def test_delete_workflow(self):
step = WorkflowStep(tool_name="tool1", arguments={}, step_id="step1") step = WorkflowStep(tool_name="tool1", arguments={}, step_id="step1")
workflow = Workflow( workflow = Workflow(name="test_workflow", description="A test workflow", steps=[step])
name="test_workflow",
description="A test workflow",
steps=[step]
)
workflow_id = self.storage.save_workflow(workflow) workflow_id = self.storage.save_workflow(workflow)
deleted = self.storage.delete_workflow(workflow_id) deleted = self.storage.delete_workflow(workflow_id)
assert deleted is True assert deleted is True
@ -286,11 +256,7 @@ class TestWorkflowStorage:
def test_save_execution(self): def test_save_execution(self):
step = WorkflowStep(tool_name="tool1", arguments={}, step_id="step1") step = WorkflowStep(tool_name="tool1", arguments={}, step_id="step1")
workflow = Workflow( workflow = Workflow(name="test_workflow", description="A test workflow", steps=[step])
name="test_workflow",
description="A test workflow",
steps=[step]
)
workflow_id = self.storage.save_workflow(workflow) workflow_id = self.storage.save_workflow(workflow)
context = WorkflowExecutionContext() context = WorkflowExecutionContext()
context.set_step_result("step1", "result") context.set_step_result("step1", "result")
@ -303,11 +269,7 @@ class TestWorkflowStorage:
def test_get_execution_history(self): def test_get_execution_history(self):
step = WorkflowStep(tool_name="tool1", arguments={}, step_id="step1") step = WorkflowStep(tool_name="tool1", arguments={}, step_id="step1")
workflow = Workflow( workflow = Workflow(name="test_workflow", description="A test workflow", steps=[step])
name="test_workflow",
description="A test workflow",
steps=[step]
)
workflow_id = self.storage.save_workflow(workflow) workflow_id = self.storage.save_workflow(workflow)
context = WorkflowExecutionContext() context = WorkflowExecutionContext()
context.set_step_result("step1", "result") context.set_step_result("step1", "result")
@ -318,11 +280,7 @@ class TestWorkflowStorage:
def test_get_execution_history_limit(self): def test_get_execution_history_limit(self):
step = WorkflowStep(tool_name="tool1", arguments={}, step_id="step1") step = WorkflowStep(tool_name="tool1", arguments={}, step_id="step1")
workflow = Workflow( workflow = Workflow(name="test_workflow", description="A test workflow", steps=[step])
name="test_workflow",
description="A test workflow",
steps=[step]
)
workflow_id = self.storage.save_workflow(workflow) workflow_id = self.storage.save_workflow(workflow)
for i in range(5): for i in range(5):
context = WorkflowExecutionContext() context = WorkflowExecutionContext()
@ -367,6 +325,7 @@ class TestWorkflowEngine:
def test_init(self): def test_init(self):
def tool_executor(tool_name, args): def tool_executor(tool_name, args):
return f"executed {tool_name} with {args}" return f"executed {tool_name} with {args}"
engine = WorkflowEngine(tool_executor, max_workers=10) engine = WorkflowEngine(tool_executor, max_workers=10)
assert engine.tool_executor == tool_executor assert engine.tool_executor == tool_executor
assert engine.max_workers == 10 assert engine.max_workers == 10
@ -374,6 +333,7 @@ class TestWorkflowEngine:
def test_evaluate_condition_true(self): def test_evaluate_condition_true(self):
def tool_executor(tool_name, args): def tool_executor(tool_name, args):
return "result" return "result"
engine = WorkflowEngine(tool_executor) engine = WorkflowEngine(tool_executor)
context = WorkflowExecutionContext() context = WorkflowExecutionContext()
assert engine._evaluate_condition("True", context) is True assert engine._evaluate_condition("True", context) is True
@ -381,6 +341,7 @@ class TestWorkflowEngine:
def test_evaluate_condition_false(self): def test_evaluate_condition_false(self):
def tool_executor(tool_name, args): def tool_executor(tool_name, args):
return "result" return "result"
engine = WorkflowEngine(tool_executor) engine = WorkflowEngine(tool_executor)
context = WorkflowExecutionContext() context = WorkflowExecutionContext()
assert engine._evaluate_condition("False", context) is False assert engine._evaluate_condition("False", context) is False
@ -388,6 +349,7 @@ class TestWorkflowEngine:
def test_evaluate_condition_with_variables(self): def test_evaluate_condition_with_variables(self):
def tool_executor(tool_name, args): def tool_executor(tool_name, args):
return "result" return "result"
engine = WorkflowEngine(tool_executor) engine = WorkflowEngine(tool_executor)
context = WorkflowExecutionContext() context = WorkflowExecutionContext()
context.set_variable("test_var", "test_value") context.set_variable("test_var", "test_value")
@ -396,15 +358,12 @@ class TestWorkflowEngine:
def test_substitute_variables(self): def test_substitute_variables(self):
def tool_executor(tool_name, args): def tool_executor(tool_name, args):
return "result" return "result"
engine = WorkflowEngine(tool_executor) engine = WorkflowEngine(tool_executor)
context = WorkflowExecutionContext() context = WorkflowExecutionContext()
context.set_variable("var1", "value1") context.set_variable("var1", "value1")
context.set_step_result("step1", "result1") context.set_step_result("step1", "result1")
arguments = { arguments = {"arg1": "${var.var1}", "arg2": "${step.step1}", "arg3": "plain_value"}
"arg1": "${var.var1}",
"arg2": "${step.step1}",
"arg3": "plain_value"
}
substituted = engine._substitute_variables(arguments, context) substituted = engine._substitute_variables(arguments, context)
assert substituted["arg1"] == "value1" assert substituted["arg1"] == "value1"
assert substituted["arg2"] == "result1" assert substituted["arg2"] == "result1"
@ -412,16 +371,14 @@ class TestWorkflowEngine:
def test_execute_step_success(self): def test_execute_step_success(self):
executed = [] executed = []
def tool_executor(tool_name, args): def tool_executor(tool_name, args):
executed.append((tool_name, args)) executed.append((tool_name, args))
return "success_result" return "success_result"
engine = WorkflowEngine(tool_executor) engine = WorkflowEngine(tool_executor)
context = WorkflowExecutionContext() context = WorkflowExecutionContext()
step = WorkflowStep( step = WorkflowStep(tool_name="test_tool", arguments={"arg": "value"}, step_id="step1")
tool_name="test_tool",
arguments={"arg": "value"},
step_id="step1"
)
result = engine._execute_step(step, context) result = engine._execute_step(step, context)
assert result["status"] == "success" assert result["status"] == "success"
assert result["step_id"] == "step1" assert result["step_id"] == "step1"
@ -431,16 +388,15 @@ class TestWorkflowEngine:
def test_execute_step_skipped(self): def test_execute_step_skipped(self):
executed = [] executed = []
def tool_executor(tool_name, args): def tool_executor(tool_name, args):
executed.append((tool_name, args)) executed.append((tool_name, args))
return "result" return "result"
engine = WorkflowEngine(tool_executor) engine = WorkflowEngine(tool_executor)
context = WorkflowExecutionContext() context = WorkflowExecutionContext()
step = WorkflowStep( step = WorkflowStep(
tool_name="test_tool", tool_name="test_tool", arguments={"arg": "value"}, step_id="step1", condition="False"
arguments={"arg": "value"},
step_id="step1",
condition="False"
) )
result = engine._execute_step(step, context) result = engine._execute_step(step, context)
assert result["status"] == "skipped" assert result["status"] == "skipped"
@ -449,18 +405,17 @@ class TestWorkflowEngine:
def test_execute_step_failed_with_retry(self): def test_execute_step_failed_with_retry(self):
executed = [] executed = []
def tool_executor(tool_name, args): def tool_executor(tool_name, args):
executed.append((tool_name, args)) executed.append((tool_name, args))
if len(executed) < 2: if len(executed) < 2:
raise Exception("Temporary failure") raise Exception("Temporary failure")
return "success_result" return "success_result"
engine = WorkflowEngine(tool_executor) engine = WorkflowEngine(tool_executor)
context = WorkflowExecutionContext() context = WorkflowExecutionContext()
step = WorkflowStep( step = WorkflowStep(
tool_name="test_tool", tool_name="test_tool", arguments={"arg": "value"}, step_id="step1", retry_count=1
arguments={"arg": "value"},
step_id="step1",
retry_count=1
) )
result = engine._execute_step(step, context) result = engine._execute_step(step, context)
assert result["status"] == "success" assert result["status"] == "success"
@ -469,16 +424,15 @@ class TestWorkflowEngine:
def test_execute_step_failed(self): def test_execute_step_failed(self):
executed = [] executed = []
def tool_executor(tool_name, args): def tool_executor(tool_name, args):
executed.append((tool_name, args)) executed.append((tool_name, args))
raise Exception("Permanent failure") raise Exception("Permanent failure")
engine = WorkflowEngine(tool_executor) engine = WorkflowEngine(tool_executor)
context = WorkflowExecutionContext() context = WorkflowExecutionContext()
step = WorkflowStep( step = WorkflowStep(
tool_name="test_tool", tool_name="test_tool", arguments={"arg": "value"}, step_id="step1", retry_count=1
arguments={"arg": "value"},
step_id="step1",
retry_count=1
) )
result = engine._execute_step(step, context) result = engine._execute_step(step, context)
assert result["status"] == "failed" assert result["status"] == "failed"
@ -488,6 +442,7 @@ class TestWorkflowEngine:
def test_get_next_steps_sequential(self): def test_get_next_steps_sequential(self):
def tool_executor(tool_name, args): def tool_executor(tool_name, args):
return "result" return "result"
engine = WorkflowEngine(tool_executor) engine = WorkflowEngine(tool_executor)
step1 = WorkflowStep(tool_name="tool1", arguments={}, step_id="step1") step1 = WorkflowStep(tool_name="tool1", arguments={}, step_id="step1")
step2 = WorkflowStep(tool_name="tool2", arguments={}, step_id="step2") step2 = WorkflowStep(tool_name="tool2", arguments={}, step_id="step2")
@ -495,7 +450,7 @@ class TestWorkflowEngine:
name="test", name="test",
description="test", description="test",
steps=[step1, step2], steps=[step1, step2],
execution_mode=ExecutionMode.SEQUENTIAL execution_mode=ExecutionMode.SEQUENTIAL,
) )
result = {"status": "success"} result = {"status": "success"}
next_steps = engine._get_next_steps(step1, result, workflow) next_steps = engine._get_next_steps(step1, result, workflow)
@ -504,28 +459,22 @@ class TestWorkflowEngine:
def test_get_next_steps_on_success(self): def test_get_next_steps_on_success(self):
def tool_executor(tool_name, args): def tool_executor(tool_name, args):
return "result" return "result"
engine = WorkflowEngine(tool_executor) engine = WorkflowEngine(tool_executor)
step1 = WorkflowStep( step1 = WorkflowStep(tool_name="tool1", arguments={}, step_id="step1", on_success=["step2"])
tool_name="tool1",
arguments={},
step_id="step1",
on_success=["step2"]
)
step2 = WorkflowStep(tool_name="tool2", arguments={}, step_id="step2") step2 = WorkflowStep(tool_name="tool2", arguments={}, step_id="step2")
workflow = Workflow( workflow = Workflow(name="test", description="test", steps=[step1, step2])
name="test",
description="test",
steps=[step1, step2]
)
result = {"status": "success"} result = {"status": "success"}
next_steps = engine._get_next_steps(step1, result, workflow) next_steps = engine._get_next_steps(step1, result, workflow)
assert next_steps == [step2] assert next_steps == [step2]
def test_execute_workflow_sequential(self): def test_execute_workflow_sequential(self):
executed = [] executed = []
def tool_executor(tool_name, args): def tool_executor(tool_name, args):
executed.append(tool_name) executed.append(tool_name)
return f"result_{tool_name}" return f"result_{tool_name}"
engine = WorkflowEngine(tool_executor) engine = WorkflowEngine(tool_executor)
step1 = WorkflowStep(tool_name="tool1", arguments={}, step_id="step1") step1 = WorkflowStep(tool_name="tool1", arguments={}, step_id="step1")
step2 = WorkflowStep(tool_name="tool2", arguments={}, step_id="step2") step2 = WorkflowStep(tool_name="tool2", arguments={}, step_id="step2")
@ -533,7 +482,7 @@ class TestWorkflowEngine:
name="test", name="test",
description="test", description="test",
steps=[step1, step2], steps=[step1, step2],
execution_mode=ExecutionMode.SEQUENTIAL execution_mode=ExecutionMode.SEQUENTIAL,
) )
context = engine.execute_workflow(workflow) context = engine.execute_workflow(workflow)
assert executed == ["tool1", "tool2"] assert executed == ["tool1", "tool2"]
@ -542,9 +491,11 @@ class TestWorkflowEngine:
def test_execute_workflow_parallel(self): def test_execute_workflow_parallel(self):
executed = [] executed = []
def tool_executor(tool_name, args): def tool_executor(tool_name, args):
executed.append(tool_name) executed.append(tool_name)
return f"result_{tool_name}" return f"result_{tool_name}"
engine = WorkflowEngine(tool_executor) engine = WorkflowEngine(tool_executor)
step1 = WorkflowStep(tool_name="tool1", arguments={}, step_id="step1") step1 = WorkflowStep(tool_name="tool1", arguments={}, step_id="step1")
step2 = WorkflowStep(tool_name="tool2", arguments={}, step_id="step2") step2 = WorkflowStep(tool_name="tool2", arguments={}, step_id="step2")
@ -552,7 +503,7 @@ class TestWorkflowEngine:
name="test", name="test",
description="test", description="test",
steps=[step1, step2], steps=[step1, step2],
execution_mode=ExecutionMode.PARALLEL execution_mode=ExecutionMode.PARALLEL,
) )
context = engine.execute_workflow(workflow) context = engine.execute_workflow(workflow)
assert set(executed) == {"tool1", "tool2"} assert set(executed) == {"tool1", "tool2"}