Redmind commited on
Commit
c2710ab
·
verified ·
1 Parent(s): 1371299

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +32 -17
app.py CHANGED
@@ -8,24 +8,22 @@ from transformers import CLIPProcessor, CLIPModel
8
  from PIL import Image
9
  import chromadb
10
  import numpy as np
 
11
 
12
  app = FastAPI()
13
 
14
- # Initialize ChromaDB with 512 dimensions
15
  client = chromadb.PersistentClient(path="/data/chroma_db")
16
- client.delete_collection(name="knowledge_base")
17
- collection = client.get_or_create_collection(name="knowledge_base", metadata={"dim": 512})
18
-
19
- #collection = client.get_or_create_collection(name="knowledge_base", metadata={"hnsw:space": "cosine"}, embedding_function=None)
20
 
21
  # File Paths
22
  pdf_file = "Sutures and Suturing techniques.pdf"
23
  pptx_file = "impalnt 1.pptx"
24
 
25
  # Initialize Embedding Models
26
- text_model = SentenceTransformer('paraphrase-MiniLM-L12-v2') # 512D text embeddings
27
- clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
28
- clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
29
 
30
  # Image Storage Folder
31
  IMAGE_FOLDER = "/data/extracted_images"
@@ -88,17 +86,23 @@ def extract_images_from_pptx(pptx_path):
88
  print(f"Error extracting images from PPTX: {e}")
89
  return []
90
 
91
- # Convert Text to Embeddings (512D)
92
  def get_text_embedding(text):
93
  return text_model.encode(text).tolist()
94
 
95
- # Extract Image Embeddings (512D)
96
  def get_image_embedding(image_path):
97
  try:
98
  image = Image.open(image_path)
99
- inputs = clip_processor(images=image, return_tensors="pt")
100
  with torch.no_grad():
101
- image_embedding = clip_model.get_image_features(**inputs).squeeze().numpy()
 
 
 
 
 
 
102
  return image_embedding.tolist()
103
  except Exception as e:
104
  print(f"Error generating image embedding: {e}")
@@ -109,12 +113,21 @@ def store_data(texts, image_paths):
109
  for i, text in enumerate(texts):
110
  if text:
111
  text_embedding = get_text_embedding(text)
112
- collection.add(ids=[f"text_{i}"], embeddings=[text_embedding], documents=[text])
 
 
 
113
 
114
- for j, img_path in enumerate(image_paths):
115
- img_embedding = get_image_embedding(img_path)
116
- if img_embedding:
117
- collection.add(ids=[f"image_{j}"], embeddings=[img_embedding], documents=[img_path])
 
 
 
 
 
 
118
 
119
  print("Data stored successfully!")
120
 
@@ -133,6 +146,8 @@ def process_and_store(pdf_path=None, pptx_path=None):
133
  images.extend(extract_images_from_pptx(pptx_path))
134
  store_data(texts, images)
135
 
 
 
136
  # FastAPI Endpoints
137
  @app.get("/")
138
  def greet_json():
 
8
  from PIL import Image
9
  import chromadb
10
  import numpy as np
11
+ from sklearn.decomposition import PCA
12
 
13
  app = FastAPI()
14
 
15
+ # Initialize ChromaDB
16
  client = chromadb.PersistentClient(path="/data/chroma_db")
17
+ collection = client.get_or_create_collection(name="knowledge_base")
 
 
 
18
 
19
  # File Paths
20
  pdf_file = "Sutures and Suturing techniques.pdf"
21
  pptx_file = "impalnt 1.pptx"
22
 
23
  # Initialize Embedding Models
24
+ text_model = SentenceTransformer('all-MiniLM-L6-v2')
25
+ model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
26
+ processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
27
 
28
  # Image Storage Folder
29
  IMAGE_FOLDER = "/data/extracted_images"
 
86
  print(f"Error extracting images from PPTX: {e}")
87
  return []
88
 
89
+ # Convert Text to Embeddings
90
  def get_text_embedding(text):
91
  return text_model.encode(text).tolist()
92
 
93
+ # Extract Image Embeddings and Reduce to 384 Dimensions
94
  def get_image_embedding(image_path):
95
  try:
96
  image = Image.open(image_path)
97
+ inputs = processor(images=image, return_tensors="pt")
98
  with torch.no_grad():
99
+ image_embedding = model.get_image_features(**inputs).numpy().flatten()
100
+
101
+ # Ensure embedding is 384-dimensional
102
+ if len(image_embedding) != 384:
103
+ pca = PCA(n_components=384)
104
+ image_embedding = pca.fit_transform(image_embedding.reshape(1, -1)).flatten()
105
+
106
  return image_embedding.tolist()
107
  except Exception as e:
108
  print(f"Error generating image embedding: {e}")
 
113
  for i, text in enumerate(texts):
114
  if text:
115
  text_embedding = get_text_embedding(text)
116
+ if len(text_embedding) == 384:
117
+ collection.add(ids=[f"text_{i}"], embeddings=[text_embedding], documents=[text])
118
+
119
+ all_embeddings = [get_image_embedding(img_path) for img_path in image_paths if get_image_embedding(img_path) is not None]
120
 
121
+ if all_embeddings:
122
+ all_embeddings = np.array(all_embeddings)
123
+
124
+ # Apply PCA only if necessary
125
+ if all_embeddings.shape[1] != 384:
126
+ pca = PCA(n_components=384)
127
+ all_embeddings = pca.fit_transform(all_embeddings)
128
+
129
+ for j, img_path in enumerate(image_paths):
130
+ collection.add(ids=[f"image_{j}"], embeddings=[all_embeddings[j].tolist()], documents=[img_path])
131
 
132
  print("Data stored successfully!")
133
 
 
146
  images.extend(extract_images_from_pptx(pptx_path))
147
  store_data(texts, images)
148
 
149
+
150
+
151
  # FastAPI Endpoints
152
  @app.get("/")
153
  def greet_json():