import json
import sqlite3
import time
from contextlib import contextmanager
from typing import List, Optional
from rp.core.operations import (
TransactionManager,
Validator,
ValidationError,
managed_connection,
retry,
TRANSIENT_ERRORS,
)
from .workflow_definition import Workflow
class WorkflowStorage:
def __init__(self, db_path: str):
self.db_path = db_path
self._initialize_storage()
@contextmanager
def _get_connection(self):
with managed_connection(self.db_path) as conn:
yield conn
@retry(max_attempts=3, base_delay=0.5, transient_errors=TRANSIENT_ERRORS)
def _initialize_storage(self):
with self._get_connection() as conn:
tx = TransactionManager(conn)
with tx.transaction():
tx.execute("""
CREATE TABLE IF NOT EXISTS workflows (
workflow_id TEXT PRIMARY KEY,
name TEXT NOT NULL,
description TEXT,
workflow_data TEXT NOT NULL,
created_at INTEGER NOT NULL,
updated_at INTEGER NOT NULL,
execution_count INTEGER DEFAULT 0,
last_execution_at INTEGER,
tags TEXT
)
""")
tx.execute("""
CREATE TABLE IF NOT EXISTS workflow_executions (
execution_id TEXT PRIMARY KEY,
workflow_id TEXT NOT NULL,
started_at INTEGER NOT NULL,
completed_at INTEGER,
status TEXT NOT NULL,
execution_log TEXT,
variables TEXT,
step_results TEXT,
FOREIGN KEY (workflow_id) REFERENCES workflows(workflow_id)
)
""")
tx.execute("CREATE INDEX IF NOT EXISTS idx_workflow_name ON workflows(name)")
tx.execute("CREATE INDEX IF NOT EXISTS idx_execution_workflow ON workflow_executions(workflow_id)")
tx.execute("CREATE INDEX IF NOT EXISTS idx_execution_started ON workflow_executions(started_at)")
@retry(max_attempts=3, base_delay=0.5, transient_errors=TRANSIENT_ERRORS)
def save_workflow(self, workflow: Workflow) -> str:
import hashlib
name = Validator.string(workflow.name, "workflow.name", min_length=1, max_length=200)
description = Validator.string(workflow.description or "", "workflow.description", max_length=2000, allow_none=True)
workflow_data = json.dumps(workflow.to_dict())
workflow_id = hashlib.sha256(name.encode()).hexdigest()[:16]
current_time = int(time.time())
tags_json = json.dumps(workflow.tags if workflow.tags else [])
with self._get_connection() as conn:
tx = TransactionManager(conn)
with tx.transaction():
tx.execute("""
INSERT OR REPLACE INTO workflows
(workflow_id, name, description, workflow_data, created_at, updated_at, tags)
VALUES (?, ?, ?, ?, ?, ?, ?)
""", (workflow_id, name, description, workflow_data, current_time, current_time, tags_json))
return workflow_id
@retry(max_attempts=3, base_delay=0.5, transient_errors=TRANSIENT_ERRORS)
def load_workflow(self, workflow_id: str) -> Optional[Workflow]:
workflow_id = Validator.string(workflow_id, "workflow_id", min_length=1, max_length=64)
with self._get_connection() as conn:
cursor = conn.execute("SELECT workflow_data FROM workflows WHERE workflow_id = ?", (workflow_id,))
row = cursor.fetchone()
if row:
workflow_dict = json.loads(row[0])
return Workflow.from_dict(workflow_dict)
return None
@retry(max_attempts=3, base_delay=0.5, transient_errors=TRANSIENT_ERRORS)
def load_workflow_by_name(self, name: str) -> Optional[Workflow]:
name = Validator.string(name, "name", min_length=1, max_length=200)
with self._get_connection() as conn:
cursor = conn.execute("SELECT workflow_data FROM workflows WHERE name = ?", (name,))
row = cursor.fetchone()
if row:
workflow_dict = json.loads(row[0])
return Workflow.from_dict(workflow_dict)
return None
@retry(max_attempts=3, base_delay=0.5, transient_errors=TRANSIENT_ERRORS)
def list_workflows(self, tag: Optional[str] = None) -> List[dict]:
if tag:
tag = Validator.string(tag, "tag", max_length=100)
with self._get_connection() as conn:
if tag:
cursor = conn.execute("""
SELECT workflow_id, name, description, execution_count, last_execution_at, tags
FROM workflows
WHERE tags LIKE ?
ORDER BY name
""", (f'%"{tag}"%',))
else:
cursor = conn.execute("""
SELECT workflow_id, name, description, execution_count, last_execution_at, tags
FROM workflows
ORDER BY name
""")
workflows = []
for row in cursor.fetchall():
workflows.append({
"workflow_id": row[0],
"name": row[1],
"description": row[2],
"execution_count": row[3],
"last_execution_at": row[4],
"tags": json.loads(row[5]) if row[5] else [],
})
return workflows
@retry(max_attempts=3, base_delay=0.5, transient_errors=TRANSIENT_ERRORS)
def delete_workflow(self, workflow_id: str) -> bool:
workflow_id = Validator.string(workflow_id, "workflow_id", min_length=1, max_length=64)
with self._get_connection() as conn:
tx = TransactionManager(conn)
with tx.transaction():
cursor = tx.execute("DELETE FROM workflows WHERE workflow_id = ?", (workflow_id,))
deleted = cursor.rowcount > 0
tx.execute("DELETE FROM workflow_executions WHERE workflow_id = ?", (workflow_id,))
return deleted
@retry(max_attempts=3, base_delay=0.5, transient_errors=TRANSIENT_ERRORS)
def save_execution(self, workflow_id: str, execution_context: "WorkflowExecutionContext") -> str:
import uuid
workflow_id = Validator.string(workflow_id, "workflow_id", min_length=1, max_length=64)
execution_id = str(uuid.uuid4())[:16]
started_at = (
int(execution_context.execution_log[0]["timestamp"])
if execution_context.execution_log
else int(time.time())
)
completed_at = int(time.time())
with self._get_connection() as conn:
tx = TransactionManager(conn)
with tx.transaction():
tx.execute("""
INSERT INTO workflow_executions
(execution_id, workflow_id, started_at, completed_at, status, execution_log, variables, step_results)
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
""", (
execution_id,
workflow_id,
started_at,
completed_at,
"completed",
json.dumps(execution_context.execution_log),
json.dumps(execution_context.variables),
json.dumps(execution_context.step_results),
))
tx.execute("""
UPDATE workflows
SET execution_count = execution_count + 1,
last_execution_at = ?
WHERE workflow_id = ?
""", (completed_at, workflow_id))
return execution_id
@retry(max_attempts=3, base_delay=0.5, transient_errors=TRANSIENT_ERRORS)
def get_execution_history(self, workflow_id: str, limit: int = 10) -> List[dict]:
workflow_id = Validator.string(workflow_id, "workflow_id", min_length=1, max_length=64)
limit = Validator.integer(limit, "limit", min_value=1, max_value=1000)
with self._get_connection() as conn:
cursor = conn.execute("""
SELECT execution_id, started_at, completed_at, status
FROM workflow_executions
WHERE workflow_id = ?
ORDER BY started_at DESC
LIMIT ?
""", (workflow_id, limit))
executions = []
for row in cursor.fetchall():
executions.append({
"execution_id": row[0],
"started_at": row[1],
"completed_at": row[2],
"status": row[3],
})
return executions