sakshee05 commited on
Commit
49bffdd
·
verified ·
1 Parent(s): 70ae50f

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +164 -69
main.py CHANGED
@@ -1,13 +1,23 @@
1
  import os
 
 
 
 
 
 
 
 
 
 
 
2
 
3
  # Set Hugging Face cache directory to /tmp
4
  os.environ["HF_HOME"] = "/tmp/huggingface"
5
  os.environ["TRANSFORMERS_CACHE"] = "/tmp/huggingface"
6
  os.environ["TORCH_HOME"] = "/tmp/torch"
7
 
8
- from fastapi import FastAPI, File, UploadFile, Form
9
  from fastapi.middleware.cors import CORSMiddleware
10
- from fastapi.responses import Response
11
  import uvicorn
12
  from PIL import Image
13
  import io
@@ -18,6 +28,13 @@ from sam2.build_sam import build_sam2
18
  from sam2.sam2_image_predictor import SAM2ImagePredictor
19
  import torch
20
  import cv2
 
 
 
 
 
 
 
21
 
22
  app = FastAPI()
23
 
@@ -35,34 +52,50 @@ os.makedirs("/tmp/huggingface", exist_ok=True)
35
  os.makedirs("/tmp/torch", exist_ok=True)
36
 
37
  # Load the langSAM model
 
38
  langsam_model = LangSAM()
 
39
 
40
  # Load SAM2 Model
 
41
  sam2_checkpoint = "sam2.1_hiera_small.pt"
42
  model_cfg = "configs/sam2.1/sam2.1_hiera_s.yaml"
43
  device = torch.device("cpu")
44
 
45
  sam2_model = build_sam2(model_cfg, sam2_checkpoint, device=device)
46
  predictor = SAM2ImagePredictor(sam2_model)
 
47
 
48
  @app.get("/")
49
  async def root():
50
  return {"message": "LangSAM API is running!"}
51
 
52
- def apply_mask(image, mask):
53
- """Overlay mask on image."""
54
- mask = mask.astype(np.uint8) * 255 # Convert mask to 0-255 scale
55
- mask_colored = np.zeros((*mask.shape, 3), dtype=np.uint8)
56
- mask_colored[mask > 0] = [30, 144, 255] # Blue color for the mask
57
 
58
  # Add contour
59
- contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
60
- cv2.drawContours(mask_colored, contours, -1, (255, 255, 255), thickness=2)
61
 
62
  # Blend with original image
63
- overlay = cv2.addWeighted(image, 0.7, mask_colored, 0.3, 0)
64
  return overlay
65
 
 
 
 
 
 
 
 
 
 
 
 
 
66
 
67
  def draw_image(image_rgb, masks, xyxy, probs, labels):
68
  mask_annotator = sv.MaskAnnotator()
@@ -81,72 +114,134 @@ def draw_image(image_rgb, masks, xyxy, probs, labels):
81
  annotated_image = mask_annotator.annotate(scene=image_rgb.copy(), detections=detections)
82
  return annotated_image
83
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
  @app.post("/segment/sam2")
85
  async def segment_image(
86
- file: UploadFile = File(...),
87
  x: int = Form(...),
88
  y: int = Form(...)
89
  ):
90
  """Segment image using SAM2 with a single input point."""
91
- image_bytes = await file.read()
92
- image_pil = Image.open(io.BytesIO(image_bytes)).convert("RGB")
93
- image_array = np.array(image_pil)
94
-
95
- predictor.set_image(image_array)
96
-
97
- input_point = np.array([[x, y]])
98
- input_label = np.array([1]) # Foreground point
99
-
100
- # Run SAM2 model
101
- masks, scores, logits = predictor.predict(
102
- point_coords=input_point,
103
- point_labels=input_label,
104
- multimask_output=True,
105
- )
106
-
107
- # Get top mask
108
- top_mask = masks[np.argmax(scores)]
109
-
110
- # Apply mask overlay
111
- output_image = apply_mask(image_array, top_mask)
112
-
113
- # Convert to PNG
114
- output_pil = Image.fromarray(output_image)
115
- img_io = io.BytesIO()
116
- output_pil.save(img_io, format="PNG")
117
- img_io.seek(0)
118
-
119
- return Response(content=img_io.getvalue(), media_type="image/png")
120
-
 
 
 
 
 
 
 
 
 
 
 
 
 
121
 
122
  @app.post("/segment/langsam")
123
- async def segment_image(file: UploadFile = File(...), text_prompt: str = Form(...)):
124
- image_bytes = await file.read()
125
- image_pil = Image.open(io.BytesIO(image_bytes)).convert("RGB")
126
-
127
- # Run segmentation
128
- results = langsam_model.predict([image_pil], [text_prompt])
129
-
130
- # Convert to NumPy array
131
- image_array = np.asarray(image_pil)
132
- output_image = draw_image(
133
- image_array,
134
- results[0]["masks"],
135
- results[0]["boxes"],
136
- results[0]["scores"],
137
- results[0]["labels"],
138
- )
139
-
140
- # Convert back to PIL Image
141
- output_pil = Image.fromarray(np.uint8(output_image)).convert("RGB")
142
-
143
- # Save to byte stream
144
- img_io = io.BytesIO()
145
- output_pil.save(img_io, format="PNG")
146
- img_io.seek(0)
147
-
148
- return Response(content=img_io.getvalue(), media_type="image/png")
149
-
 
 
 
 
 
 
 
 
 
 
150
 
