Redmind commited on
Commit
4e117fe
·
verified ·
1 Parent(s): 2e4f58f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -25
app.py CHANGED
@@ -8,7 +8,6 @@ from transformers import CLIPProcessor, CLIPModel
8
  from PIL import Image
9
  import chromadb
10
  import numpy as np
11
- from sklearn.decomposition import PCA
12
 
13
  app = FastAPI()
14
 
@@ -21,7 +20,7 @@ pdf_file = "Sutures and Suturing techniques.pdf"
21
  pptx_file = "impalnt 1.pptx"
22
 
23
  # Initialize Embedding Models
24
- text_model = SentenceTransformer('all-MiniLM-L6-v2')
25
  model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
26
  processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
27
 
@@ -86,23 +85,17 @@ def extract_images_from_pptx(pptx_path):
86
  print(f"Error extracting images from PPTX: {e}")
87
  return []
88
 
89
- # Convert Text to Embeddings
90
  def get_text_embedding(text):
91
  return text_model.encode(text).tolist()
92
 
93
- # Extract Image Embeddings and Reduce to 384 Dimensions
94
  def get_image_embedding(image_path):
95
  try:
96
  image = Image.open(image_path)
97
  inputs = processor(images=image, return_tensors="pt")
98
  with torch.no_grad():
99
  image_embedding = model.get_image_features(**inputs).numpy().flatten()
100
-
101
- # Ensure embedding is 384-dimensional
102
- if len(image_embedding) != 384:
103
- pca = PCA(n_components=384)
104
- image_embedding = pca.fit_transform(image_embedding.reshape(1, -1)).flatten()
105
-
106
  return image_embedding.tolist()
107
  except Exception as e:
108
  print(f"Error generating image embedding: {e}")
@@ -113,21 +106,12 @@ def store_data(texts, image_paths):
113
  for i, text in enumerate(texts):
114
  if text:
115
  text_embedding = get_text_embedding(text)
116
- if len(text_embedding) == 384:
117
- collection.add(ids=[f"text_{i}"], embeddings=[text_embedding], documents=[text])
118
 
119
- all_embeddings = [get_image_embedding(img_path) for img_path in image_paths if get_image_embedding(img_path) is not None]
120
-
121
- if all_embeddings:
122
- all_embeddings = np.array(all_embeddings)
123
-
124
- # Apply PCA only if necessary
125
- if all_embeddings.shape[1] != 384:
126
- pca = PCA(n_components=384)
127
- all_embeddings = pca.fit_transform(all_embeddings)
128
-
129
- for j, img_path in enumerate(image_paths):
130
- collection.add(ids=[f"image_{j}"], embeddings=[all_embeddings[j].tolist()], documents=[img_path])
131
 
132
  print("Data stored successfully!")
133
 
@@ -148,7 +132,7 @@ def process_and_store(pdf_path=None, pptx_path=None):
148
 
149
  # FastAPI Endpoints
150
  @app.get("/")
151
- def create_vector():
152
  # Run Data Processing
153
  process_and_store(pdf_path=pdf_file, pptx_path=pptx_file)
154
  return {"Document store": "created!"}
 
8
  from PIL import Image
9
  import chromadb
10
  import numpy as np
 
11
 
12
  app = FastAPI()
13
 
 
20
  pptx_file = "impalnt 1.pptx"
21
 
22
  # Initialize Embedding Models
23
+ text_model = SentenceTransformer('paraphrase-MiniLM-L12-v2') # 512D embeddings
24
  model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
25
  processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
26
 
 
85
  print(f"Error extracting images from PPTX: {e}")
86
  return []
87
 
88
+ # Convert Text to Embeddings (512D)
89
  def get_text_embedding(text):
90
  return text_model.encode(text).tolist()
91
 
92
+ # Extract Image Embeddings (512D)
93
  def get_image_embedding(image_path):
94
  try:
95
  image = Image.open(image_path)
96
  inputs = processor(images=image, return_tensors="pt")
97
  with torch.no_grad():
98
  image_embedding = model.get_image_features(**inputs).numpy().flatten()
 
 
 
 
 
 
99
  return image_embedding.tolist()
100
  except Exception as e:
101
  print(f"Error generating image embedding: {e}")
 
106
  for i, text in enumerate(texts):
107
  if text:
108
  text_embedding = get_text_embedding(text)
109
+ collection.add(ids=[f"text_{i}"], embeddings=[text_embedding], documents=[text])
 
110
 
111
+ for j, img_path in enumerate(image_paths):
112
+ img_embedding = get_image_embedding(img_path)
113
+ if img_embedding:
114
+ collection.add(ids=[f"image_{j}"], embeddings=[img_embedding], documents=[img_path])
 
 
 
 
 
 
 
 
115
 
116
  print("Data stored successfully!")
117
 
 
132
 
133
  # FastAPI Endpoints
134
  @app.get("/")
135
+ def greet_json():
136
  # Run Data Processing
137
  process_and_store(pdf_path=pdf_file, pptx_path=pptx_file)
138
  return {"Document store": "created!"}