Redmind commited on
Commit
a244d5b
·
verified ·
1 Parent(s): bbe1084

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +25 -26
app.py CHANGED
@@ -1,6 +1,6 @@
1
  from fastapi import FastAPI
2
  import os
3
- import pymupdf
4
  from pptx import Presentation # PowerPoint
5
  from sentence_transformers import SentenceTransformer # Text embeddings
6
  import torch
@@ -11,33 +11,34 @@ import numpy as np
11
  from sklearn.decomposition import PCA
12
 
13
  app = FastAPI()
 
 
14
  client = chromadb.PersistentClient(path="/data/chroma_db")
15
- collection = client.get_or_create_collection(name="knowledge_base")
16
 
17
  pdf_file = "Sutures and Suturing techniques.pdf"
18
  pptx_file = "impalnt 1.pptx"
19
 
20
  # Initialize models
21
- text_model = SentenceTransformer('all-MiniLM-L6-v2')
22
- model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
23
- processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
24
 
25
  IMAGE_FOLDER = "/data/extracted_images"
26
  os.makedirs(IMAGE_FOLDER, exist_ok=True)
27
 
28
  # Extract text from PDF
29
  def extract_text_from_pdf(pdf_path):
30
- text = "".join([page.get_text() for page in pymupdf.open(pdf_path)])
31
- return text.strip()
32
 
33
  # Extract text from PowerPoint
34
  def extract_text_from_pptx(pptx_path):
35
- return "".join([shape.text for slide in Presentation(pptx_path).slides for shape in slide.shapes if hasattr(shape, "text")]).strip()
36
 
37
  # Extract images from PDF
38
  def extract_images_from_pdf(pdf_path):
39
  images = []
40
- doc = pymupdf.open(pdf_path)
41
  for i, page in enumerate(doc):
42
  for img_index, img in enumerate(page.get_images(full=True)):
43
  xref = img[0]
@@ -63,34 +64,32 @@ def extract_images_from_pptx(pptx_path):
63
 
64
  # Convert text to embeddings
65
  def get_text_embedding(text):
66
- return text_model.encode(text).tolist()
67
 
68
  # Extract image embeddings
69
  def get_image_embedding(image_path):
70
  image = Image.open(image_path)
71
- inputs = processor(images=image, return_tensors="pt")
72
  with torch.no_grad():
73
- image_embedding = model.get_image_features(**inputs).numpy().flatten()
74
  return image_embedding.tolist()
75
 
 
 
 
 
 
76
  # Store Data in ChromaDB
77
  def store_data(texts, image_paths):
78
  for i, text in enumerate(texts):
79
  collection.add(ids=[f"text_{i}"], embeddings=[get_text_embedding(text)], documents=[text])
80
 
81
- # Collect image embeddings first
82
- all_embeddings = [get_image_embedding(img_path) for img_path in image_paths]
83
- all_embeddings = np.array(all_embeddings)
84
-
85
- # Apply PCA if enough images exist
86
- if all_embeddings.shape[0] >= 384:
87
- pca = PCA(n_components=384)
88
- transformed_embeddings = pca.fit_transform(all_embeddings)
89
- else:
90
- transformed_embeddings = all_embeddings # Use original embeddings
91
-
92
- for j, img_path in enumerate(image_paths):
93
- collection.add(ids=[f"image_{j}"], embeddings=[transformed_embeddings[j].tolist()], documents=[img_path])
94
 
95
  print("Data stored successfully!")
96
 
@@ -119,4 +118,4 @@ def greet_json():
119
  def search(query: str):
120
  query_embedding = get_text_embedding(query)
121
  results = collection.query(query_embeddings=[query_embedding], n_results=5)
122
- return {"results": results["documents"]}
 
1
  from fastapi import FastAPI
2
  import os
3
+ import fitz # pymupdf
4
  from pptx import Presentation # PowerPoint
5
  from sentence_transformers import SentenceTransformer # Text embeddings
6
  import torch
 
11
  from sklearn.decomposition import PCA
12
 
13
  app = FastAPI()
14
+
15
+ # Initialize ChromaDB
16
  client = chromadb.PersistentClient(path="/data/chroma_db")
17
+ collection = client.get_or_create_collection(name="knowledge_base", metadata={"hnsw:space": "cosine"})
18
 
19
  pdf_file = "Sutures and Suturing techniques.pdf"
20
  pptx_file = "impalnt 1.pptx"
21
 
22
  # Initialize models
23
+ text_model = SentenceTransformer('sentence-transformers/paraphrase-MiniLM-L3-v2') # 384-dim text model
24
+ clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
25
+ clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
26
 
27
  IMAGE_FOLDER = "/data/extracted_images"
28
  os.makedirs(IMAGE_FOLDER, exist_ok=True)
29
 
30
  # Extract text from PDF
31
  def extract_text_from_pdf(pdf_path):
32
+ return " ".join([page.get_text() for page in fitz.open(pdf_path)]).strip()
 
33
 
34
  # Extract text from PowerPoint
35
  def extract_text_from_pptx(pptx_path):
36
+ return " ".join([shape.text for slide in Presentation(pptx_path).slides for shape in slide.shapes if hasattr(shape, "text")]).strip()
37
 
38
  # Extract images from PDF
39
  def extract_images_from_pdf(pdf_path):
40
  images = []
41
+ doc = fitz.open(pdf_path)
42
  for i, page in enumerate(doc):
43
  for img_index, img in enumerate(page.get_images(full=True)):
44
  xref = img[0]
 
64
 
65
  # Convert text to embeddings
66
  def get_text_embedding(text):
67
+ return text_model.encode(text).tolist() # 384-dim output
68
 
69
  # Extract image embeddings
70
  def get_image_embedding(image_path):
71
  image = Image.open(image_path)
72
+ inputs = clip_processor(images=image, return_tensors="pt")
73
  with torch.no_grad():
74
+ image_embedding = clip_model.get_image_features(**inputs).numpy().flatten() # 512-dim output
75
  return image_embedding.tolist()
76
 
77
+ # Reduce image embedding dimensionality (512 → 384)
78
+ def reduce_embedding_dim(embeddings):
79
+ pca = PCA(n_components=384)
80
+ return pca.fit_transform(np.array(embeddings))
81
+
82
  # Store Data in ChromaDB
83
  def store_data(texts, image_paths):
84
  for i, text in enumerate(texts):
85
  collection.add(ids=[f"text_{i}"], embeddings=[get_text_embedding(text)], documents=[text])
86
 
87
+ if image_paths:
88
+ all_embeddings = np.array([get_image_embedding(img_path) for img_path in image_paths])
89
+ transformed_embeddings = reduce_embedding_dim(all_embeddings) if all_embeddings.shape[1] > 384 else all_embeddings
90
+
91
+ for j, img_path in enumerate(image_paths):
92
+ collection.add(ids=[f"image_{j}"], embeddings=[transformed_embeddings[j].tolist()], documents=[img_path])
 
 
 
 
 
 
 
93
 
94
  print("Data stored successfully!")
95
 
 
118
  def search(query: str):
119
  query_embedding = get_text_embedding(query)
120
  results = collection.query(query_embeddings=[query_embedding], n_results=5)
121
+ return {"results": results["documents"][0] if results["documents"] else []} # Fix empty results handling