151
  if __name__ == "__main__":
152
- uvicorn.run(app, host="0.0.0.0", port=7860)
 
1
  import os
2
+ import logging
3
+ import json
4
+ import base64
5
+ from typing import Dict, Any
6
+
7
+ # Configure logging
8
+ logging.basicConfig(
9
+ level=logging.INFO,
10
+ format='%(asctime)s - %(levelname)s - %(message)s'
11
+ )
12
+ logger = logging.getLogger(__name__)
13
 
14
  # Set Hugging Face cache directory to /tmp
15
  os.environ["HF_HOME"] = "/tmp/huggingface"
16
  os.environ["TRANSFORMERS_CACHE"] = "/tmp/huggingface"
17
  os.environ["TORCH_HOME"] = "/tmp/torch"
18
 
19
+ from fastapi import FastAPI, Form, HTTPException
20
  from fastapi.middleware.cors import CORSMiddleware
 
21
  import uvicorn
22
  from PIL import Image
23
  import io
 
28
  from sam2.sam2_image_predictor import SAM2ImagePredictor
29
  import torch
30
  import cv2
31
+ from dotenv import load_dotenv
32
+ import openai
33
+ import requests
34
+ from io import BytesIO
35
+
36
+ load_dotenv()
37
+ client = openai.OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
38
 
39
  app = FastAPI()
40
 
 
52
  os.makedirs("/tmp/torch", exist_ok=True)
53
 
54
  # Load the langSAM model
55
+ logger.info("Loading LangSAM model...")
56
  langsam_model = LangSAM()
57
+ logger.info("LangSAM model loaded successfully")
58
 
59
  # Load SAM2 Model
60
+ logger.info("Loading SAM2 model...")
61
  sam2_checkpoint = "sam2.1_hiera_small.pt"
62
  model_cfg = "configs/sam2.1/sam2.1_hiera_s.yaml"
63
  device = torch.device("cpu")
64
 
65
  sam2_model = build_sam2(model_cfg, sam2_checkpoint, device=device)
66
  predictor = SAM2ImagePredictor(sam2_model)
67
+ logger.info("SAM2 model loaded successfully")
68
 
69
  @app.get("/")
70
  async def root():
71
  return {"message": "LangSAM API is running!"}
72
 
73
+ def create_mask_overlay(image: np.ndarray, mask: np.ndarray, alpha: float = 0.5) -> np.ndarray:
74
+ """Create a mask overlay on the original image."""
75
+ # Create a colored mask (blue color)
76
+ colored_mask = np.zeros_like(image)
77
+ colored_mask[mask > 0] = [30, 144, 255] # Blue color
78
 
79
  # Add contour
