import asyncio
import hashlib
import io
import json
import os
import re
import uuid
from datetime import datetime
from typing import Optional, List
from contextlib import asynccontextmanager
from fastapi import FastAPI, UploadFile, File, Form, WebSocket, WebSocketDisconnect, HTTPException, BackgroundTasks
from fastapi.responses import FileResponse, StreamingResponse
from fastapi.staticfiles import StaticFiles
from fastapi.middleware.cors import CORSMiddleware
from sqlmodel import Field, SQLModel, create_engine, Session, select
import chromadb
from chromadb.config import Settings
import tiktoken
import openai
import pypdf
# Configuration
UPLOAD_DIR = "uploads"
DB_URL = "sqlite:///./database.db"
CHROMA_DIR = "./chroma_db"
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY", "your-api-key-here")
CHUNK_SIZE = 500
CHUNK_OVERLAP = 50
os.makedirs(UPLOAD_DIR, exist_ok=True)
os.makedirs(CHROMA_DIR, exist_ok=True)
# Database Models
class Document(SQLModel, table=True):
id: Optional[str] = Field(default=None, primary_key=True)
name: str
filename: str
markdown_content: str = Field(default="")
upload_time: datetime
status: str = "pending"
downloads: int = 0
class SearchResult(SQLModel, table=True):
id: Optional[int] = Field(default=None, primary_key=True)
query: str
slug: str = Field(index=True)
results_json: str
created_at: datetime
tokens_used: int = 0
cost_eur: float = 0.0
class PromptResult(SQLModel, table=True):
id: Optional[int] = Field(default=None, primary_key=True)
query: str
slug: str = Field(index=True)
prompt_response: str
search_results_json: str
created_at: datetime
input_tokens: int = 0
output_tokens: int = 0
cost_eur: float = 0.0
# Database Setup
engine = create_engine(DB_URL, connect_args={"check_same_thread": False})
def init_db():
SQLModel.metadata.create_all(engine)
# ChromaDB Setup
chroma_client = chromadb.PersistentClient(path=CHROMA_DIR, settings=Settings(anonymized_telemetry=False))
try:
collection = chroma_client.get_collection("documents")
except:
collection = chroma_client.create_collection("documents")
# OpenAI Setup
openai.api_key = OPENAI_API_KEY
tokenizer = tiktoken.get_encoding("cl100k_base")
# WebSocket Manager
class ConnectionManager:
def __init__(self):
self.active_connections: dict[str, WebSocket] = {}
async def connect(self, document_id: str, websocket: WebSocket):
await websocket.accept()
self.active_connections[document_id] = websocket
def disconnect(self, document_id: str):
if document_id in self.active_connections:
del self.active_connections[document_id]
async def send_message(self, document_id: str, message: dict):
if document_id in self.active_connections:
try:
await self.active_connections[document_id].send_json(message)
except:
self.disconnect(document_id)
manager = ConnectionManager()
# Lifespan
@asynccontextmanager
async def lifespan(app: FastAPI):
init_db()
yield
# FastAPI App
app = FastAPI(lifespan=lifespan)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Helper Functions
def generate_slug(text: str) -> str:
"""Generate URL-friendly slug from text"""
slug = re.sub(r'[^\w\s-]', '', text.lower())
slug = re.sub(r'[-\s]+', '-', slug)
return slug[:50] + "-" + hashlib.md5(text.encode()).hexdigest()[:8]
def convert_to_markdown(filepath: str) -> str:
"""Convert various document formats to markdown"""
ext = os.path.splitext(filepath)[1].lower()
if ext == '.pdf':
markdown_text = ""
with open(filepath, 'rb') as f:
pdf_reader = pypdf.PdfReader(f)
for page_num, page in enumerate(pdf_reader.pages):
text = page.extract_text()
markdown_text += f"\n\n## Page {page_num + 1}\n\n{text}"
return markdown_text
elif ext == '.md':
with open(filepath, 'r', encoding='utf-8', errors='ignore') as f:
return f.read()
else:
with open(filepath, 'r', encoding='utf-8', errors='ignore') as f:
content = f.read()
lines = content.split('\n')
markdown_lines = []
for line in lines:
if line.strip():
markdown_lines.append(line)
else:
markdown_lines.append('')
return '\n'.join(markdown_lines)
def chunk_text(text: str, chunk_size: int = CHUNK_SIZE, overlap: int = CHUNK_OVERLAP) -> List[dict]:
tokens = tokenizer.encode(text)
chunks = []
start = 0
# Extract page numbers from markdown headers
page_pattern = re.compile(r'##\s*Page\s+(\d+)', re.IGNORECASE)
while start < len(tokens):
end = start + chunk_size
chunk_tokens = tokens[start:end]
chunk_text = tokenizer.decode(chunk_tokens)
# Find page number in chunk
page_match = page_pattern.search(chunk_text)
page_num = int(page_match.group(1)) if page_match else 1
chunks.append({
"text": chunk_text,
"page": page_num
})
start = end - overlap
return chunks
async def get_embedding(text: str) -> tuple[List[float], int]:
response = await asyncio.to_thread(
openai.embeddings.create,
input=text,
model="text-embedding-3-small"
)
tokens_used = response.usage.total_tokens
return response.data[0].embedding, tokens_used
async def detect_prompt_intent(query: str) -> bool:
"""Detect if query is a prompt vs simple search"""
prompt_keywords = ['make', 'create', 'list', 'summarize', 'explain', 'compare', 'analyze', 'generate', 'write', 'show me', 'give me', 'find all', 'extract', 'what are', 'how many']
query_lower = query.lower()
# Check for question words or action verbs
if any(keyword in query_lower for keyword in prompt_keywords):
return True
if query_lower.endswith('?') and len(query.split()) > 3:
return True
return False
async def execute_prompt(query: str, search_results: List[dict]) -> tuple[str, int, int, float]:
"""Execute prompt using GPT-4 with search results as context"""
# Prepare context from search results
context = "\n\n".join([
f"Document: {r['name']}\nPage: {r.get('page', 'N/A')}\nContent: {r['snippet']}"
for r in search_results[:10] # Use top 10 results
])
messages = [
{"role": "system", "content": "You are a helpful assistant that answers questions based on the provided document context. Be concise and accurate."},
{"role": "user", "content": f"Context from documents:\n\n{context}\n\nUser query: {query}\n\nPlease answer based on the context provided."}
]
response = await asyncio.to_thread(
openai.chat.completions.create,
model="gpt-4o-mini",
messages=messages,
temperature=0.7,
max_tokens=1000
)
input_tokens = response.usage.prompt_tokens
output_tokens = response.usage.completion_tokens
# Calculate cost: GPT-4o-mini pricing
# $0.150 per 1M input tokens, $0.600 per 1M output tokens
cost_usd = (input_tokens / 1_000_000 * 0.150) + (output_tokens / 1_000_000 * 0.600)
cost_eur = cost_usd * 0.92
return response.choices[0].message.content, input_tokens, output_tokens, cost_eur
async def process_document(document_id: str, filepath: str):
total_tokens = 0
await manager.send_message(document_id, {
"step": "reading",
"progress": 5,
"message": "Reading file...",
"tokens": 0,
"cost_eur": 0.0
})
try:
await manager.send_message(document_id, {
"step": "converting",
"progress": 15,
"message": "Converting to markdown...",
"tokens": 0,
"cost_eur": 0.0
})
content = await asyncio.to_thread(convert_to_markdown, filepath)
with Session(engine) as session:
doc = session.get(Document, document_id)
if doc:
doc.markdown_content = content
session.add(doc)
session.commit()
await manager.send_message(document_id, {
"step": "chunking",
"progress": 25,
"message": "Splitting text into chunks...",
"tokens": 0,
"cost_eur": 0.0
})
chunks = chunk_text(content)
await manager.send_message(document_id, {
"step": "embedding",
"progress": 35,
"message": f"Processing {len(chunks)} chunks...",
"tokens": 0,
"cost_eur": 0.0
})
for i, chunk_data in enumerate(chunks):
embedding, tokens = await get_embedding(chunk_data["text"])
total_tokens += tokens
cost_usd = (total_tokens / 1000) * 0.00002
cost_eur = cost_usd * 0.92
collection.add(
ids=[f"{document_id}_chunk_{i}"],
embeddings=[embedding],
documents=[chunk_data["text"]],
metadatas=[{
"document_id": document_id,
"chunk_index": i,
"total_chunks": len(chunks),
"page": chunk_data["page"]
}]
)
progress = 35 + int((i + 1) / len(chunks) * 55)
await manager.send_message(document_id, {
"step": "embedding",
"progress": progress,
"message": f"Embedded chunk {i + 1}/{len(chunks)}",
"tokens": total_tokens,
"cost_eur": cost_eur
})
await manager.send_message(document_id, {
"step": "indexing",
"progress": 95,
"message": "Finalizing index...",
"tokens": total_tokens,
"cost_eur": (total_tokens / 1000) * 0.00002 * 0.92
})
with Session(engine) as session:
doc = session.get(Document, document_id)
if doc:
doc.status = "completed"
session.add(doc)
session.commit()
await manager.send_message(document_id, {
"step": "completed",
"progress": 100,
"message": "Processing complete!",
"tokens": total_tokens,
"cost_eur": (total_tokens / 1000) * 0.00002 * 0.92
})
except Exception as e:
with Session(engine) as session:
doc = session.get(Document, document_id)
if doc:
doc.status = "failed"
session.add(doc)
session.commit()
await manager.send_message(document_id, {
"step": "error",
"progress": 0,
"message": f"Error: {str(e)}",
"tokens": total_tokens,
"cost_eur": 0.0
})
# API Endpoints
@app.post("/api/upload")
async def upload_documents(
background_tasks: BackgroundTasks,
files: List[UploadFile] = File(...)
):
uploaded_docs = []
for file in files:
doc_id = str(uuid.uuid4())
file_ext = os.path.splitext(file.filename)[1]
filename = f"{doc_id}{file_ext}"
filepath = os.path.join(UPLOAD_DIR, filename)
content = await file.read()
with open(filepath, 'wb') as f:
f.write(content)
document = Document(
id=doc_id,
name=file.filename,
filename=filename,
upload_time=datetime.now(),
status="processing"
)
with Session(engine) as session:
session.add(document)
session.commit()
background_tasks.add_task(process_document, doc_id, filepath)
uploaded_docs.append({
"document_id": doc_id,
"name": file.filename,
"status": "processing"
})
return {"documents": uploaded_docs}
@app.get("/api/search")
async def search_documents(query: str, page: int = 1, page_size: int = 10):
try:
slug = generate_slug(query)
# Check cache first
with Session(engine) as session:
cached = session.exec(select(SearchResult).where(SearchResult.slug == slug)).first()
if cached:
results_data = json.loads(cached.results_json)
return {
"results": results_data["results"],
"total": results_data["total"],
"page": page,
"page_size": page_size,
"slug": slug,
"tokens": cached.tokens_used,
"cost_eur": cached.cost_eur,
"cached": True
}
collection_count = collection.count()
if collection_count == 0:
return {
"results": [],
"total": 0,
"page": page,
"page_size": page_size,
"slug": slug,
"tokens": 0,
"cost_eur": 0.0
}
query_embedding, tokens = await get_embedding(query)
cost_eur = (tokens / 1000) * 0.00002 * 0.92
results = collection.query(
query_embeddings=[query_embedding],
n_results=min(50, collection_count)
)
search_results = []
if results['ids'] and len(results['ids'][0]) > 0:
for i in range(len(results['ids'][0])):
document_id = results['metadatas'][0][i]['document_id']
snippet = results['documents'][0][i]
distance = results['distances'][0][i] if 'distances' in results else 0
page_num = results['metadatas'][0][i].get('page', 1)
with Session(engine) as session:
doc = session.get(Document, document_id)
if doc:
search_results.append({
"document_id": document_id,
"name": doc.name,
"snippet": snippet[:300] + "..." if len(snippet) > 300 else snippet,
"chunk_index": results['metadatas'][0][i]['chunk_index'],
"page": page_num,
"upload_date": doc.upload_time.isoformat(),
"score": 1 - distance
})
# Cache results
with Session(engine) as session:
search_cache = SearchResult(
query=query,
slug=slug,
results_json=json.dumps({"results": search_results, "total": len(search_results)}),
created_at=datetime.now(),
tokens_used=tokens,
cost_eur=cost_eur
)
session.add(search_cache)
session.commit()
start_idx = (page - 1) * page_size
end_idx = page * page_size
return {
"results": search_results[start_idx:end_idx],
"total": len(search_results),
"page": page,
"page_size": page_size,
"slug": slug,
"tokens": tokens,
"cost_eur": cost_eur
}
except Exception as e:
print(f"Search error: {str(e)}")
import traceback
traceback.print_exc()
raise HTTPException(status_code=500, detail=str(e))
@app.get("/api/search/{slug}")
async def get_cached_search(slug: str):
with Session(engine) as session:
cached = session.exec(select(SearchResult).where(SearchResult.slug == slug)).first()
if not cached:
raise HTTPException(status_code=404, detail="Search not found")
results_data = json.loads(cached.results_json)
return {
"query": cached.query,
"results": results_data["results"],
"total": results_data["total"],
"slug": slug,
"tokens": cached.tokens_used,
"cost_eur": cached.cost_eur
}
@app.post("/api/prompt")
async def execute_prompt_endpoint(query: str):
try:
slug = generate_slug(query)
# Check cache
with Session(engine) as session:
cached = session.exec(select(PromptResult).where(PromptResult.slug == slug)).first()
if cached:
return {
"response": cached.prompt_response,
"search_results": json.loads(cached.search_results_json),
"slug": slug,
"input_tokens": cached.input_tokens,
"output_tokens": cached.output_tokens,
"cost_eur": cached.cost_eur,
"cached": True
}
# First get search results
search_response = await search_documents(query, page=1, page_size=20)
# Execute prompt with results
response_text, input_tokens, output_tokens, cost_eur = await execute_prompt(
query,
search_response["results"]
)
# Add search cost to total
total_cost = cost_eur + search_response.get("cost_eur", 0)
# Cache prompt result
with Session(engine) as session:
prompt_cache = PromptResult(
query=query,
slug=slug,
prompt_response=response_text,
search_results_json=json.dumps(search_response["results"]),
created_at=datetime.now(),
input_tokens=input_tokens,
output_tokens=output_tokens,
cost_eur=total_cost
)
session.add(prompt_cache)
session.commit()
return {
"response": response_text,
"search_results": search_response["results"],
"slug": slug,
"input_tokens": input_tokens,
"output_tokens": output_tokens,
"cost_eur": total_cost
}
except Exception as e:
print(f"Prompt error: {str(e)}")
import traceback
traceback.print_exc()
raise HTTPException(status_code=500, detail=str(e))
@app.get("/api/prompt/{slug}")
async def get_cached_prompt(slug: str):
with Session(engine) as session:
cached = session.exec(select(PromptResult).where(PromptResult.slug == slug)).first()
if not cached:
raise HTTPException(status_code=404, detail="Prompt result not found")
return {
"query": cached.query,
"response": cached.prompt_response,
"search_results": json.loads(cached.search_results_json),
"slug": slug,
"input_tokens": cached.input_tokens,
"output_tokens": cached.output_tokens,
"cost_eur": cached.cost_eur
}
@app.get("/api/documents")
async def list_documents():
with Session(engine) as session:
documents = session.exec(select(Document)).all()
return [{
"id": doc.id,
"name": doc.name,
"upload_time": doc.upload_time.isoformat(),
"status": doc.status,
"downloads": doc.downloads
} for doc in documents]
@app.get("/api/document/{document_id}")
async def get_document(document_id: str):
with Session(engine) as session:
doc = session.get(Document, document_id)
if not doc:
raise HTTPException(status_code=404, detail="Document not found")
return {
"id": doc.id,
"name": doc.name,
"markdown_content": doc.markdown_content,
"upload_time": doc.upload_time.isoformat(),
"status": doc.status
}
@app.get("/api/download/{document_id}")
async def download_document(document_id: str):
with Session(engine) as session:
doc = session.get(Document, document_id)
if not doc:
raise HTTPException(status_code=404, detail="Document not found")
doc.downloads += 1
session.add(doc)
session.commit()
filepath = os.path.join(UPLOAD_DIR, doc.filename)
if not os.path.exists(filepath):
raise HTTPException(status_code=404, detail="File not found")
return FileResponse(filepath, filename=doc.name)
@app.websocket("/ws/status/{document_id}")
async def websocket_endpoint(websocket: WebSocket, document_id: str):
await manager.connect(document_id, websocket)
try:
while True:
await websocket.receive_text()
except WebSocketDisconnect:
manager.disconnect(document_id)
@app.get("/")
async def read_root():
return FileResponse("index.html")
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=9900)