import asyncio import hashlib import io import json import os import re import uuid from datetime import datetime from typing import Optional, List from contextlib import asynccontextmanager from fastapi import FastAPI, UploadFile, File, Form, WebSocket, WebSocketDisconnect, HTTPException, BackgroundTasks from fastapi.responses import FileResponse, StreamingResponse from fastapi.staticfiles import StaticFiles from fastapi.middleware.cors import CORSMiddleware from sqlmodel import Field, SQLModel, create_engine, Session, select import chromadb from chromadb.config import Settings import tiktoken import openai import pypdf # Configuration UPLOAD_DIR = "uploads" DB_URL = "sqlite:///./database.db" CHROMA_DIR = "./chroma_db" OPENAI_API_KEY = os.getenv("OPENAI_API_KEY", "your-api-key-here") CHUNK_SIZE = 500 CHUNK_OVERLAP = 50 os.makedirs(UPLOAD_DIR, exist_ok=True) os.makedirs(CHROMA_DIR, exist_ok=True) # Database Models class Document(SQLModel, table=True): id: Optional[str] = Field(default=None, primary_key=True) name: str filename: str markdown_content: str = Field(default="") upload_time: datetime status: str = "pending" downloads: int = 0 class SearchResult(SQLModel, table=True): id: Optional[int] = Field(default=None, primary_key=True) query: str slug: str = Field(index=True) results_json: str created_at: datetime tokens_used: int = 0 cost_eur: float = 0.0 class PromptResult(SQLModel, table=True): id: Optional[int] = Field(default=None, primary_key=True) query: str slug: str = Field(index=True) prompt_response: str search_results_json: str created_at: datetime input_tokens: int = 0 output_tokens: int = 0 cost_eur: float = 0.0 # Database Setup engine = create_engine(DB_URL, connect_args={"check_same_thread": False}) def init_db(): SQLModel.metadata.create_all(engine) # ChromaDB Setup chroma_client = chromadb.PersistentClient(path=CHROMA_DIR, settings=Settings(anonymized_telemetry=False)) try: collection = chroma_client.get_collection("documents") except: collection = chroma_client.create_collection("documents") # OpenAI Setup openai.api_key = OPENAI_API_KEY tokenizer = tiktoken.get_encoding("cl100k_base") # WebSocket Manager class ConnectionManager: def __init__(self): self.active_connections: dict[str, WebSocket] = {} async def connect(self, document_id: str, websocket: WebSocket): await websocket.accept() self.active_connections[document_id] = websocket def disconnect(self, document_id: str): if document_id in self.active_connections: del self.active_connections[document_id] async def send_message(self, document_id: str, message: dict): if document_id in self.active_connections: try: await self.active_connections[document_id].send_json(message) except: self.disconnect(document_id) manager = ConnectionManager() # Lifespan @asynccontextmanager async def lifespan(app: FastAPI): init_db() yield # FastAPI App app = FastAPI(lifespan=lifespan) app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # Helper Functions def generate_slug(text: str) -> str: """Generate URL-friendly slug from text""" slug = re.sub(r'[^\w\s-]', '', text.lower()) slug = re.sub(r'[-\s]+', '-', slug) return slug[:50] + "-" + hashlib.md5(text.encode()).hexdigest()[:8] def convert_to_markdown(filepath: str) -> str: """Convert various document formats to markdown""" ext = os.path.splitext(filepath)[1].lower() if ext == '.pdf': markdown_text = "" with open(filepath, 'rb') as f: pdf_reader = pypdf.PdfReader(f) for page_num, page in enumerate(pdf_reader.pages): text = page.extract_text() markdown_text += f"\n\n## Page {page_num + 1}\n\n{text}" return markdown_text elif ext == '.md': with open(filepath, 'r', encoding='utf-8', errors='ignore') as f: return f.read() else: with open(filepath, 'r', encoding='utf-8', errors='ignore') as f: content = f.read() lines = content.split('\n') markdown_lines = [] for line in lines: if line.strip(): markdown_lines.append(line) else: markdown_lines.append('') return '\n'.join(markdown_lines) def chunk_text(text: str, chunk_size: int = CHUNK_SIZE, overlap: int = CHUNK_OVERLAP) -> List[dict]: tokens = tokenizer.encode(text) chunks = [] start = 0 # Extract page numbers from markdown headers page_pattern = re.compile(r'##\s*Page\s+(\d+)', re.IGNORECASE) while start < len(tokens): end = start + chunk_size chunk_tokens = tokens[start:end] chunk_text = tokenizer.decode(chunk_tokens) # Find page number in chunk page_match = page_pattern.search(chunk_text) page_num = int(page_match.group(1)) if page_match else 1 chunks.append({ "text": chunk_text, "page": page_num }) start = end - overlap return chunks async def get_embedding(text: str) -> tuple[List[float], int]: response = await asyncio.to_thread( openai.embeddings.create, input=text, model="text-embedding-3-small" ) tokens_used = response.usage.total_tokens return response.data[0].embedding, tokens_used async def detect_prompt_intent(query: str) -> bool: """Detect if query is a prompt vs simple search""" prompt_keywords = ['make', 'create', 'list', 'summarize', 'explain', 'compare', 'analyze', 'generate', 'write', 'show me', 'give me', 'find all', 'extract', 'what are', 'how many'] query_lower = query.lower() # Check for question words or action verbs if any(keyword in query_lower for keyword in prompt_keywords): return True if query_lower.endswith('?') and len(query.split()) > 3: return True return False async def execute_prompt(query: str, search_results: List[dict]) -> tuple[str, int, int, float]: """Execute prompt using GPT-4 with search results as context""" # Prepare context from search results context = "\n\n".join([ f"Document: {r['name']}\nPage: {r.get('page', 'N/A')}\nContent: {r['snippet']}" for r in search_results[:10] # Use top 10 results ]) messages = [ {"role": "system", "content": "You are a helpful assistant that answers questions based on the provided document context. Be concise and accurate."}, {"role": "user", "content": f"Context from documents:\n\n{context}\n\nUser query: {query}\n\nPlease answer based on the context provided."} ] response = await asyncio.to_thread( openai.chat.completions.create, model="gpt-4o-mini", messages=messages, temperature=0.7, max_tokens=1000 ) input_tokens = response.usage.prompt_tokens output_tokens = response.usage.completion_tokens # Calculate cost: GPT-4o-mini pricing # $0.150 per 1M input tokens, $0.600 per 1M output tokens cost_usd = (input_tokens / 1_000_000 * 0.150) + (output_tokens / 1_000_000 * 0.600) cost_eur = cost_usd * 0.92 return response.choices[0].message.content, input_tokens, output_tokens, cost_eur async def process_document(document_id: str, filepath: str): total_tokens = 0 await manager.send_message(document_id, { "step": "reading", "progress": 5, "message": "Reading file...", "tokens": 0, "cost_eur": 0.0 }) try: await manager.send_message(document_id, { "step": "converting", "progress": 15, "message": "Converting to markdown...", "tokens": 0, "cost_eur": 0.0 }) content = await asyncio.to_thread(convert_to_markdown, filepath) with Session(engine) as session: doc = session.get(Document, document_id) if doc: doc.markdown_content = content session.add(doc) session.commit() await manager.send_message(document_id, { "step": "chunking", "progress": 25, "message": "Splitting text into chunks...", "tokens": 0, "cost_eur": 0.0 }) chunks = chunk_text(content) await manager.send_message(document_id, { "step": "embedding", "progress": 35, "message": f"Processing {len(chunks)} chunks...", "tokens": 0, "cost_eur": 0.0 }) for i, chunk_data in enumerate(chunks): embedding, tokens = await get_embedding(chunk_data["text"]) total_tokens += tokens cost_usd = (total_tokens / 1000) * 0.00002 cost_eur = cost_usd * 0.92 collection.add( ids=[f"{document_id}_chunk_{i}"], embeddings=[embedding], documents=[chunk_data["text"]], metadatas=[{ "document_id": document_id, "chunk_index": i, "total_chunks": len(chunks), "page": chunk_data["page"] }] ) progress = 35 + int((i + 1) / len(chunks) * 55) await manager.send_message(document_id, { "step": "embedding", "progress": progress, "message": f"Embedded chunk {i + 1}/{len(chunks)}", "tokens": total_tokens, "cost_eur": cost_eur }) await manager.send_message(document_id, { "step": "indexing", "progress": 95, "message": "Finalizing index...", "tokens": total_tokens, "cost_eur": (total_tokens / 1000) * 0.00002 * 0.92 }) with Session(engine) as session: doc = session.get(Document, document_id) if doc: doc.status = "completed" session.add(doc) session.commit() await manager.send_message(document_id, { "step": "completed", "progress": 100, "message": "Processing complete!", "tokens": total_tokens, "cost_eur": (total_tokens / 1000) * 0.00002 * 0.92 }) except Exception as e: with Session(engine) as session: doc = session.get(Document, document_id) if doc: doc.status = "failed" session.add(doc) session.commit() await manager.send_message(document_id, { "step": "error", "progress": 0, "message": f"Error: {str(e)}", "tokens": total_tokens, "cost_eur": 0.0 }) # API Endpoints @app.post("/api/upload") async def upload_documents( background_tasks: BackgroundTasks, files: List[UploadFile] = File(...) ): uploaded_docs = [] for file in files: doc_id = str(uuid.uuid4()) file_ext = os.path.splitext(file.filename)[1] filename = f"{doc_id}{file_ext}" filepath = os.path.join(UPLOAD_DIR, filename) content = await file.read() with open(filepath, 'wb') as f: f.write(content) document = Document( id=doc_id, name=file.filename, filename=filename, upload_time=datetime.now(), status="processing" ) with Session(engine) as session: session.add(document) session.commit() background_tasks.add_task(process_document, doc_id, filepath) uploaded_docs.append({ "document_id": doc_id, "name": file.filename, "status": "processing" }) return {"documents": uploaded_docs} @app.get("/api/search") async def search_documents(query: str, page: int = 1, page_size: int = 10): try: slug = generate_slug(query) # Check cache first with Session(engine) as session: cached = session.exec(select(SearchResult).where(SearchResult.slug == slug)).first() if cached: results_data = json.loads(cached.results_json) return { "results": results_data["results"], "total": results_data["total"], "page": page, "page_size": page_size, "slug": slug, "tokens": cached.tokens_used, "cost_eur": cached.cost_eur, "cached": True } collection_count = collection.count() if collection_count == 0: return { "results": [], "total": 0, "page": page, "page_size": page_size, "slug": slug, "tokens": 0, "cost_eur": 0.0 } query_embedding, tokens = await get_embedding(query) cost_eur = (tokens / 1000) * 0.00002 * 0.92 results = collection.query( query_embeddings=[query_embedding], n_results=min(50, collection_count) ) search_results = [] if results['ids'] and len(results['ids'][0]) > 0: for i in range(len(results['ids'][0])): document_id = results['metadatas'][0][i]['document_id'] snippet = results['documents'][0][i] distance = results['distances'][0][i] if 'distances' in results else 0 page_num = results['metadatas'][0][i].get('page', 1) with Session(engine) as session: doc = session.get(Document, document_id) if doc: search_results.append({ "document_id": document_id, "name": doc.name, "snippet": snippet[:300] + "..." if len(snippet) > 300 else snippet, "chunk_index": results['metadatas'][0][i]['chunk_index'], "page": page_num, "upload_date": doc.upload_time.isoformat(), "score": 1 - distance }) # Cache results with Session(engine) as session: search_cache = SearchResult( query=query, slug=slug, results_json=json.dumps({"results": search_results, "total": len(search_results)}), created_at=datetime.now(), tokens_used=tokens, cost_eur=cost_eur ) session.add(search_cache) session.commit() start_idx = (page - 1) * page_size end_idx = page * page_size return { "results": search_results[start_idx:end_idx], "total": len(search_results), "page": page, "page_size": page_size, "slug": slug, "tokens": tokens, "cost_eur": cost_eur } except Exception as e: print(f"Search error: {str(e)}") import traceback traceback.print_exc() raise HTTPException(status_code=500, detail=str(e)) @app.get("/api/search/{slug}") async def get_cached_search(slug: str): with Session(engine) as session: cached = session.exec(select(SearchResult).where(SearchResult.slug == slug)).first() if not cached: raise HTTPException(status_code=404, detail="Search not found") results_data = json.loads(cached.results_json) return { "query": cached.query, "results": results_data["results"], "total": results_data["total"], "slug": slug, "tokens": cached.tokens_used, "cost_eur": cached.cost_eur } @app.post("/api/prompt") async def execute_prompt_endpoint(query: str): try: slug = generate_slug(query) # Check cache with Session(engine) as session: cached = session.exec(select(PromptResult).where(PromptResult.slug == slug)).first() if cached: return { "response": cached.prompt_response, "search_results": json.loads(cached.search_results_json), "slug": slug, "input_tokens": cached.input_tokens, "output_tokens": cached.output_tokens, "cost_eur": cached.cost_eur, "cached": True } # First get search results search_response = await search_documents(query, page=1, page_size=20) # Execute prompt with results response_text, input_tokens, output_tokens, cost_eur = await execute_prompt( query, search_response["results"] ) # Add search cost to total total_cost = cost_eur + search_response.get("cost_eur", 0) # Cache prompt result with Session(engine) as session: prompt_cache = PromptResult( query=query, slug=slug, prompt_response=response_text, search_results_json=json.dumps(search_response["results"]), created_at=datetime.now(), input_tokens=input_tokens, output_tokens=output_tokens, cost_eur=total_cost ) session.add(prompt_cache) session.commit() return { "response": response_text, "search_results": search_response["results"], "slug": slug, "input_tokens": input_tokens, "output_tokens": output_tokens, "cost_eur": total_cost } except Exception as e: print(f"Prompt error: {str(e)}") import traceback traceback.print_exc() raise HTTPException(status_code=500, detail=str(e)) @app.get("/api/prompt/{slug}") async def get_cached_prompt(slug: str): with Session(engine) as session: cached = session.exec(select(PromptResult).where(PromptResult.slug == slug)).first() if not cached: raise HTTPException(status_code=404, detail="Prompt result not found") return { "query": cached.query, "response": cached.prompt_response, "search_results": json.loads(cached.search_results_json), "slug": slug, "input_tokens": cached.input_tokens, "output_tokens": cached.output_tokens, "cost_eur": cached.cost_eur } @app.get("/api/documents") async def list_documents(): with Session(engine) as session: documents = session.exec(select(Document)).all() return [{ "id": doc.id, "name": doc.name, "upload_time": doc.upload_time.isoformat(), "status": doc.status, "downloads": doc.downloads } for doc in documents] @app.get("/api/document/{document_id}") async def get_document(document_id: str): with Session(engine) as session: doc = session.get(Document, document_id) if not doc: raise HTTPException(status_code=404, detail="Document not found") return { "id": doc.id, "name": doc.name, "markdown_content": doc.markdown_content, "upload_time": doc.upload_time.isoformat(), "status": doc.status } @app.get("/api/download/{document_id}") async def download_document(document_id: str): with Session(engine) as session: doc = session.get(Document, document_id) if not doc: raise HTTPException(status_code=404, detail="Document not found") doc.downloads += 1 session.add(doc) session.commit() filepath = os.path.join(UPLOAD_DIR, doc.filename) if not os.path.exists(filepath): raise HTTPException(status_code=404, detail="File not found") return FileResponse(filepath, filename=doc.name) @app.websocket("/ws/status/{document_id}") async def websocket_endpoint(websocket: WebSocket, document_id: str): await manager.connect(document_id, websocket) try: while True: await websocket.receive_text() except WebSocketDisconnect: manager.disconnect(document_id) @app.get("/") async def read_root(): return FileResponse("index.html") if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=9900)