feat: update reasoning and task completion markers
feat: enable autonomous mode by default refactor: improve assistant class structure maintenance: update pyproject.toml version to 1.53.0 fix: handle edge cases in autonomous mode content extraction docs: clarify autonomous mode deprecation in command line arguments
This commit is contained in:
parent
63c2f52885
commit
20668d9086
47
CHANGELOG.md
47
CHANGELOG.md
@ -1,6 +1,53 @@
|
||||
# Changelog
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
## Version 1.52.0 - 2025-11-10
|
||||
|
||||
This release updates the project version to 1.52.0. No new features or changes are introduced for users or developers.
|
||||
|
||||
**Changes:** 1 files, 2 lines
|
||||
**Languages:** TOML (2 lines)
|
||||
|
||||
## Version 1.51.0 - 2025-11-10
|
||||
|
||||
The system can now extract and clean reasoning steps during task completion. Autonomous mode has been updated to recognize these reasoning steps and task completion markers, improving overall performance.
|
||||
|
||||
**Changes:** 5 files, 65 lines
|
||||
**Languages:** Python (63 lines), TOML (2 lines)
|
||||
|
||||
## Version 1.50.0 - 2025-11-09
|
||||
|
||||
### Added
|
||||
- **LLM Reasoning Display**: The assistant now displays its reasoning process before each response
|
||||
- Added `REASONING:` prefix instruction in system prompt
|
||||
- Reasoning is extracted and displayed with a blue thought bubble icon
|
||||
- Provides transparency into the assistant's decision-making process
|
||||
|
||||
- **Task Completion Marker**: Implemented `[TASK_COMPLETE]` marker for explicit task completion signaling
|
||||
- LLM can now mark tasks as complete with a special marker
|
||||
- Marker is stripped from user-facing output
|
||||
- Autonomous mode detection recognizes the marker for faster completion
|
||||
- Reduces unnecessary iterations when tasks are finished
|
||||
|
||||
### Changed
|
||||
- Updated system prompt in `context.py` to include response format instructions
|
||||
- Enhanced `process_response_autonomous()` to extract and display reasoning
|
||||
- Modified `is_task_complete()` to recognize `[TASK_COMPLETE]` marker
|
||||
- Both autonomous and regular modes now support reasoning display
|
||||
|
||||
**Changes:** 3 files, 52 lines
|
||||
**Languages:** Python (52 lines)
|
||||
|
||||
## Version 1.49.0 - 2025-11-09
|
||||
|
||||
Autonomous mode is now enabled by default, improving performance. Identical messages are now removed in autonomous mode to prevent redundancy.
|
||||
|
||||
**Changes:** 3 files, 28 lines
|
||||
**Languages:** Markdown (18 lines), Python (8 lines), TOML (2 lines)
|
||||
|
||||
## Version 1.48.1 - 2025-11-09
|
||||
|
||||
### Fixed
|
||||
|
||||
@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
|
||||
|
||||
[project]
|
||||
name = "rp"
|
||||
version = "1.51.0"
|
||||
version = "1.52.0"
|
||||
description = "R python edition. The ultimate autonomous AI CLI."
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.10"
|
||||
|
||||
@ -45,7 +45,12 @@ Commands in interactive mode:
|
||||
parser.add_argument("-u", "--api-url", help="API endpoint URL")
|
||||
parser.add_argument("--model-list-url", help="Model list endpoint URL")
|
||||
parser.add_argument("-i", "--interactive", action="store_true", help="Interactive mode")
|
||||
parser.add_argument("-a", "--autonomous", action="store_true", help="Autonomous mode (now default, this flag is deprecated)")
|
||||
parser.add_argument(
|
||||
"-a",
|
||||
"--autonomous",
|
||||
action="store_true",
|
||||
help="Autonomous mode (now default, this flag is deprecated)",
|
||||
)
|
||||
parser.add_argument("-v", "--verbose", action="store_true", help="Verbose output")
|
||||
parser.add_argument(
|
||||
"--debug", action="store_true", help="Enable debug mode with detailed logging"
|
||||
|
||||
@ -21,17 +21,17 @@ def extract_reasoning_and_clean_content(content):
|
||||
tuple: (reasoning, cleaned_content)
|
||||
"""
|
||||
reasoning = None
|
||||
lines = content.split('\n')
|
||||
lines = content.split("\n")
|
||||
cleaned_lines = []
|
||||
|
||||
for line in lines:
|
||||
if line.strip().startswith('REASONING:'):
|
||||
if line.strip().startswith("REASONING:"):
|
||||
reasoning = line.strip()[10:].strip()
|
||||
else:
|
||||
cleaned_lines.append(line)
|
||||
|
||||
cleaned_content = '\n'.join(cleaned_lines)
|
||||
cleaned_content = cleaned_content.replace('[TASK_COMPLETE]', '').strip()
|
||||
cleaned_content = "\n".join(cleaned_lines)
|
||||
cleaned_content = cleaned_content.replace("[TASK_COMPLETE]", "").strip()
|
||||
|
||||
return reasoning, cleaned_content
|
||||
|
||||
@ -128,7 +128,11 @@ def process_response_autonomous(assistant, response):
|
||||
display_tool_call(func_name, arguments, status, result)
|
||||
sanitized_result = sanitize_for_json(result)
|
||||
tool_results.append(
|
||||
{"tool_call_id": tool_call["id"], "role": "tool", "content": json.dumps(sanitized_result)}
|
||||
{
|
||||
"tool_call_id": tool_call["id"],
|
||||
"role": "tool",
|
||||
"content": json.dumps(sanitized_result),
|
||||
}
|
||||
)
|
||||
for result in tool_results:
|
||||
assistant.messages.append(result)
|
||||
|
||||
@ -2,4 +2,11 @@ from rp.core.api import call_api, list_models
|
||||
from rp.core.assistant import Assistant
|
||||
from rp.core.context import init_system_message, manage_context_window, get_context_content
|
||||
|
||||
__all__ = ["Assistant", "call_api", "list_models", "init_system_message", "manage_context_window", "get_context_content"]
|
||||
__all__ = [
|
||||
"Assistant",
|
||||
"call_api",
|
||||
"list_models",
|
||||
"init_system_message",
|
||||
"manage_context_window",
|
||||
"get_context_content",
|
||||
]
|
||||
|
||||
@ -6,9 +6,7 @@ import readline
|
||||
import signal
|
||||
import sqlite3
|
||||
import sys
|
||||
import time
|
||||
import traceback
|
||||
import uuid
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
from rp.commands import handle_command
|
||||
@ -68,7 +66,7 @@ from rp.tools.memory import (
|
||||
from rp.tools.patch import apply_patch, create_diff, display_file_diff
|
||||
from rp.tools.python_exec import python_exec
|
||||
from rp.tools.web import http_fetch, web_search, web_search_news
|
||||
from rp.ui import Colors, Spinner, render_markdown
|
||||
from rp.ui import Colors, render_markdown
|
||||
from rp.ui.progress import ProgressIndicator
|
||||
|
||||
logger = logging.getLogger("rp")
|
||||
@ -112,6 +110,7 @@ class Assistant:
|
||||
self.last_result = None
|
||||
self.init_database()
|
||||
from rp.memory import KnowledgeStore, FactExtractor, GraphMemory
|
||||
|
||||
self.knowledge_store = KnowledgeStore(DB_PATH, db_conn=self.db_conn)
|
||||
self.fact_extractor = FactExtractor()
|
||||
self.graph_memory = GraphMemory(DB_PATH, db_conn=self.db_conn)
|
||||
@ -127,7 +126,10 @@ class Assistant:
|
||||
self.enhanced = None
|
||||
|
||||
from rp.config import BACKGROUND_MONITOR_ENABLED
|
||||
bg_enabled = os.environ.get("BACKGROUND_MONITOR", str(BACKGROUND_MONITOR_ENABLED)).lower() in ("1", "true", "yes")
|
||||
|
||||
bg_enabled = os.environ.get(
|
||||
"BACKGROUND_MONITOR", str(BACKGROUND_MONITOR_ENABLED)
|
||||
).lower() in ("1", "true", "yes")
|
||||
|
||||
if bg_enabled:
|
||||
try:
|
||||
@ -338,6 +340,7 @@ class Assistant:
|
||||
content = message.get("content", "")
|
||||
|
||||
from rp.autonomous.mode import extract_reasoning_and_clean_content
|
||||
|
||||
reasoning, cleaned_content = extract_reasoning_and_clean_content(content)
|
||||
|
||||
if reasoning:
|
||||
@ -438,6 +441,7 @@ class Assistant:
|
||||
# If cmd_result is None, it's not a special command, process with autonomous mode.
|
||||
elif cmd_result is None:
|
||||
from rp.autonomous import run_autonomous_mode
|
||||
|
||||
run_autonomous_mode(self, user_input)
|
||||
except EOFError:
|
||||
break
|
||||
@ -453,6 +457,7 @@ class Assistant:
|
||||
else:
|
||||
message = sys.stdin.read()
|
||||
from rp.autonomous import run_autonomous_mode
|
||||
|
||||
run_autonomous_mode(self, message)
|
||||
|
||||
def run_autonomous(self):
|
||||
|
||||
@ -70,6 +70,7 @@ def get_context_content():
|
||||
logging.error(f"Error reading context file {knowledge_file}: {e}")
|
||||
return "\n\n".join(context_parts)
|
||||
|
||||
|
||||
def init_system_message(args):
|
||||
context_parts = [
|
||||
"You are a professional AI assistant with access to advanced tools.",
|
||||
|
||||
@ -17,7 +17,9 @@ def inject_knowledge_context(assistant, user_message):
|
||||
break
|
||||
try:
|
||||
# Run all search methods
|
||||
knowledge_results = assistant.enhanced.knowledge_store.search_entries(user_message, top_k=5) # Hybrid semantic + keyword + category
|
||||
knowledge_results = assistant.enhanced.knowledge_store.search_entries(
|
||||
user_message, top_k=5
|
||||
) # Hybrid semantic + keyword + category
|
||||
# Additional keyword search if needed (but already in hybrid)
|
||||
# Category-specific: preferences and general
|
||||
pref_results = assistant.enhanced.knowledge_store.get_by_category("preferences", limit=5)
|
||||
@ -25,12 +27,14 @@ def inject_knowledge_context(assistant, user_message):
|
||||
category_results = []
|
||||
for entry in pref_results + general_results:
|
||||
if any(word in entry.content.lower() for word in user_message.lower().split()):
|
||||
category_results.append({
|
||||
"content": entry.content,
|
||||
"score": 0.6,
|
||||
"source": f"Knowledge Base ({entry.category})",
|
||||
"type": "knowledge_category",
|
||||
})
|
||||
category_results.append(
|
||||
{
|
||||
"content": entry.content,
|
||||
"score": 0.6,
|
||||
"source": f"Knowledge Base ({entry.category})",
|
||||
"type": "knowledge_category",
|
||||
}
|
||||
)
|
||||
|
||||
conversation_results = []
|
||||
if hasattr(assistant.enhanced, "conversation_memory"):
|
||||
|
||||
@ -4,87 +4,107 @@ import sqlite3
|
||||
from typing import List, Optional, Set
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
|
||||
@dataclass
|
||||
class Entity:
|
||||
name: str
|
||||
entityType: str
|
||||
observations: List[str]
|
||||
|
||||
|
||||
@dataclass
|
||||
class Relation:
|
||||
from_: str = field(metadata={'alias': 'from'})
|
||||
from_: str = field(metadata={"alias": "from"})
|
||||
to: str
|
||||
relationType: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class KnowledgeGraph:
|
||||
entities: List[Entity]
|
||||
relations: List[Relation]
|
||||
|
||||
|
||||
@dataclass
|
||||
class CreateEntitiesRequest:
|
||||
entities: List[Entity]
|
||||
|
||||
|
||||
@dataclass
|
||||
class CreateRelationsRequest:
|
||||
relations: List[Relation]
|
||||
|
||||
|
||||
@dataclass
|
||||
class ObservationItem:
|
||||
entityName: str
|
||||
contents: List[str]
|
||||
|
||||
|
||||
@dataclass
|
||||
class AddObservationsRequest:
|
||||
observations: List[ObservationItem]
|
||||
|
||||
|
||||
@dataclass
|
||||
class DeletionItem:
|
||||
entityName: str
|
||||
observations: List[str]
|
||||
|
||||
|
||||
@dataclass
|
||||
class DeleteObservationsRequest:
|
||||
deletions: List[DeletionItem]
|
||||
|
||||
|
||||
@dataclass
|
||||
class DeleteEntitiesRequest:
|
||||
entityNames: List[str]
|
||||
|
||||
|
||||
@dataclass
|
||||
class DeleteRelationsRequest:
|
||||
relations: List[Relation]
|
||||
|
||||
|
||||
@dataclass
|
||||
class SearchNodesRequest:
|
||||
query: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class OpenNodesRequest:
|
||||
names: List[str]
|
||||
depth: int = 1
|
||||
|
||||
|
||||
@dataclass
|
||||
class PopulateRequest:
|
||||
text: str
|
||||
|
||||
|
||||
class GraphMemory:
|
||||
def __init__(self, db_path: str = 'graph_memory.db', db_conn: Optional[sqlite3.Connection] = None):
|
||||
def __init__(
|
||||
self, db_path: str = "graph_memory.db", db_conn: Optional[sqlite3.Connection] = None
|
||||
):
|
||||
self.db_path = db_path
|
||||
self.conn = db_conn if db_conn else sqlite3.connect(self.db_path, check_same_thread=False)
|
||||
self.init_db()
|
||||
|
||||
def init_db(self):
|
||||
cursor = self.conn.cursor()
|
||||
cursor.execute('''
|
||||
cursor.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS entities (
|
||||
id INTEGER PRIMARY KEY,
|
||||
name TEXT UNIQUE,
|
||||
entity_type TEXT,
|
||||
observations TEXT
|
||||
)
|
||||
''')
|
||||
cursor.execute('''
|
||||
"""
|
||||
)
|
||||
cursor.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS relations (
|
||||
id INTEGER PRIMARY KEY,
|
||||
from_entity TEXT,
|
||||
@ -92,7 +112,8 @@ class GraphMemory:
|
||||
relation_type TEXT,
|
||||
UNIQUE(from_entity, to_entity, relation_type)
|
||||
)
|
||||
''')
|
||||
"""
|
||||
)
|
||||
self.conn.commit()
|
||||
|
||||
def create_entities(self, entities: List[Entity]) -> List[Entity]:
|
||||
@ -101,8 +122,10 @@ class GraphMemory:
|
||||
cursor = conn.cursor()
|
||||
for e in entities:
|
||||
try:
|
||||
cursor.execute('INSERT INTO entities (name, entity_type, observations) VALUES (?, ?, ?)',
|
||||
(e.name, e.entityType, json.dumps(e.observations)))
|
||||
cursor.execute(
|
||||
"INSERT INTO entities (name, entity_type, observations) VALUES (?, ?, ?)",
|
||||
(e.name, e.entityType, json.dumps(e.observations)),
|
||||
)
|
||||
new_entities.append(e)
|
||||
except sqlite3.IntegrityError:
|
||||
pass # already exists
|
||||
@ -115,8 +138,10 @@ class GraphMemory:
|
||||
cursor = conn.cursor()
|
||||
for r in relations:
|
||||
try:
|
||||
cursor.execute('INSERT INTO relations (from_entity, to_entity, relation_type) VALUES (?, ?, ?)',
|
||||
(r.from_, r.to, r.relationType))
|
||||
cursor.execute(
|
||||
"INSERT INTO relations (from_entity, to_entity, relation_type) VALUES (?, ?, ?)",
|
||||
(r.from_, r.to, r.relationType),
|
||||
)
|
||||
new_relations.append(r)
|
||||
except sqlite3.IntegrityError:
|
||||
pass # already exists
|
||||
@ -130,16 +155,19 @@ class GraphMemory:
|
||||
for obs in observations:
|
||||
name = obs.entityName.lower()
|
||||
contents = obs.contents
|
||||
cursor.execute('SELECT observations FROM entities WHERE LOWER(name) = ?', (name,))
|
||||
cursor.execute("SELECT observations FROM entities WHERE LOWER(name) = ?", (name,))
|
||||
row = cursor.fetchone()
|
||||
if not row:
|
||||
# Log the error instead of raising an exception
|
||||
print(f"Error: Entity {name} not found when adding observations.")
|
||||
return [] # Return an empty list or appropriate failure indicator
|
||||
return [] # Return an empty list or appropriate failure indicator
|
||||
current_obs = json.loads(row[0]) if row[0] else []
|
||||
added = [c for c in contents if c not in current_obs]
|
||||
current_obs.extend(added)
|
||||
cursor.execute('UPDATE entities SET observations = ? WHERE LOWER(name) = ?', (json.dumps(current_obs), name))
|
||||
cursor.execute(
|
||||
"UPDATE entities SET observations = ? WHERE LOWER(name) = ?",
|
||||
(json.dumps(current_obs), name),
|
||||
)
|
||||
results.append({"entityName": name, "addedObservations": added})
|
||||
conn.commit()
|
||||
return results
|
||||
@ -148,11 +176,16 @@ class GraphMemory:
|
||||
conn = self.conn
|
||||
cursor = conn.cursor()
|
||||
# delete entities
|
||||
cursor.executemany('DELETE FROM entities WHERE LOWER(name) = ?', [(n.lower(),) for n in entity_names])
|
||||
cursor.executemany(
|
||||
"DELETE FROM entities WHERE LOWER(name) = ?", [(n.lower(),) for n in entity_names]
|
||||
)
|
||||
# delete relations involving them
|
||||
placeholders = ','.join('?' * len(entity_names))
|
||||
placeholders = ",".join("?" * len(entity_names))
|
||||
params = [n.lower() for n in entity_names] * 2
|
||||
cursor.execute(f'DELETE FROM relations WHERE LOWER(from_entity) IN ({placeholders}) OR LOWER(to_entity) IN ({placeholders})', params)
|
||||
cursor.execute(
|
||||
f"DELETE FROM relations WHERE LOWER(from_entity) IN ({placeholders}) OR LOWER(to_entity) IN ({placeholders})",
|
||||
params,
|
||||
)
|
||||
conn.commit()
|
||||
|
||||
def delete_observations(self, deletions: List[DeletionItem]):
|
||||
@ -161,20 +194,25 @@ class GraphMemory:
|
||||
for del_item in deletions:
|
||||
name = del_item.entityName.lower()
|
||||
to_delete = del_item.observations
|
||||
cursor.execute('SELECT observations FROM entities WHERE LOWER(name) = ?', (name,))
|
||||
cursor.execute("SELECT observations FROM entities WHERE LOWER(name) = ?", (name,))
|
||||
row = cursor.fetchone()
|
||||
if row:
|
||||
current_obs = json.loads(row[0]) if row[0] else []
|
||||
current_obs = [obs for obs in current_obs if obs not in to_delete]
|
||||
cursor.execute('UPDATE entities SET observations = ? WHERE LOWER(name) = ?', (json.dumps(current_obs), name))
|
||||
cursor.execute(
|
||||
"UPDATE entities SET observations = ? WHERE LOWER(name) = ?",
|
||||
(json.dumps(current_obs), name),
|
||||
)
|
||||
conn.commit()
|
||||
|
||||
def delete_relations(self, relations: List[Relation]):
|
||||
conn = self.conn
|
||||
cursor = conn.cursor()
|
||||
for r in relations:
|
||||
cursor.execute('DELETE FROM relations WHERE LOWER(from_entity) = ? AND LOWER(to_entity) = ? AND LOWER(relation_type) = ?',
|
||||
(r.from_.lower(), r.to.lower(), r.relationType.lower()))
|
||||
cursor.execute(
|
||||
"DELETE FROM relations WHERE LOWER(from_entity) = ? AND LOWER(to_entity) = ? AND LOWER(relation_type) = ?",
|
||||
(r.from_.lower(), r.to.lower(), r.relationType.lower()),
|
||||
)
|
||||
conn.commit()
|
||||
|
||||
def read_graph(self) -> KnowledgeGraph:
|
||||
@ -182,12 +220,12 @@ class GraphMemory:
|
||||
relations = []
|
||||
conn = self.conn
|
||||
cursor = conn.cursor()
|
||||
cursor.execute('SELECT name, entity_type, observations FROM entities')
|
||||
cursor.execute("SELECT name, entity_type, observations FROM entities")
|
||||
for row in cursor.fetchall():
|
||||
name, etype, obs = row
|
||||
observations = json.loads(obs) if obs else []
|
||||
entities.append(Entity(name=name, entityType=etype, observations=observations))
|
||||
cursor.execute('SELECT from_entity, to_entity, relation_type FROM relations')
|
||||
cursor.execute("SELECT from_entity, to_entity, relation_type FROM relations")
|
||||
for row in cursor.fetchall():
|
||||
relations.append(Relation(from_=row[0], to=row[1], relationType=row[2]))
|
||||
return KnowledgeGraph(entities=entities, relations=relations)
|
||||
@ -197,17 +235,19 @@ class GraphMemory:
|
||||
conn = self.conn
|
||||
cursor = conn.cursor()
|
||||
query_lower = query.lower()
|
||||
cursor.execute('SELECT name, entity_type, observations FROM entities')
|
||||
cursor.execute("SELECT name, entity_type, observations FROM entities")
|
||||
for row in cursor.fetchall():
|
||||
name, etype, obs = row
|
||||
observations = json.loads(obs) if obs else []
|
||||
if (query_lower in name.lower() or
|
||||
query_lower in etype.lower() or
|
||||
any(query_lower in o.lower() for o in observations)):
|
||||
if (
|
||||
query_lower in name.lower()
|
||||
or query_lower in etype.lower()
|
||||
or any(query_lower in o.lower() for o in observations)
|
||||
):
|
||||
entities.append(Entity(name=name, entityType=etype, observations=observations))
|
||||
names = {e.name.lower() for e in entities}
|
||||
relations = []
|
||||
cursor.execute('SELECT from_entity, to_entity, relation_type FROM relations')
|
||||
cursor.execute("SELECT from_entity, to_entity, relation_type FROM relations")
|
||||
for row in cursor.fetchall():
|
||||
if row[0].lower() in names and row[1].lower() in names:
|
||||
relations.append(Relation(from_=row[0], to=row[1], relationType=row[2]))
|
||||
@ -221,13 +261,16 @@ class GraphMemory:
|
||||
def traverse(current_names: List[str], current_depth: int):
|
||||
if current_depth > depth:
|
||||
return
|
||||
name_set = {n.lower() for n in current_names}
|
||||
{n.lower() for n in current_names}
|
||||
new_entities = []
|
||||
conn = self.conn
|
||||
cursor = conn.cursor()
|
||||
placeholders = ','.join('?' * len(current_names))
|
||||
placeholders = ",".join("?" * len(current_names))
|
||||
params = [n.lower() for n in current_names]
|
||||
cursor.execute(f'SELECT name, entity_type, observations FROM entities WHERE LOWER(name) IN ({placeholders})', params)
|
||||
cursor.execute(
|
||||
f"SELECT name, entity_type, observations FROM entities WHERE LOWER(name) IN ({placeholders})",
|
||||
params,
|
||||
)
|
||||
for row in cursor.fetchall():
|
||||
name, etype, obs = row
|
||||
if name.lower() not in visited:
|
||||
@ -237,9 +280,12 @@ class GraphMemory:
|
||||
new_entities.append(entity)
|
||||
entities.append(entity)
|
||||
# Find relations involving these entities
|
||||
placeholders = ','.join('?' * len(new_entities))
|
||||
placeholders = ",".join("?" * len(new_entities))
|
||||
params = [e.name.lower() for e in new_entities] * 2
|
||||
cursor.execute(f'SELECT from_entity, to_entity, relation_type FROM relations WHERE LOWER(from_entity) IN ({placeholders}) OR LOWER(to_entity) IN ({placeholders})', params)
|
||||
cursor.execute(
|
||||
f"SELECT from_entity, to_entity, relation_type FROM relations WHERE LOWER(from_entity) IN ({placeholders}) OR LOWER(to_entity) IN ({placeholders})",
|
||||
params,
|
||||
)
|
||||
for row in cursor.fetchall():
|
||||
rel = Relation(from_=row[0], to=row[1], relationType=row[2])
|
||||
if rel not in relations:
|
||||
@ -254,24 +300,24 @@ class GraphMemory:
|
||||
|
||||
def populate_from_text(self, text: str):
|
||||
# Algorithm: Extract entities as capitalized words, relations from patterns, observations from sentences mentioning entities
|
||||
entities = set(re.findall(r'\b[A-Z][a-zA-Z]*\b', text))
|
||||
entities = set(re.findall(r"\b[A-Z][a-zA-Z]*\b", text))
|
||||
for entity in entities:
|
||||
self.create_entities([Entity(name=entity, entityType='unknown', observations=[])])
|
||||
self.create_entities([Entity(name=entity, entityType="unknown", observations=[])])
|
||||
# Add the text as observation if it mentions the entity
|
||||
self.add_observations([ObservationItem(entityName=entity, contents=[text])])
|
||||
|
||||
# Extract relations from patterns like "A is B", "A knows B", etc.
|
||||
patterns = [
|
||||
(r'(\w+) is (a|an) (\w+)', 'is_a'),
|
||||
(r'(\w+) knows (\w+)', 'knows'),
|
||||
(r'(\w+) works at (\w+)', 'works_at'),
|
||||
(r'(\w+) lives in (\w+)', 'lives_in'),
|
||||
(r'(\w+) is (\w+)', 'is'), # general
|
||||
(r"(\w+) is (a|an) (\w+)", "is_a"),
|
||||
(r"(\w+) knows (\w+)", "knows"),
|
||||
(r"(\w+) works at (\w+)", "works_at"),
|
||||
(r"(\w+) lives in (\w+)", "lives_in"),
|
||||
(r"(\w+) is (\w+)", "is"), # general
|
||||
]
|
||||
for pattern, rel_type in patterns:
|
||||
matches = re.findall(pattern, text, re.IGNORECASE)
|
||||
for match in matches:
|
||||
if len(match) == 3 and match[1].lower() in ['a', 'an']:
|
||||
if len(match) == 3 and match[1].lower() in ["a", "an"]:
|
||||
from_e, _, to_e = match
|
||||
elif len(match) == 2:
|
||||
from_e, to_e = match
|
||||
@ -280,9 +326,10 @@ class GraphMemory:
|
||||
if from_e in entities and to_e in entities:
|
||||
self.create_relations([Relation(from_=from_e, to=to_e, relationType=rel_type)])
|
||||
elif from_e in entities:
|
||||
self.create_entities([Entity(name=to_e, entityType='unknown', observations=[])])
|
||||
self.create_entities([Entity(name=to_e, entityType="unknown", observations=[])])
|
||||
self.create_relations([Relation(from_=from_e, to=to_e, relationType=rel_type)])
|
||||
elif to_e in entities:
|
||||
self.create_entities([Entity(name=from_e, entityType='unknown', observations=[])])
|
||||
self.create_entities(
|
||||
[Entity(name=from_e, entityType="unknown", observations=[])]
|
||||
)
|
||||
self.create_relations([Relation(from_=from_e, to=to_e, relationType=rel_type)])
|
||||
|
||||
|
||||
@ -20,7 +20,7 @@ class KnowledgeEntry:
|
||||
importance_score: float = 1.0
|
||||
|
||||
def __str__(self):
|
||||
return json.dumps(self.to_dict(), indent=4, sort_keys=True,default=str)
|
||||
return json.dumps(self.to_dict(), indent=4, sort_keys=True, default=str)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return {
|
||||
|
||||
@ -1,4 +1,3 @@
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
|
||||
@ -4,10 +4,10 @@ from rp.ui.colors import Colors
|
||||
def display_tool_call(tool_name, arguments, status="running", result=None):
|
||||
if status == "running":
|
||||
return
|
||||
args_str = ", ".join([f"{k}={str(v)[:20]}" for k, v in list(arguments.items())[:2]])
|
||||
args_str = ", ".join([f"{k}={repr(v)}" for k, v in arguments.items()])
|
||||
line = f"{tool_name}({args_str})"
|
||||
if len(line) > 80:
|
||||
line = line[:77] + "..."
|
||||
if len(line) > 120:
|
||||
line = line[:117] + "..."
|
||||
print(f"{Colors.GRAY}{line}{Colors.RESET}")
|
||||
|
||||
|
||||
|
||||
@ -114,10 +114,6 @@ class TestAssistant(unittest.TestCase):
|
||||
process_message(assistant, "test message")
|
||||
|
||||
from rp.memory import KnowledgeEntry
|
||||
import json
|
||||
import time
|
||||
import uuid
|
||||
from unittest.mock import ANY
|
||||
|
||||
# Mock time.time() and uuid.uuid4() to return consistent values
|
||||
expected_entry = KnowledgeEntry(
|
||||
@ -132,7 +128,7 @@ class TestAssistant(unittest.TestCase):
|
||||
created_at=1234567890.123456,
|
||||
updated_at=1234567890.123456,
|
||||
)
|
||||
expected_content = str(expected_entry)
|
||||
str(expected_entry)
|
||||
|
||||
assistant.knowledge_store.add_entry.assert_called_once_with(expected_entry)
|
||||
|
||||
|
||||
@ -696,4 +696,3 @@ class TestShowBackgroundEvents:
|
||||
def test_show_events_exception(self, mock_get):
|
||||
mock_get.side_effect = Exception("test")
|
||||
show_background_events(self.assistant)
|
||||
|
||||
|
||||
@ -1,4 +1,3 @@
|
||||
import pytest
|
||||
import tempfile
|
||||
import os
|
||||
from rp.workflows.workflow_definition import ExecutionMode, Workflow, WorkflowStep
|
||||
@ -16,7 +15,7 @@ class TestWorkflowStep:
|
||||
on_success=["step2"],
|
||||
on_failure=["step3"],
|
||||
retry_count=2,
|
||||
timeout_seconds=600
|
||||
timeout_seconds=600,
|
||||
)
|
||||
assert step.tool_name == "test_tool"
|
||||
assert step.arguments == {"arg1": "value1"}
|
||||
@ -28,11 +27,7 @@ class TestWorkflowStep:
|
||||
assert step.timeout_seconds == 600
|
||||
|
||||
def test_to_dict(self):
|
||||
step = WorkflowStep(
|
||||
tool_name="test_tool",
|
||||
arguments={"arg1": "value1"},
|
||||
step_id="step1"
|
||||
)
|
||||
step = WorkflowStep(tool_name="test_tool", arguments={"arg1": "value1"}, step_id="step1")
|
||||
expected = {
|
||||
"tool_name": "test_tool",
|
||||
"arguments": {"arg1": "value1"},
|
||||
@ -77,7 +72,7 @@ class TestWorkflow:
|
||||
steps=[step1, step2],
|
||||
execution_mode=ExecutionMode.PARALLEL,
|
||||
variables={"var1": "value1"},
|
||||
tags=["tag1", "tag2"]
|
||||
tags=["tag1", "tag2"],
|
||||
)
|
||||
assert workflow.name == "test_workflow"
|
||||
assert workflow.description == "A test workflow"
|
||||
@ -94,7 +89,7 @@ class TestWorkflow:
|
||||
steps=[step1],
|
||||
execution_mode=ExecutionMode.SEQUENTIAL,
|
||||
variables={"var1": "value1"},
|
||||
tags=["tag1"]
|
||||
tags=["tag1"],
|
||||
)
|
||||
expected = {
|
||||
"name": "test_workflow",
|
||||
@ -110,16 +105,18 @@ class TestWorkflow:
|
||||
data = {
|
||||
"name": "test_workflow",
|
||||
"description": "A test workflow",
|
||||
"steps": [{
|
||||
"tool_name": "tool1",
|
||||
"arguments": {},
|
||||
"step_id": "step1",
|
||||
"condition": None,
|
||||
"on_success": None,
|
||||
"on_failure": None,
|
||||
"retry_count": 0,
|
||||
"timeout_seconds": 300,
|
||||
}],
|
||||
"steps": [
|
||||
{
|
||||
"tool_name": "tool1",
|
||||
"arguments": {},
|
||||
"step_id": "step1",
|
||||
"condition": None,
|
||||
"on_success": None,
|
||||
"on_failure": None,
|
||||
"retry_count": 0,
|
||||
"timeout_seconds": 300,
|
||||
}
|
||||
],
|
||||
"execution_mode": "parallel",
|
||||
"variables": {"var1": "value1"},
|
||||
"tags": ["tag1"],
|
||||
@ -134,11 +131,7 @@ class TestWorkflow:
|
||||
assert workflow.tags == ["tag1"]
|
||||
|
||||
def test_add_step(self):
|
||||
workflow = Workflow(
|
||||
name="test_workflow",
|
||||
description="A test workflow",
|
||||
steps=[]
|
||||
)
|
||||
workflow = Workflow(name="test_workflow", description="A test workflow", steps=[])
|
||||
step = WorkflowStep(tool_name="tool1", arguments={}, step_id="step1")
|
||||
workflow.add_step(step)
|
||||
assert workflow.steps == [step]
|
||||
@ -147,9 +140,7 @@ class TestWorkflow:
|
||||
step1 = WorkflowStep(tool_name="tool1", arguments={}, step_id="step1")
|
||||
step2 = WorkflowStep(tool_name="tool2", arguments={}, step_id="step2")
|
||||
workflow = Workflow(
|
||||
name="test_workflow",
|
||||
description="A test workflow",
|
||||
steps=[step1, step2]
|
||||
name="test_workflow", description="A test workflow", steps=[step1, step2]
|
||||
)
|
||||
assert workflow.get_step("step1") == step1
|
||||
assert workflow.get_step("step2") == step2
|
||||
@ -162,7 +153,7 @@ class TestWorkflow:
|
||||
name="test_workflow",
|
||||
description="A test workflow",
|
||||
steps=[step1, step2],
|
||||
execution_mode=ExecutionMode.SEQUENTIAL
|
||||
execution_mode=ExecutionMode.SEQUENTIAL,
|
||||
)
|
||||
assert workflow.get_initial_steps() == [step1]
|
||||
|
||||
@ -173,7 +164,7 @@ class TestWorkflow:
|
||||
name="test_workflow",
|
||||
description="A test workflow",
|
||||
steps=[step1, step2],
|
||||
execution_mode=ExecutionMode.PARALLEL
|
||||
execution_mode=ExecutionMode.PARALLEL,
|
||||
)
|
||||
assert workflow.get_initial_steps() == [step1, step2]
|
||||
|
||||
@ -184,7 +175,7 @@ class TestWorkflow:
|
||||
name="test_workflow",
|
||||
description="A test workflow",
|
||||
steps=[step1, step2],
|
||||
execution_mode=ExecutionMode.CONDITIONAL
|
||||
execution_mode=ExecutionMode.CONDITIONAL,
|
||||
)
|
||||
assert workflow.get_initial_steps() == [step2] # Only step without condition
|
||||
|
||||
@ -200,11 +191,7 @@ class TestWorkflowStorage:
|
||||
|
||||
def test_save_and_load_workflow(self):
|
||||
step = WorkflowStep(tool_name="tool1", arguments={}, step_id="step1")
|
||||
workflow = Workflow(
|
||||
name="test_workflow",
|
||||
description="A test workflow",
|
||||
steps=[step]
|
||||
)
|
||||
workflow = Workflow(name="test_workflow", description="A test workflow", steps=[step])
|
||||
workflow_id = self.storage.save_workflow(workflow)
|
||||
loaded = self.storage.load_workflow(workflow_id)
|
||||
assert loaded is not None
|
||||
@ -219,11 +206,7 @@ class TestWorkflowStorage:
|
||||
|
||||
def test_load_workflow_by_name(self):
|
||||
step = WorkflowStep(tool_name="tool1", arguments={}, step_id="step1")
|
||||
workflow = Workflow(
|
||||
name="test_workflow",
|
||||
description="A test workflow",
|
||||
steps=[step]
|
||||
)
|
||||
workflow = Workflow(name="test_workflow", description="A test workflow", steps=[step])
|
||||
self.storage.save_workflow(workflow)
|
||||
loaded = self.storage.load_workflow_by_name("test_workflow")
|
||||
assert loaded is not None
|
||||
@ -236,10 +219,7 @@ class TestWorkflowStorage:
|
||||
def test_list_workflows(self):
|
||||
step = WorkflowStep(tool_name="tool1", arguments={}, step_id="step1")
|
||||
workflow = Workflow(
|
||||
name="test_workflow",
|
||||
description="A test workflow",
|
||||
steps=[step],
|
||||
tags=["tag1"]
|
||||
name="test_workflow", description="A test workflow", steps=[step], tags=["tag1"]
|
||||
)
|
||||
self.storage.save_workflow(workflow)
|
||||
workflows = self.storage.list_workflows()
|
||||
@ -250,16 +230,10 @@ class TestWorkflowStorage:
|
||||
def test_list_workflows_with_tag(self):
|
||||
step = WorkflowStep(tool_name="tool1", arguments={}, step_id="step1")
|
||||
workflow1 = Workflow(
|
||||
name="test_workflow1",
|
||||
description="A test workflow",
|
||||
steps=[step],
|
||||
tags=["tag1"]
|
||||
name="test_workflow1", description="A test workflow", steps=[step], tags=["tag1"]
|
||||
)
|
||||
workflow2 = Workflow(
|
||||
name="test_workflow2",
|
||||
description="A test workflow",
|
||||
steps=[step],
|
||||
tags=["tag2"]
|
||||
name="test_workflow2", description="A test workflow", steps=[step], tags=["tag2"]
|
||||
)
|
||||
self.storage.save_workflow(workflow1)
|
||||
self.storage.save_workflow(workflow2)
|
||||
@ -269,11 +243,7 @@ class TestWorkflowStorage:
|
||||
|
||||
def test_delete_workflow(self):
|
||||
step = WorkflowStep(tool_name="tool1", arguments={}, step_id="step1")
|
||||
workflow = Workflow(
|
||||
name="test_workflow",
|
||||
description="A test workflow",
|
||||
steps=[step]
|
||||
)
|
||||
workflow = Workflow(name="test_workflow", description="A test workflow", steps=[step])
|
||||
workflow_id = self.storage.save_workflow(workflow)
|
||||
deleted = self.storage.delete_workflow(workflow_id)
|
||||
assert deleted is True
|
||||
@ -286,11 +256,7 @@ class TestWorkflowStorage:
|
||||
|
||||
def test_save_execution(self):
|
||||
step = WorkflowStep(tool_name="tool1", arguments={}, step_id="step1")
|
||||
workflow = Workflow(
|
||||
name="test_workflow",
|
||||
description="A test workflow",
|
||||
steps=[step]
|
||||
)
|
||||
workflow = Workflow(name="test_workflow", description="A test workflow", steps=[step])
|
||||
workflow_id = self.storage.save_workflow(workflow)
|
||||
context = WorkflowExecutionContext()
|
||||
context.set_step_result("step1", "result")
|
||||
@ -303,11 +269,7 @@ class TestWorkflowStorage:
|
||||
|
||||
def test_get_execution_history(self):
|
||||
step = WorkflowStep(tool_name="tool1", arguments={}, step_id="step1")
|
||||
workflow = Workflow(
|
||||
name="test_workflow",
|
||||
description="A test workflow",
|
||||
steps=[step]
|
||||
)
|
||||
workflow = Workflow(name="test_workflow", description="A test workflow", steps=[step])
|
||||
workflow_id = self.storage.save_workflow(workflow)
|
||||
context = WorkflowExecutionContext()
|
||||
context.set_step_result("step1", "result")
|
||||
@ -318,11 +280,7 @@ class TestWorkflowStorage:
|
||||
|
||||
def test_get_execution_history_limit(self):
|
||||
step = WorkflowStep(tool_name="tool1", arguments={}, step_id="step1")
|
||||
workflow = Workflow(
|
||||
name="test_workflow",
|
||||
description="A test workflow",
|
||||
steps=[step]
|
||||
)
|
||||
workflow = Workflow(name="test_workflow", description="A test workflow", steps=[step])
|
||||
workflow_id = self.storage.save_workflow(workflow)
|
||||
for i in range(5):
|
||||
context = WorkflowExecutionContext()
|
||||
@ -367,6 +325,7 @@ class TestWorkflowEngine:
|
||||
def test_init(self):
|
||||
def tool_executor(tool_name, args):
|
||||
return f"executed {tool_name} with {args}"
|
||||
|
||||
engine = WorkflowEngine(tool_executor, max_workers=10)
|
||||
assert engine.tool_executor == tool_executor
|
||||
assert engine.max_workers == 10
|
||||
@ -374,6 +333,7 @@ class TestWorkflowEngine:
|
||||
def test_evaluate_condition_true(self):
|
||||
def tool_executor(tool_name, args):
|
||||
return "result"
|
||||
|
||||
engine = WorkflowEngine(tool_executor)
|
||||
context = WorkflowExecutionContext()
|
||||
assert engine._evaluate_condition("True", context) is True
|
||||
@ -381,6 +341,7 @@ class TestWorkflowEngine:
|
||||
def test_evaluate_condition_false(self):
|
||||
def tool_executor(tool_name, args):
|
||||
return "result"
|
||||
|
||||
engine = WorkflowEngine(tool_executor)
|
||||
context = WorkflowExecutionContext()
|
||||
assert engine._evaluate_condition("False", context) is False
|
||||
@ -388,6 +349,7 @@ class TestWorkflowEngine:
|
||||
def test_evaluate_condition_with_variables(self):
|
||||
def tool_executor(tool_name, args):
|
||||
return "result"
|
||||
|
||||
engine = WorkflowEngine(tool_executor)
|
||||
context = WorkflowExecutionContext()
|
||||
context.set_variable("test_var", "test_value")
|
||||
@ -396,15 +358,12 @@ class TestWorkflowEngine:
|
||||
def test_substitute_variables(self):
|
||||
def tool_executor(tool_name, args):
|
||||
return "result"
|
||||
|
||||
engine = WorkflowEngine(tool_executor)
|
||||
context = WorkflowExecutionContext()
|
||||
context.set_variable("var1", "value1")
|
||||
context.set_step_result("step1", "result1")
|
||||
arguments = {
|
||||
"arg1": "${var.var1}",
|
||||
"arg2": "${step.step1}",
|
||||
"arg3": "plain_value"
|
||||
}
|
||||
arguments = {"arg1": "${var.var1}", "arg2": "${step.step1}", "arg3": "plain_value"}
|
||||
substituted = engine._substitute_variables(arguments, context)
|
||||
assert substituted["arg1"] == "value1"
|
||||
assert substituted["arg2"] == "result1"
|
||||
@ -412,16 +371,14 @@ class TestWorkflowEngine:
|
||||
|
||||
def test_execute_step_success(self):
|
||||
executed = []
|
||||
|
||||
def tool_executor(tool_name, args):
|
||||
executed.append((tool_name, args))
|
||||
return "success_result"
|
||||
|
||||
engine = WorkflowEngine(tool_executor)
|
||||
context = WorkflowExecutionContext()
|
||||
step = WorkflowStep(
|
||||
tool_name="test_tool",
|
||||
arguments={"arg": "value"},
|
||||
step_id="step1"
|
||||
)
|
||||
step = WorkflowStep(tool_name="test_tool", arguments={"arg": "value"}, step_id="step1")
|
||||
result = engine._execute_step(step, context)
|
||||
assert result["status"] == "success"
|
||||
assert result["step_id"] == "step1"
|
||||
@ -431,16 +388,15 @@ class TestWorkflowEngine:
|
||||
|
||||
def test_execute_step_skipped(self):
|
||||
executed = []
|
||||
|
||||
def tool_executor(tool_name, args):
|
||||
executed.append((tool_name, args))
|
||||
return "result"
|
||||
|
||||
engine = WorkflowEngine(tool_executor)
|
||||
context = WorkflowExecutionContext()
|
||||
step = WorkflowStep(
|
||||
tool_name="test_tool",
|
||||
arguments={"arg": "value"},
|
||||
step_id="step1",
|
||||
condition="False"
|
||||
tool_name="test_tool", arguments={"arg": "value"}, step_id="step1", condition="False"
|
||||
)
|
||||
result = engine._execute_step(step, context)
|
||||
assert result["status"] == "skipped"
|
||||
@ -449,18 +405,17 @@ class TestWorkflowEngine:
|
||||
|
||||
def test_execute_step_failed_with_retry(self):
|
||||
executed = []
|
||||
|
||||
def tool_executor(tool_name, args):
|
||||
executed.append((tool_name, args))
|
||||
if len(executed) < 2:
|
||||
raise Exception("Temporary failure")
|
||||
return "success_result"
|
||||
|
||||
engine = WorkflowEngine(tool_executor)
|
||||
context = WorkflowExecutionContext()
|
||||
step = WorkflowStep(
|
||||
tool_name="test_tool",
|
||||
arguments={"arg": "value"},
|
||||
step_id="step1",
|
||||
retry_count=1
|
||||
tool_name="test_tool", arguments={"arg": "value"}, step_id="step1", retry_count=1
|
||||
)
|
||||
result = engine._execute_step(step, context)
|
||||
assert result["status"] == "success"
|
||||
@ -469,16 +424,15 @@ class TestWorkflowEngine:
|
||||
|
||||
def test_execute_step_failed(self):
|
||||
executed = []
|
||||
|
||||
def tool_executor(tool_name, args):
|
||||
executed.append((tool_name, args))
|
||||
raise Exception("Permanent failure")
|
||||
|
||||
engine = WorkflowEngine(tool_executor)
|
||||
context = WorkflowExecutionContext()
|
||||
step = WorkflowStep(
|
||||
tool_name="test_tool",
|
||||
arguments={"arg": "value"},
|
||||
step_id="step1",
|
||||
retry_count=1
|
||||
tool_name="test_tool", arguments={"arg": "value"}, step_id="step1", retry_count=1
|
||||
)
|
||||
result = engine._execute_step(step, context)
|
||||
assert result["status"] == "failed"
|
||||
@ -488,6 +442,7 @@ class TestWorkflowEngine:
|
||||
def test_get_next_steps_sequential(self):
|
||||
def tool_executor(tool_name, args):
|
||||
return "result"
|
||||
|
||||
engine = WorkflowEngine(tool_executor)
|
||||
step1 = WorkflowStep(tool_name="tool1", arguments={}, step_id="step1")
|
||||
step2 = WorkflowStep(tool_name="tool2", arguments={}, step_id="step2")
|
||||
@ -495,7 +450,7 @@ class TestWorkflowEngine:
|
||||
name="test",
|
||||
description="test",
|
||||
steps=[step1, step2],
|
||||
execution_mode=ExecutionMode.SEQUENTIAL
|
||||
execution_mode=ExecutionMode.SEQUENTIAL,
|
||||
)
|
||||
result = {"status": "success"}
|
||||
next_steps = engine._get_next_steps(step1, result, workflow)
|
||||
@ -504,28 +459,22 @@ class TestWorkflowEngine:
|
||||
def test_get_next_steps_on_success(self):
|
||||
def tool_executor(tool_name, args):
|
||||
return "result"
|
||||
|
||||
engine = WorkflowEngine(tool_executor)
|
||||
step1 = WorkflowStep(
|
||||
tool_name="tool1",
|
||||
arguments={},
|
||||
step_id="step1",
|
||||
on_success=["step2"]
|
||||
)
|
||||
step1 = WorkflowStep(tool_name="tool1", arguments={}, step_id="step1", on_success=["step2"])
|
||||
step2 = WorkflowStep(tool_name="tool2", arguments={}, step_id="step2")
|
||||
workflow = Workflow(
|
||||
name="test",
|
||||
description="test",
|
||||
steps=[step1, step2]
|
||||
)
|
||||
workflow = Workflow(name="test", description="test", steps=[step1, step2])
|
||||
result = {"status": "success"}
|
||||
next_steps = engine._get_next_steps(step1, result, workflow)
|
||||
assert next_steps == [step2]
|
||||
|
||||
def test_execute_workflow_sequential(self):
|
||||
executed = []
|
||||
|
||||
def tool_executor(tool_name, args):
|
||||
executed.append(tool_name)
|
||||
return f"result_{tool_name}"
|
||||
|
||||
engine = WorkflowEngine(tool_executor)
|
||||
step1 = WorkflowStep(tool_name="tool1", arguments={}, step_id="step1")
|
||||
step2 = WorkflowStep(tool_name="tool2", arguments={}, step_id="step2")
|
||||
@ -533,7 +482,7 @@ class TestWorkflowEngine:
|
||||
name="test",
|
||||
description="test",
|
||||
steps=[step1, step2],
|
||||
execution_mode=ExecutionMode.SEQUENTIAL
|
||||
execution_mode=ExecutionMode.SEQUENTIAL,
|
||||
)
|
||||
context = engine.execute_workflow(workflow)
|
||||
assert executed == ["tool1", "tool2"]
|
||||
@ -542,9 +491,11 @@ class TestWorkflowEngine:
|
||||
|
||||
def test_execute_workflow_parallel(self):
|
||||
executed = []
|
||||
|
||||
def tool_executor(tool_name, args):
|
||||
executed.append(tool_name)
|
||||
return f"result_{tool_name}"
|
||||
|
||||
engine = WorkflowEngine(tool_executor)
|
||||
step1 = WorkflowStep(tool_name="tool1", arguments={}, step_id="step1")
|
||||
step2 = WorkflowStep(tool_name="tool2", arguments={}, step_id="step2")
|
||||
@ -552,7 +503,7 @@ class TestWorkflowEngine:
|
||||
name="test",
|
||||
description="test",
|
||||
steps=[step1, step2],
|
||||
execution_mode=ExecutionMode.PARALLEL
|
||||
execution_mode=ExecutionMode.PARALLEL,
|
||||
)
|
||||
context = engine.execute_workflow(workflow)
|
||||
assert set(executed) == {"tool1", "tool2"}
|
||||
|
||||
Loading…
Reference in New Issue
Block a user