80
+ contours, _ = cv2.findContours(mask.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
81
+ cv2.drawContours(colored_mask, contours, -1, (255, 255, 255), thickness=2)
82
 
83
  # Blend with original image
84
+ overlay = cv2.addWeighted(image, 1 - alpha, colored_mask, alpha, 0)
85
  return overlay
86
 
87
+ def create_mask_only(image: np.ndarray, mask: np.ndarray) -> np.ndarray:
88
+ """Create an image showing only the masked region."""
89
+ # Create a black background
90
+ result = np.zeros_like(image)
91
+ # Copy only the masked region
92
+ result[mask > 0] = image[mask > 0]
93
+ return result
94
+
95
+ def image_to_base64(image: np.ndarray) -> str:
96
+ """Convert numpy array image to base64 string."""
97
+ _, buffer = cv2.imencode('.png', cv2.cvtColor(image, cv2.COLOR_RGB2BGR))
98
+ return base64.b64encode(buffer).decode('utf-8')
99
 
100
  def draw_image(image_rgb, masks, xyxy, probs, labels):
101
  mask_annotator = sv.MaskAnnotator()
 
114
  annotated_image = mask_annotator.annotate(scene=image_rgb.copy(), detections=detections)
115
  return annotated_image
116
 
117
+ def load_image_from_url(url):
118
+ """Fetch image from URL and load it into memory."""
119
+ try:
120
+ logger.info(f"Fetching image from URL: {url}")
121
+ response = requests.get(url)
122
+ response.raise_for_status()
123
+ return Image.open(BytesIO(response.content))
124
+ except Exception as e:
125
+ logger.error(f"Error loading image from URL: {str(e)}")
126
+ raise HTTPException(status_code=400, detail=f"Error loading image from URL: {str(e)}")
127
+
128
+ prompt = """You will be provided with a complete product name, which may contain brand names, extra details, and categories. Your task is to extract only the core product name (apparel or accessory) while removing brand names, categories, and unnecessary words and convert it's meaning to a basic clothing or accessory category.
129
+
130
+ Examples:
131
+ Beachwood Luxe Paneled Unitard — Girlfriend Collective → Dress
132
+ 100 cotton strappy top · Black, White, Red, Peach · T-shirts And Polo Shirts | Massimo Dutti → Shirt
133
+ Wide-leg co-ord trousers with pleats · Green · Dressy | Massimo Dutti → Pants
134
+ BLANKNYC Wide Leg Jean in Radio Star | REVOLVE → Jeans
135
+
136
+ Basically, you need to convert the product name to a basic clothing or accessory category like Shirt, Pants, Dress, Jeans, etc.
137
+ Now, extract the core product name from the following:
138
+
139
+ {product_name}"""
140
+
141
+ @app.post("/openai/chat")
142
+ async def chat(product_name: str = Form(...)):
143
+ try:
144
+ logger.info(f"Processing product name: {product_name}")
145
+ completion = client.chat.completions.create(
146
+ model="gpt-4o-mini",
147
+ messages=[{"role": "user", "content": prompt.format(product_name=product_name)}],
148
+ )
149
+ result = completion.choices[0].message
150
+ logger.info(f"OpenAI response: {result.content}")
151
+ return result
152
+ except Exception as e:
153
+ logger.error(f"Error in OpenAI chat: {str(e)}")
154
+ raise HTTPException(status_code=500, detail=f"Error processing product name: {str(e)}")
155
+
156
  @app.post("/segment/sam2")
157
  async def segment_image(
158
+ image_url: str = Form(...),
159
  x: int = Form(...),
160
  y: int = Form(...)
161
  ):
162
  """Segment image using SAM2 with a single input point."""
163
+ try:
164
+ logger.info(f"Starting SAM2 segmentation for image URL: {image_url}")
165
+ image_pil = load_image_from_url(image_url)
166
+ image_array = np.array(image_pil)
167
+
168
+ logger.info("Setting image in SAM2 predictor")
169
+ predictor.set_image(image_array)
170
+
171
+ input_point = np.array([[x, y]])
172
+ input_label = np.array([1]) # Foreground point
173
+
174
+ logger.info("Running SAM2 prediction")
175
+ masks, scores, logits = predictor.predict(
176
+ point_coords=input_point,
177
+ point_labels=input_label,
178
+ multimask_output=True,
179
+ )
180
+
181
+ # Get top mask
182
+ top_mask = masks[np.argmax(scores)]
183
+
184
+ # Create different versions of the result
185
+ overlay_image = create_mask_overlay(image_array, top_mask)
186
+ mask_only_image = create_mask_only(image_array, top_mask)
187
+
188
+ # Convert images to base64
189
+ original_b64 = image_to_base64(image_array)
190
+ overlay_b64 = image_to_base64(overlay_image)
191
+ mask_only_b64 = image_to_base64(mask_only_image)
192
+
193
+ # Create response
194
+ response = {
195
+ "original": original_b64,
196
+ "overlay": overlay_b64,
197
+ "mask_only": mask_only_b64,
198
+ "score": float(scores[np.argmax(scores)])
199
+ }
200
+
201
+ logger.info("SAM2 segmentation completed successfully")
202
+ return response
203
+ except Exception as e:
204
+ logger.error(f"Error in SAM2 segmentation: {str(e)}")
205
+ raise HTTPException(status_code=500, detail=f"Error in SAM2 segmentation: {str(e)}")
206
 
207
  @app.post("/segment/langsam")
208
+ async def segment_image(image_url: str = Form(...), text_prompt: str = Form(...)):
209
+ try:
210
+ logger.info(f"Starting LangSAM segmentation for image URL: {image_url} with prompt: {text_prompt}")
211
+ image_pil = load_image_from_url(image_url)
212
+ image_array = np.array(image_pil)
213
+
214
+ # Run segmentation
215
+ logger.info("Running LangSAM prediction")
216
+ results = langsam_model.predict([image_pil], [text_prompt])
217
+
218
+ # Get the first (best) mask
219
+ mask = results[0]["masks"][0]
220
+
221
+ # Create different versions of the result
222
+ overlay_image = create_mask_overlay(image_array, mask)
223
+ mask_only_image = create_mask_only(image_array, mask)
224
+
225
+ # Convert images to base64
226
+ original_b64 = image_to_base64(image_array)
227
+ overlay_b64 = image_to_base64(overlay_image)
228
+ mask_only_b64 = image_to_base64(mask_only_image)
229
+
230
+ # Create response
231
+ response = {
232
+ "original": original_b64,
233
+ "overlay": overlay_b64,
234
+ "mask_only": mask_only_b64,
235
+ "boxes": results[0]["boxes"].tolist(),
236
+ "scores": results[0]["scores"].tolist(),
237
+ "labels": results[0]["labels"]
238
+ }
239
+
240
+ logger.info("LangSAM segmentation completed successfully")
241
+ return response
242
+ except Exception as e:
243
+ logger.error(f"Error in LangSAM segmentation: {str(e)}")
244
+ raise HTTPException(status_code=500, detail=f"Error in LangSAM segmentation: {str(e)}")
245
 
246
  if __name__ == "__main__":
247
+ uvicorn.run(app, host="0.0.0.0", port=7860)