|
import json
|
|
import re
|
|
import sqlite3
|
|
from typing import List, Optional, Set
|
|
from dataclasses import dataclass, field
|
|
|
|
@dataclass
|
|
class Entity:
|
|
name: str
|
|
entityType: str
|
|
observations: List[str]
|
|
|
|
@dataclass
|
|
class Relation:
|
|
from_: str = field(metadata={'alias': 'from'})
|
|
to: str
|
|
relationType: str
|
|
|
|
@dataclass
|
|
class KnowledgeGraph:
|
|
entities: List[Entity]
|
|
relations: List[Relation]
|
|
|
|
@dataclass
|
|
class CreateEntitiesRequest:
|
|
entities: List[Entity]
|
|
|
|
@dataclass
|
|
class CreateRelationsRequest:
|
|
relations: List[Relation]
|
|
|
|
@dataclass
|
|
class ObservationItem:
|
|
entityName: str
|
|
contents: List[str]
|
|
|
|
@dataclass
|
|
class AddObservationsRequest:
|
|
observations: List[ObservationItem]
|
|
|
|
@dataclass
|
|
class DeletionItem:
|
|
entityName: str
|
|
observations: List[str]
|
|
|
|
@dataclass
|
|
class DeleteObservationsRequest:
|
|
deletions: List[DeletionItem]
|
|
|
|
@dataclass
|
|
class DeleteEntitiesRequest:
|
|
entityNames: List[str]
|
|
|
|
@dataclass
|
|
class DeleteRelationsRequest:
|
|
relations: List[Relation]
|
|
|
|
@dataclass
|
|
class SearchNodesRequest:
|
|
query: str
|
|
|
|
@dataclass
|
|
class OpenNodesRequest:
|
|
names: List[str]
|
|
depth: int = 1
|
|
|
|
@dataclass
|
|
class PopulateRequest:
|
|
text: str
|
|
|
|
class GraphMemory:
|
|
def __init__(self, db_path: str = 'graph_memory.db', db_conn: Optional[sqlite3.Connection] = None):
|
|
self.db_path = db_path
|
|
self.conn = db_conn if db_conn else sqlite3.connect(self.db_path, check_same_thread=False)
|
|
self.init_db()
|
|
|
|
def init_db(self):
|
|
cursor = self.conn.cursor()
|
|
cursor.execute('''
|
|
CREATE TABLE IF NOT EXISTS entities (
|
|
id INTEGER PRIMARY KEY,
|
|
name TEXT UNIQUE,
|
|
entity_type TEXT,
|
|
observations TEXT
|
|
)
|
|
''')
|
|
cursor.execute('''
|
|
CREATE TABLE IF NOT EXISTS relations (
|
|
id INTEGER PRIMARY KEY,
|
|
from_entity TEXT,
|
|
to_entity TEXT,
|
|
relation_type TEXT,
|
|
UNIQUE(from_entity, to_entity, relation_type)
|
|
)
|
|
''')
|
|
self.conn.commit()
|
|
|
|
def create_entities(self, entities: List[Entity]) -> List[Entity]:
|
|
new_entities = []
|
|
conn = self.conn
|
|
cursor = conn.cursor()
|
|
for e in entities:
|
|
try:
|
|
cursor.execute('INSERT INTO entities (name, entity_type, observations) VALUES (?, ?, ?)',
|
|
(e.name, e.entityType, json.dumps(e.observations)))
|
|
new_entities.append(e)
|
|
except sqlite3.IntegrityError:
|
|
pass # already exists
|
|
conn.commit()
|
|
return new_entities
|
|
|
|
def create_relations(self, relations: List[Relation]) -> List[Relation]:
|
|
new_relations = []
|
|
conn = self.conn
|
|
cursor = conn.cursor()
|
|
for r in relations:
|
|
try:
|
|
cursor.execute('INSERT INTO relations (from_entity, to_entity, relation_type) VALUES (?, ?, ?)',
|
|
(r.from_, r.to, r.relationType))
|
|
new_relations.append(r)
|
|
except sqlite3.IntegrityError:
|
|
pass # already exists
|
|
conn.commit()
|
|
return new_relations
|
|
|
|
def add_observations(self, observations: List[ObservationItem]) -> List[dict]:
|
|
results = []
|
|
conn = self.conn
|
|
cursor = conn.cursor()
|
|
for obs in observations:
|
|
name = obs.entityName.lower()
|
|
contents = obs.contents
|
|
cursor.execute('SELECT observations FROM entities WHERE LOWER(name) = ?', (name,))
|
|
row = cursor.fetchone()
|
|
if not row:
|
|
# Log the error instead of raising an exception
|
|
print(f"Error: Entity {name} not found when adding observations.")
|
|
return [] # Return an empty list or appropriate failure indicator
|
|
current_obs = json.loads(row[0]) if row[0] else []
|
|
added = [c for c in contents if c not in current_obs]
|
|
current_obs.extend(added)
|
|
cursor.execute('UPDATE entities SET observations = ? WHERE LOWER(name) = ?', (json.dumps(current_obs), name))
|
|
results.append({"entityName": name, "addedObservations": added})
|
|
conn.commit()
|
|
return results
|
|
|
|
def delete_entities(self, entity_names: List[str]):
|
|
conn = self.conn
|
|
cursor = conn.cursor()
|
|
# delete entities
|
|
cursor.executemany('DELETE FROM entities WHERE LOWER(name) = ?', [(n.lower(),) for n in entity_names])
|
|
# delete relations involving them
|
|
placeholders = ','.join('?' * len(entity_names))
|
|
params = [n.lower() for n in entity_names] * 2
|
|
cursor.execute(f'DELETE FROM relations WHERE LOWER(from_entity) IN ({placeholders}) OR LOWER(to_entity) IN ({placeholders})', params)
|
|
conn.commit()
|
|
|
|
def delete_observations(self, deletions: List[DeletionItem]):
|
|
conn = self.conn
|
|
cursor = conn.cursor()
|
|
for del_item in deletions:
|
|
name = del_item.entityName.lower()
|
|
to_delete = del_item.observations
|
|
cursor.execute('SELECT observations FROM entities WHERE LOWER(name) = ?', (name,))
|
|
row = cursor.fetchone()
|
|
if row:
|
|
current_obs = json.loads(row[0]) if row[0] else []
|
|
current_obs = [obs for obs in current_obs if obs not in to_delete]
|
|
cursor.execute('UPDATE entities SET observations = ? WHERE LOWER(name) = ?', (json.dumps(current_obs), name))
|
|
conn.commit()
|
|
|
|
def delete_relations(self, relations: List[Relation]):
|
|
conn = self.conn
|
|
cursor = conn.cursor()
|
|
for r in relations:
|
|
cursor.execute('DELETE FROM relations WHERE LOWER(from_entity) = ? AND LOWER(to_entity) = ? AND LOWER(relation_type) = ?',
|
|
(r.from_.lower(), r.to.lower(), r.relationType.lower()))
|
|
conn.commit()
|
|
|
|
def read_graph(self) -> KnowledgeGraph:
|
|
entities = []
|
|
relations = []
|
|
conn = self.conn
|
|
cursor = conn.cursor()
|
|
cursor.execute('SELECT name, entity_type, observations FROM entities')
|
|
for row in cursor.fetchall():
|
|
name, etype, obs = row
|
|
observations = json.loads(obs) if obs else []
|
|
entities.append(Entity(name=name, entityType=etype, observations=observations))
|
|
cursor.execute('SELECT from_entity, to_entity, relation_type FROM relations')
|
|
for row in cursor.fetchall():
|
|
relations.append(Relation(from_=row[0], to=row[1], relationType=row[2]))
|
|
return KnowledgeGraph(entities=entities, relations=relations)
|
|
|
|
def search_nodes(self, query: str) -> KnowledgeGraph:
|
|
entities = []
|
|
conn = self.conn
|
|
cursor = conn.cursor()
|
|
query_lower = query.lower()
|
|
cursor.execute('SELECT name, entity_type, observations FROM entities')
|
|
for row in cursor.fetchall():
|
|
name, etype, obs = row
|
|
observations = json.loads(obs) if obs else []
|
|
if (query_lower in name.lower() or
|
|
query_lower in etype.lower() or
|
|
any(query_lower in o.lower() for o in observations)):
|
|
entities.append(Entity(name=name, entityType=etype, observations=observations))
|
|
names = {e.name.lower() for e in entities}
|
|
relations = []
|
|
cursor.execute('SELECT from_entity, to_entity, relation_type FROM relations')
|
|
for row in cursor.fetchall():
|
|
if row[0].lower() in names and row[1].lower() in names:
|
|
relations.append(Relation(from_=row[0], to=row[1], relationType=row[2]))
|
|
return KnowledgeGraph(entities=entities, relations=relations)
|
|
|
|
def open_nodes(self, names: List[str], depth: int = 1) -> KnowledgeGraph:
|
|
visited: Set[str] = set()
|
|
entities = []
|
|
relations = []
|
|
|
|
def traverse(current_names: List[str], current_depth: int):
|
|
if current_depth > depth:
|
|
return
|
|
name_set = {n.lower() for n in current_names}
|
|
new_entities = []
|
|
conn = self.conn
|
|
cursor = conn.cursor()
|
|
placeholders = ','.join('?' * len(current_names))
|
|
params = [n.lower() for n in current_names]
|
|
cursor.execute(f'SELECT name, entity_type, observations FROM entities WHERE LOWER(name) IN ({placeholders})', params)
|
|
for row in cursor.fetchall():
|
|
name, etype, obs = row
|
|
if name.lower() not in visited:
|
|
visited.add(name.lower())
|
|
observations = json.loads(obs) if obs else []
|
|
entity = Entity(name=name, entityType=etype, observations=observations)
|
|
new_entities.append(entity)
|
|
entities.append(entity)
|
|
# Find relations involving these entities
|
|
placeholders = ','.join('?' * len(new_entities))
|
|
params = [e.name.lower() for e in new_entities] * 2
|
|
cursor.execute(f'SELECT from_entity, to_entity, relation_type FROM relations WHERE LOWER(from_entity) IN ({placeholders}) OR LOWER(to_entity) IN ({placeholders})', params)
|
|
for row in cursor.fetchall():
|
|
rel = Relation(from_=row[0], to=row[1], relationType=row[2])
|
|
if rel not in relations:
|
|
relations.append(rel)
|
|
# Add related entities for next depth
|
|
if current_depth < depth:
|
|
related = [row[0], row[1]]
|
|
traverse(related, current_depth + 1)
|
|
|
|
traverse(names, 0)
|
|
return KnowledgeGraph(entities=entities, relations=relations)
|
|
|
|
def populate_from_text(self, text: str):
|
|
# Algorithm: Extract entities as capitalized words, relations from patterns, observations from sentences mentioning entities
|
|
entities = set(re.findall(r'\b[A-Z][a-zA-Z]*\b', text))
|
|
for entity in entities:
|
|
self.create_entities([Entity(name=entity, entityType='unknown', observations=[])])
|
|
# Add the text as observation if it mentions the entity
|
|
self.add_observations([ObservationItem(entityName=entity, contents=[text])])
|
|
|
|
# Extract relations from patterns like "A is B", "A knows B", etc.
|
|
patterns = [
|
|
(r'(\w+) is (a|an) (\w+)', 'is_a'),
|
|
(r'(\w+) knows (\w+)', 'knows'),
|
|
(r'(\w+) works at (\w+)', 'works_at'),
|
|
(r'(\w+) lives in (\w+)', 'lives_in'),
|
|
(r'(\w+) is (\w+)', 'is'), # general
|
|
]
|
|
for pattern, rel_type in patterns:
|
|
matches = re.findall(pattern, text, re.IGNORECASE)
|
|
for match in matches:
|
|
if len(match) == 3 and match[1].lower() in ['a', 'an']:
|
|
from_e, _, to_e = match
|
|
elif len(match) == 2:
|
|
from_e, to_e = match
|
|
else:
|
|
continue
|
|
if from_e in entities and to_e in entities:
|
|
self.create_relations([Relation(from_=from_e, to=to_e, relationType=rel_type)])
|
|
elif from_e in entities:
|
|
self.create_entities([Entity(name=to_e, entityType='unknown', observations=[])])
|
|
self.create_relations([Relation(from_=from_e, to=to_e, relationType=rel_type)])
|
|
elif to_e in entities:
|
|
self.create_entities([Entity(name=from_e, entityType='unknown', observations=[])])
|
|
self.create_relations([Relation(from_=from_e, to=to_e, relationType=rel_type)])
|
|
|