Redmind's picture
Update app.py
c2710ab verified
raw
history blame
5.62 kB
from fastapi import FastAPI
import os
import pymupdf # PyMuPDF
from pptx import Presentation
from sentence_transformers import SentenceTransformer
import torch
from transformers import CLIPProcessor, CLIPModel
from PIL import Image
import chromadb
import numpy as np
from sklearn.decomposition import PCA
app = FastAPI()
# Initialize ChromaDB
client = chromadb.PersistentClient(path="/data/chroma_db")
collection = client.get_or_create_collection(name="knowledge_base")
# File Paths
pdf_file = "Sutures and Suturing techniques.pdf"
pptx_file = "impalnt 1.pptx"
# Initialize Embedding Models
text_model = SentenceTransformer('all-MiniLM-L6-v2')
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
# Image Storage Folder
IMAGE_FOLDER = "/data/extracted_images"
os.makedirs(IMAGE_FOLDER, exist_ok=True)
# Extract Text from PDF
def extract_text_from_pdf(pdf_path):
try:
doc = pymupdf.open(pdf_path)
text = " ".join(page.get_text() for page in doc)
return text.strip() if text else None
except Exception as e:
print(f"Error extracting text from PDF: {e}")
return None
# Extract Text from PPTX
def extract_text_from_pptx(pptx_path):
try:
prs = Presentation(pptx_path)
text = " ".join(
shape.text for slide in prs.slides for shape in slide.shapes if hasattr(shape, "text")
)
return text.strip() if text else None
except Exception as e:
print(f"Error extracting text from PPTX: {e}")
return None
# Extract Images from PDF
def extract_images_from_pdf(pdf_path):
try:
doc = pymupdf.open(pdf_path)
images = []
for i, page in enumerate(doc):
for img_index, img in enumerate(page.get_images(full=True)):
xref = img[0]
image = doc.extract_image(xref)
img_path = f"{IMAGE_FOLDER}/pdf_image_{i}_{img_index}.{image['ext']}"
with open(img_path, "wb") as f:
f.write(image["image"])
images.append(img_path)
return images
except Exception as e:
print(f"Error extracting images from PDF: {e}")
return []
# Extract Images from PPTX
def extract_images_from_pptx(pptx_path):
try:
images = []
prs = Presentation(pptx_path)
for i, slide in enumerate(prs.slides):
for shape in slide.shapes:
if shape.shape_type == 13:
img_path = f"{IMAGE_FOLDER}/pptx_image_{i}.{shape.image.ext}"
with open(img_path, "wb") as f:
f.write(shape.image.blob)
images.append(img_path)
return images
except Exception as e:
print(f"Error extracting images from PPTX: {e}")
return []
# Convert Text to Embeddings
def get_text_embedding(text):
return text_model.encode(text).tolist()
# Extract Image Embeddings and Reduce to 384 Dimensions
def get_image_embedding(image_path):
try:
image = Image.open(image_path)
inputs = processor(images=image, return_tensors="pt")
with torch.no_grad():
image_embedding = model.get_image_features(**inputs).numpy().flatten()
# Ensure embedding is 384-dimensional
if len(image_embedding) != 384:
pca = PCA(n_components=384)
image_embedding = pca.fit_transform(image_embedding.reshape(1, -1)).flatten()
return image_embedding.tolist()
except Exception as e:
print(f"Error generating image embedding: {e}")
return None
# Store Data in ChromaDB
def store_data(texts, image_paths):
for i, text in enumerate(texts):
if text:
text_embedding = get_text_embedding(text)
if len(text_embedding) == 384:
collection.add(ids=[f"text_{i}"], embeddings=[text_embedding], documents=[text])
all_embeddings = [get_image_embedding(img_path) for img_path in image_paths if get_image_embedding(img_path) is not None]
if all_embeddings:
all_embeddings = np.array(all_embeddings)
# Apply PCA only if necessary
if all_embeddings.shape[1] != 384:
pca = PCA(n_components=384)
all_embeddings = pca.fit_transform(all_embeddings)
for j, img_path in enumerate(image_paths):
collection.add(ids=[f"image_{j}"], embeddings=[all_embeddings[j].tolist()], documents=[img_path])
print("Data stored successfully!")
# Process and Store from Files
def process_and_store(pdf_path=None, pptx_path=None):
texts, images = [], []
if pdf_path:
pdf_text = extract_text_from_pdf(pdf_path)
if pdf_text:
texts.append(pdf_text)
images.extend(extract_images_from_pdf(pdf_path))
if pptx_path:
pptx_text = extract_text_from_pptx(pptx_path)
if pptx_text:
texts.append(pptx_text)
images.extend(extract_images_from_pptx(pptx_path))
store_data(texts, images)
# FastAPI Endpoints
@app.get("/")
def greet_json():
# Run Data Processing
process_and_store(pdf_path=pdf_file, pptx_path=pptx_file)
return {"Document store": "created!"}
@app.get("/retrieval")
def retrieval(query: str):
try:
query_embedding = get_text_embedding(query)
results = collection.query(query_embeddings=[query_embedding], n_results=5)
return {"results": results.get("documents", [])}
except Exception as e:
return {"error": str(e)}