639 lines
21 KiB
Python
639 lines
21 KiB
Python
|
|
import asyncio
|
||
|
|
import hashlib
|
||
|
|
import io
|
||
|
|
import json
|
||
|
|
import os
|
||
|
|
import re
|
||
|
|
import uuid
|
||
|
|
from datetime import datetime
|
||
|
|
from typing import Optional, List
|
||
|
|
from contextlib import asynccontextmanager
|
||
|
|
|
||
|
|
from fastapi import FastAPI, UploadFile, File, Form, WebSocket, WebSocketDisconnect, HTTPException, BackgroundTasks
|
||
|
|
from fastapi.responses import FileResponse, StreamingResponse
|
||
|
|
from fastapi.staticfiles import StaticFiles
|
||
|
|
from fastapi.middleware.cors import CORSMiddleware
|
||
|
|
from sqlmodel import Field, SQLModel, create_engine, Session, select
|
||
|
|
import chromadb
|
||
|
|
from chromadb.config import Settings
|
||
|
|
import tiktoken
|
||
|
|
import openai
|
||
|
|
import pypdf
|
||
|
|
|
||
|
|
# Configuration
|
||
|
|
UPLOAD_DIR = "uploads"
|
||
|
|
DB_URL = "sqlite:///./database.db"
|
||
|
|
CHROMA_DIR = "./chroma_db"
|
||
|
|
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY", "your-api-key-here")
|
||
|
|
CHUNK_SIZE = 500
|
||
|
|
CHUNK_OVERLAP = 50
|
||
|
|
|
||
|
|
os.makedirs(UPLOAD_DIR, exist_ok=True)
|
||
|
|
os.makedirs(CHROMA_DIR, exist_ok=True)
|
||
|
|
|
||
|
|
# Database Models
|
||
|
|
class Document(SQLModel, table=True):
|
||
|
|
id: Optional[str] = Field(default=None, primary_key=True)
|
||
|
|
name: str
|
||
|
|
filename: str
|
||
|
|
markdown_content: str = Field(default="")
|
||
|
|
upload_time: datetime
|
||
|
|
status: str = "pending"
|
||
|
|
downloads: int = 0
|
||
|
|
|
||
|
|
class SearchResult(SQLModel, table=True):
|
||
|
|
id: Optional[int] = Field(default=None, primary_key=True)
|
||
|
|
query: str
|
||
|
|
slug: str = Field(index=True)
|
||
|
|
results_json: str
|
||
|
|
created_at: datetime
|
||
|
|
tokens_used: int = 0
|
||
|
|
cost_eur: float = 0.0
|
||
|
|
|
||
|
|
class PromptResult(SQLModel, table=True):
|
||
|
|
id: Optional[int] = Field(default=None, primary_key=True)
|
||
|
|
query: str
|
||
|
|
slug: str = Field(index=True)
|
||
|
|
prompt_response: str
|
||
|
|
search_results_json: str
|
||
|
|
created_at: datetime
|
||
|
|
input_tokens: int = 0
|
||
|
|
output_tokens: int = 0
|
||
|
|
cost_eur: float = 0.0
|
||
|
|
|
||
|
|
# Database Setup
|
||
|
|
engine = create_engine(DB_URL, connect_args={"check_same_thread": False})
|
||
|
|
|
||
|
|
def init_db():
|
||
|
|
SQLModel.metadata.create_all(engine)
|
||
|
|
|
||
|
|
# ChromaDB Setup
|
||
|
|
chroma_client = chromadb.PersistentClient(path=CHROMA_DIR, settings=Settings(anonymized_telemetry=False))
|
||
|
|
|
||
|
|
try:
|
||
|
|
collection = chroma_client.get_collection("documents")
|
||
|
|
except:
|
||
|
|
collection = chroma_client.create_collection("documents")
|
||
|
|
|
||
|
|
# OpenAI Setup
|
||
|
|
openai.api_key = OPENAI_API_KEY
|
||
|
|
tokenizer = tiktoken.get_encoding("cl100k_base")
|
||
|
|
|
||
|
|
# WebSocket Manager
|
||
|
|
class ConnectionManager:
|
||
|
|
def __init__(self):
|
||
|
|
self.active_connections: dict[str, WebSocket] = {}
|
||
|
|
|
||
|
|
async def connect(self, document_id: str, websocket: WebSocket):
|
||
|
|
await websocket.accept()
|
||
|
|
self.active_connections[document_id] = websocket
|
||
|
|
|
||
|
|
def disconnect(self, document_id: str):
|
||
|
|
if document_id in self.active_connections:
|
||
|
|
del self.active_connections[document_id]
|
||
|
|
|
||
|
|
async def send_message(self, document_id: str, message: dict):
|
||
|
|
if document_id in self.active_connections:
|
||
|
|
try:
|
||
|
|
await self.active_connections[document_id].send_json(message)
|
||
|
|
except:
|
||
|
|
self.disconnect(document_id)
|
||
|
|
|
||
|
|
manager = ConnectionManager()
|
||
|
|
|
||
|
|
# Lifespan
|
||
|
|
@asynccontextmanager
|
||
|
|
async def lifespan(app: FastAPI):
|
||
|
|
init_db()
|
||
|
|
yield
|
||
|
|
|
||
|
|
# FastAPI App
|
||
|
|
app = FastAPI(lifespan=lifespan)
|
||
|
|
|
||
|
|
app.add_middleware(
|
||
|
|
CORSMiddleware,
|
||
|
|
allow_origins=["*"],
|
||
|
|
allow_credentials=True,
|
||
|
|
allow_methods=["*"],
|
||
|
|
allow_headers=["*"],
|
||
|
|
)
|
||
|
|
|
||
|
|
# Helper Functions
|
||
|
|
def generate_slug(text: str) -> str:
|
||
|
|
"""Generate URL-friendly slug from text"""
|
||
|
|
slug = re.sub(r'[^\w\s-]', '', text.lower())
|
||
|
|
slug = re.sub(r'[-\s]+', '-', slug)
|
||
|
|
return slug[:50] + "-" + hashlib.md5(text.encode()).hexdigest()[:8]
|
||
|
|
|
||
|
|
def convert_to_markdown(filepath: str) -> str:
|
||
|
|
"""Convert various document formats to markdown"""
|
||
|
|
ext = os.path.splitext(filepath)[1].lower()
|
||
|
|
|
||
|
|
if ext == '.pdf':
|
||
|
|
markdown_text = ""
|
||
|
|
with open(filepath, 'rb') as f:
|
||
|
|
pdf_reader = pypdf.PdfReader(f)
|
||
|
|
for page_num, page in enumerate(pdf_reader.pages):
|
||
|
|
text = page.extract_text()
|
||
|
|
markdown_text += f"\n\n## Page {page_num + 1}\n\n{text}"
|
||
|
|
return markdown_text
|
||
|
|
|
||
|
|
elif ext == '.md':
|
||
|
|
with open(filepath, 'r', encoding='utf-8', errors='ignore') as f:
|
||
|
|
return f.read()
|
||
|
|
|
||
|
|
else:
|
||
|
|
with open(filepath, 'r', encoding='utf-8', errors='ignore') as f:
|
||
|
|
content = f.read()
|
||
|
|
lines = content.split('\n')
|
||
|
|
markdown_lines = []
|
||
|
|
for line in lines:
|
||
|
|
if line.strip():
|
||
|
|
markdown_lines.append(line)
|
||
|
|
else:
|
||
|
|
markdown_lines.append('')
|
||
|
|
return '\n'.join(markdown_lines)
|
||
|
|
|
||
|
|
def chunk_text(text: str, chunk_size: int = CHUNK_SIZE, overlap: int = CHUNK_OVERLAP) -> List[dict]:
|
||
|
|
tokens = tokenizer.encode(text)
|
||
|
|
chunks = []
|
||
|
|
start = 0
|
||
|
|
|
||
|
|
# Extract page numbers from markdown headers
|
||
|
|
page_pattern = re.compile(r'##\s*Page\s+(\d+)', re.IGNORECASE)
|
||
|
|
|
||
|
|
while start < len(tokens):
|
||
|
|
end = start + chunk_size
|
||
|
|
chunk_tokens = tokens[start:end]
|
||
|
|
chunk_text = tokenizer.decode(chunk_tokens)
|
||
|
|
|
||
|
|
# Find page number in chunk
|
||
|
|
page_match = page_pattern.search(chunk_text)
|
||
|
|
page_num = int(page_match.group(1)) if page_match else 1
|
||
|
|
|
||
|
|
chunks.append({
|
||
|
|
"text": chunk_text,
|
||
|
|
"page": page_num
|
||
|
|
})
|
||
|
|
start = end - overlap
|
||
|
|
|
||
|
|
return chunks
|
||
|
|
|
||
|
|
async def get_embedding(text: str) -> tuple[List[float], int]:
|
||
|
|
response = await asyncio.to_thread(
|
||
|
|
openai.embeddings.create,
|
||
|
|
input=text,
|
||
|
|
model="text-embedding-3-small"
|
||
|
|
)
|
||
|
|
tokens_used = response.usage.total_tokens
|
||
|
|
return response.data[0].embedding, tokens_used
|
||
|
|
|
||
|
|
async def detect_prompt_intent(query: str) -> bool:
|
||
|
|
"""Detect if query is a prompt vs simple search"""
|
||
|
|
prompt_keywords = ['make', 'create', 'list', 'summarize', 'explain', 'compare', 'analyze', 'generate', 'write', 'show me', 'give me', 'find all', 'extract', 'what are', 'how many']
|
||
|
|
query_lower = query.lower()
|
||
|
|
|
||
|
|
# Check for question words or action verbs
|
||
|
|
if any(keyword in query_lower for keyword in prompt_keywords):
|
||
|
|
return True
|
||
|
|
if query_lower.endswith('?') and len(query.split()) > 3:
|
||
|
|
return True
|
||
|
|
return False
|
||
|
|
|
||
|
|
async def execute_prompt(query: str, search_results: List[dict]) -> tuple[str, int, int, float]:
|
||
|
|
"""Execute prompt using GPT-4 with search results as context"""
|
||
|
|
|
||
|
|
# Prepare context from search results
|
||
|
|
context = "\n\n".join([
|
||
|
|
f"Document: {r['name']}\nPage: {r.get('page', 'N/A')}\nContent: {r['snippet']}"
|
||
|
|
for r in search_results[:10] # Use top 10 results
|
||
|
|
])
|
||
|
|
|
||
|
|
messages = [
|
||
|
|
{"role": "system", "content": "You are a helpful assistant that answers questions based on the provided document context. Be concise and accurate."},
|
||
|
|
{"role": "user", "content": f"Context from documents:\n\n{context}\n\nUser query: {query}\n\nPlease answer based on the context provided."}
|
||
|
|
]
|
||
|
|
|
||
|
|
response = await asyncio.to_thread(
|
||
|
|
openai.chat.completions.create,
|
||
|
|
model="gpt-4o-mini",
|
||
|
|
messages=messages,
|
||
|
|
temperature=0.7,
|
||
|
|
max_tokens=1000
|
||
|
|
)
|
||
|
|
|
||
|
|
input_tokens = response.usage.prompt_tokens
|
||
|
|
output_tokens = response.usage.completion_tokens
|
||
|
|
|
||
|
|
# Calculate cost: GPT-4o-mini pricing
|
||
|
|
# $0.150 per 1M input tokens, $0.600 per 1M output tokens
|
||
|
|
cost_usd = (input_tokens / 1_000_000 * 0.150) + (output_tokens / 1_000_000 * 0.600)
|
||
|
|
cost_eur = cost_usd * 0.92
|
||
|
|
|
||
|
|
return response.choices[0].message.content, input_tokens, output_tokens, cost_eur
|
||
|
|
|
||
|
|
async def process_document(document_id: str, filepath: str):
|
||
|
|
total_tokens = 0
|
||
|
|
|
||
|
|
await manager.send_message(document_id, {
|
||
|
|
"step": "reading",
|
||
|
|
"progress": 5,
|
||
|
|
"message": "Reading file...",
|
||
|
|
"tokens": 0,
|
||
|
|
"cost_eur": 0.0
|
||
|
|
})
|
||
|
|
|
||
|
|
try:
|
||
|
|
await manager.send_message(document_id, {
|
||
|
|
"step": "converting",
|
||
|
|
"progress": 15,
|
||
|
|
"message": "Converting to markdown...",
|
||
|
|
"tokens": 0,
|
||
|
|
"cost_eur": 0.0
|
||
|
|
})
|
||
|
|
|
||
|
|
content = await asyncio.to_thread(convert_to_markdown, filepath)
|
||
|
|
|
||
|
|
with Session(engine) as session:
|
||
|
|
doc = session.get(Document, document_id)
|
||
|
|
if doc:
|
||
|
|
doc.markdown_content = content
|
||
|
|
session.add(doc)
|
||
|
|
session.commit()
|
||
|
|
|
||
|
|
await manager.send_message(document_id, {
|
||
|
|
"step": "chunking",
|
||
|
|
"progress": 25,
|
||
|
|
"message": "Splitting text into chunks...",
|
||
|
|
"tokens": 0,
|
||
|
|
"cost_eur": 0.0
|
||
|
|
})
|
||
|
|
|
||
|
|
chunks = chunk_text(content)
|
||
|
|
|
||
|
|
await manager.send_message(document_id, {
|
||
|
|
"step": "embedding",
|
||
|
|
"progress": 35,
|
||
|
|
"message": f"Processing {len(chunks)} chunks...",
|
||
|
|
"tokens": 0,
|
||
|
|
"cost_eur": 0.0
|
||
|
|
})
|
||
|
|
|
||
|
|
for i, chunk_data in enumerate(chunks):
|
||
|
|
embedding, tokens = await get_embedding(chunk_data["text"])
|
||
|
|
total_tokens += tokens
|
||
|
|
|
||
|
|
cost_usd = (total_tokens / 1000) * 0.00002
|
||
|
|
cost_eur = cost_usd * 0.92
|
||
|
|
|
||
|
|
collection.add(
|
||
|
|
ids=[f"{document_id}_chunk_{i}"],
|
||
|
|
embeddings=[embedding],
|
||
|
|
documents=[chunk_data["text"]],
|
||
|
|
metadatas=[{
|
||
|
|
"document_id": document_id,
|
||
|
|
"chunk_index": i,
|
||
|
|
"total_chunks": len(chunks),
|
||
|
|
"page": chunk_data["page"]
|
||
|
|
}]
|
||
|
|
)
|
||
|
|
|
||
|
|
progress = 35 + int((i + 1) / len(chunks) * 55)
|
||
|
|
await manager.send_message(document_id, {
|
||
|
|
"step": "embedding",
|
||
|
|
"progress": progress,
|
||
|
|
"message": f"Embedded chunk {i + 1}/{len(chunks)}",
|
||
|
|
"tokens": total_tokens,
|
||
|
|
"cost_eur": cost_eur
|
||
|
|
})
|
||
|
|
|
||
|
|
await manager.send_message(document_id, {
|
||
|
|
"step": "indexing",
|
||
|
|
"progress": 95,
|
||
|
|
"message": "Finalizing index...",
|
||
|
|
"tokens": total_tokens,
|
||
|
|
"cost_eur": (total_tokens / 1000) * 0.00002 * 0.92
|
||
|
|
})
|
||
|
|
|
||
|
|
with Session(engine) as session:
|
||
|
|
doc = session.get(Document, document_id)
|
||
|
|
if doc:
|
||
|
|
doc.status = "completed"
|
||
|
|
session.add(doc)
|
||
|
|
session.commit()
|
||
|
|
|
||
|
|
await manager.send_message(document_id, {
|
||
|
|
"step": "completed",
|
||
|
|
"progress": 100,
|
||
|
|
"message": "Processing complete!",
|
||
|
|
"tokens": total_tokens,
|
||
|
|
"cost_eur": (total_tokens / 1000) * 0.00002 * 0.92
|
||
|
|
})
|
||
|
|
|
||
|
|
except Exception as e:
|
||
|
|
with Session(engine) as session:
|
||
|
|
doc = session.get(Document, document_id)
|
||
|
|
if doc:
|
||
|
|
doc.status = "failed"
|
||
|
|
session.add(doc)
|
||
|
|
session.commit()
|
||
|
|
|
||
|
|
await manager.send_message(document_id, {
|
||
|
|
"step": "error",
|
||
|
|
"progress": 0,
|
||
|
|
"message": f"Error: {str(e)}",
|
||
|
|
"tokens": total_tokens,
|
||
|
|
"cost_eur": 0.0
|
||
|
|
})
|
||
|
|
|
||
|
|
# API Endpoints
|
||
|
|
@app.post("/api/upload")
|
||
|
|
async def upload_documents(
|
||
|
|
background_tasks: BackgroundTasks,
|
||
|
|
files: List[UploadFile] = File(...)
|
||
|
|
):
|
||
|
|
uploaded_docs = []
|
||
|
|
|
||
|
|
for file in files:
|
||
|
|
doc_id = str(uuid.uuid4())
|
||
|
|
file_ext = os.path.splitext(file.filename)[1]
|
||
|
|
filename = f"{doc_id}{file_ext}"
|
||
|
|
filepath = os.path.join(UPLOAD_DIR, filename)
|
||
|
|
|
||
|
|
content = await file.read()
|
||
|
|
with open(filepath, 'wb') as f:
|
||
|
|
f.write(content)
|
||
|
|
|
||
|
|
document = Document(
|
||
|
|
id=doc_id,
|
||
|
|
name=file.filename,
|
||
|
|
filename=filename,
|
||
|
|
upload_time=datetime.now(),
|
||
|
|
status="processing"
|
||
|
|
)
|
||
|
|
|
||
|
|
with Session(engine) as session:
|
||
|
|
session.add(document)
|
||
|
|
session.commit()
|
||
|
|
|
||
|
|
background_tasks.add_task(process_document, doc_id, filepath)
|
||
|
|
|
||
|
|
uploaded_docs.append({
|
||
|
|
"document_id": doc_id,
|
||
|
|
"name": file.filename,
|
||
|
|
"status": "processing"
|
||
|
|
})
|
||
|
|
|
||
|
|
return {"documents": uploaded_docs}
|
||
|
|
|
||
|
|
@app.get("/api/search")
|
||
|
|
async def search_documents(query: str, page: int = 1, page_size: int = 10):
|
||
|
|
try:
|
||
|
|
slug = generate_slug(query)
|
||
|
|
|
||
|
|
# Check cache first
|
||
|
|
with Session(engine) as session:
|
||
|
|
cached = session.exec(select(SearchResult).where(SearchResult.slug == slug)).first()
|
||
|
|
if cached:
|
||
|
|
results_data = json.loads(cached.results_json)
|
||
|
|
return {
|
||
|
|
"results": results_data["results"],
|
||
|
|
"total": results_data["total"],
|
||
|
|
"page": page,
|
||
|
|
"page_size": page_size,
|
||
|
|
"slug": slug,
|
||
|
|
"tokens": cached.tokens_used,
|
||
|
|
"cost_eur": cached.cost_eur,
|
||
|
|
"cached": True
|
||
|
|
}
|
||
|
|
|
||
|
|
collection_count = collection.count()
|
||
|
|
if collection_count == 0:
|
||
|
|
return {
|
||
|
|
"results": [],
|
||
|
|
"total": 0,
|
||
|
|
"page": page,
|
||
|
|
"page_size": page_size,
|
||
|
|
"slug": slug,
|
||
|
|
"tokens": 0,
|
||
|
|
"cost_eur": 0.0
|
||
|
|
}
|
||
|
|
|
||
|
|
query_embedding, tokens = await get_embedding(query)
|
||
|
|
cost_eur = (tokens / 1000) * 0.00002 * 0.92
|
||
|
|
|
||
|
|
results = collection.query(
|
||
|
|
query_embeddings=[query_embedding],
|
||
|
|
n_results=min(50, collection_count)
|
||
|
|
)
|
||
|
|
|
||
|
|
search_results = []
|
||
|
|
|
||
|
|
if results['ids'] and len(results['ids'][0]) > 0:
|
||
|
|
for i in range(len(results['ids'][0])):
|
||
|
|
document_id = results['metadatas'][0][i]['document_id']
|
||
|
|
snippet = results['documents'][0][i]
|
||
|
|
distance = results['distances'][0][i] if 'distances' in results else 0
|
||
|
|
page_num = results['metadatas'][0][i].get('page', 1)
|
||
|
|
|
||
|
|
with Session(engine) as session:
|
||
|
|
doc = session.get(Document, document_id)
|
||
|
|
if doc:
|
||
|
|
search_results.append({
|
||
|
|
"document_id": document_id,
|
||
|
|
"name": doc.name,
|
||
|
|
"snippet": snippet[:300] + "..." if len(snippet) > 300 else snippet,
|
||
|
|
"chunk_index": results['metadatas'][0][i]['chunk_index'],
|
||
|
|
"page": page_num,
|
||
|
|
"upload_date": doc.upload_time.isoformat(),
|
||
|
|
"score": 1 - distance
|
||
|
|
})
|
||
|
|
|
||
|
|
# Cache results
|
||
|
|
with Session(engine) as session:
|
||
|
|
search_cache = SearchResult(
|
||
|
|
query=query,
|
||
|
|
slug=slug,
|
||
|
|
results_json=json.dumps({"results": search_results, "total": len(search_results)}),
|
||
|
|
created_at=datetime.now(),
|
||
|
|
tokens_used=tokens,
|
||
|
|
cost_eur=cost_eur
|
||
|
|
)
|
||
|
|
session.add(search_cache)
|
||
|
|
session.commit()
|
||
|
|
|
||
|
|
start_idx = (page - 1) * page_size
|
||
|
|
end_idx = page * page_size
|
||
|
|
|
||
|
|
return {
|
||
|
|
"results": search_results[start_idx:end_idx],
|
||
|
|
"total": len(search_results),
|
||
|
|
"page": page,
|
||
|
|
"page_size": page_size,
|
||
|
|
"slug": slug,
|
||
|
|
"tokens": tokens,
|
||
|
|
"cost_eur": cost_eur
|
||
|
|
}
|
||
|
|
|
||
|
|
except Exception as e:
|
||
|
|
print(f"Search error: {str(e)}")
|
||
|
|
import traceback
|
||
|
|
traceback.print_exc()
|
||
|
|
raise HTTPException(status_code=500, detail=str(e))
|
||
|
|
|
||
|
|
@app.get("/api/search/{slug}")
|
||
|
|
async def get_cached_search(slug: str):
|
||
|
|
with Session(engine) as session:
|
||
|
|
cached = session.exec(select(SearchResult).where(SearchResult.slug == slug)).first()
|
||
|
|
if not cached:
|
||
|
|
raise HTTPException(status_code=404, detail="Search not found")
|
||
|
|
|
||
|
|
results_data = json.loads(cached.results_json)
|
||
|
|
return {
|
||
|
|
"query": cached.query,
|
||
|
|
"results": results_data["results"],
|
||
|
|
"total": results_data["total"],
|
||
|
|
"slug": slug,
|
||
|
|
"tokens": cached.tokens_used,
|
||
|
|
"cost_eur": cached.cost_eur
|
||
|
|
}
|
||
|
|
|
||
|
|
@app.post("/api/prompt")
|
||
|
|
async def execute_prompt_endpoint(query: str):
|
||
|
|
try:
|
||
|
|
slug = generate_slug(query)
|
||
|
|
|
||
|
|
# Check cache
|
||
|
|
with Session(engine) as session:
|
||
|
|
cached = session.exec(select(PromptResult).where(PromptResult.slug == slug)).first()
|
||
|
|
if cached:
|
||
|
|
return {
|
||
|
|
"response": cached.prompt_response,
|
||
|
|
"search_results": json.loads(cached.search_results_json),
|
||
|
|
"slug": slug,
|
||
|
|
"input_tokens": cached.input_tokens,
|
||
|
|
"output_tokens": cached.output_tokens,
|
||
|
|
"cost_eur": cached.cost_eur,
|
||
|
|
"cached": True
|
||
|
|
}
|
||
|
|
|
||
|
|
# First get search results
|
||
|
|
search_response = await search_documents(query, page=1, page_size=20)
|
||
|
|
|
||
|
|
# Execute prompt with results
|
||
|
|
response_text, input_tokens, output_tokens, cost_eur = await execute_prompt(
|
||
|
|
query,
|
||
|
|
search_response["results"]
|
||
|
|
)
|
||
|
|
|
||
|
|
# Add search cost to total
|
||
|
|
total_cost = cost_eur + search_response.get("cost_eur", 0)
|
||
|
|
|
||
|
|
# Cache prompt result
|
||
|
|
with Session(engine) as session:
|
||
|
|
prompt_cache = PromptResult(
|
||
|
|
query=query,
|
||
|
|
slug=slug,
|
||
|
|
prompt_response=response_text,
|
||
|
|
search_results_json=json.dumps(search_response["results"]),
|
||
|
|
created_at=datetime.now(),
|
||
|
|
input_tokens=input_tokens,
|
||
|
|
output_tokens=output_tokens,
|
||
|
|
cost_eur=total_cost
|
||
|
|
)
|
||
|
|
session.add(prompt_cache)
|
||
|
|
session.commit()
|
||
|
|
|
||
|
|
return {
|
||
|
|
"response": response_text,
|
||
|
|
"search_results": search_response["results"],
|
||
|
|
"slug": slug,
|
||
|
|
"input_tokens": input_tokens,
|
||
|
|
"output_tokens": output_tokens,
|
||
|
|
"cost_eur": total_cost
|
||
|
|
}
|
||
|
|
|
||
|
|
except Exception as e:
|
||
|
|
print(f"Prompt error: {str(e)}")
|
||
|
|
import traceback
|
||
|
|
traceback.print_exc()
|
||
|
|
raise HTTPException(status_code=500, detail=str(e))
|
||
|
|
|
||
|
|
@app.get("/api/prompt/{slug}")
|
||
|
|
async def get_cached_prompt(slug: str):
|
||
|
|
with Session(engine) as session:
|
||
|
|
cached = session.exec(select(PromptResult).where(PromptResult.slug == slug)).first()
|
||
|
|
if not cached:
|
||
|
|
raise HTTPException(status_code=404, detail="Prompt result not found")
|
||
|
|
|
||
|
|
return {
|
||
|
|
"query": cached.query,
|
||
|
|
"response": cached.prompt_response,
|
||
|
|
"search_results": json.loads(cached.search_results_json),
|
||
|
|
"slug": slug,
|
||
|
|
"input_tokens": cached.input_tokens,
|
||
|
|
"output_tokens": cached.output_tokens,
|
||
|
|
"cost_eur": cached.cost_eur
|
||
|
|
}
|
||
|
|
|
||
|
|
@app.get("/api/documents")
|
||
|
|
async def list_documents():
|
||
|
|
with Session(engine) as session:
|
||
|
|
documents = session.exec(select(Document)).all()
|
||
|
|
return [{
|
||
|
|
"id": doc.id,
|
||
|
|
"name": doc.name,
|
||
|
|
"upload_time": doc.upload_time.isoformat(),
|
||
|
|
"status": doc.status,
|
||
|
|
"downloads": doc.downloads
|
||
|
|
} for doc in documents]
|
||
|
|
|
||
|
|
@app.get("/api/document/{document_id}")
|
||
|
|
async def get_document(document_id: str):
|
||
|
|
with Session(engine) as session:
|
||
|
|
doc = session.get(Document, document_id)
|
||
|
|
if not doc:
|
||
|
|
raise HTTPException(status_code=404, detail="Document not found")
|
||
|
|
|
||
|
|
return {
|
||
|
|
"id": doc.id,
|
||
|
|
"name": doc.name,
|
||
|
|
"markdown_content": doc.markdown_content,
|
||
|
|
"upload_time": doc.upload_time.isoformat(),
|
||
|
|
"status": doc.status
|
||
|
|
}
|
||
|
|
|
||
|
|
@app.get("/api/download/{document_id}")
|
||
|
|
async def download_document(document_id: str):
|
||
|
|
with Session(engine) as session:
|
||
|
|
doc = session.get(Document, document_id)
|
||
|
|
if not doc:
|
||
|
|
raise HTTPException(status_code=404, detail="Document not found")
|
||
|
|
|
||
|
|
doc.downloads += 1
|
||
|
|
session.add(doc)
|
||
|
|
session.commit()
|
||
|
|
|
||
|
|
filepath = os.path.join(UPLOAD_DIR, doc.filename)
|
||
|
|
if not os.path.exists(filepath):
|
||
|
|
raise HTTPException(status_code=404, detail="File not found")
|
||
|
|
|
||
|
|
return FileResponse(filepath, filename=doc.name)
|
||
|
|
|
||
|
|
@app.websocket("/ws/status/{document_id}")
|
||
|
|
async def websocket_endpoint(websocket: WebSocket, document_id: str):
|
||
|
|
await manager.connect(document_id, websocket)
|
||
|
|
try:
|
||
|
|
while True:
|
||
|
|
await websocket.receive_text()
|
||
|
|
except WebSocketDisconnect:
|
||
|
|
manager.disconnect(document_id)
|
||
|
|
|
||
|
|
@app.get("/")
|
||
|
|
async def read_root():
|
||
|
|
return FileResponse("index.html")
|
||
|
|
|
||
|
|
if __name__ == "__main__":
|
||
|
|
import uvicorn
|
||
|
|
uvicorn.run(app, host="0.0.0.0", port=9900)
|