aznasut commited on
Commit
538fbd6
·
1 Parent(s): e593c27

add endpoint for single url

Browse files
Files changed (2) hide show
  1. main.py +50 -49
  2. models.py +10 -0
main.py CHANGED
@@ -16,6 +16,7 @@ from models import (
16
  FileImageDetectionResponse,
17
  UrlImageDetectionResponse,
18
  ImageUrlsRequest,
 
19
  )
20
 
21
  app = FastAPI()
@@ -87,36 +88,9 @@ async def classify_image(file: UploadFile = File(None)):
87
  # inputs = image_processor(image, return_tensors="pt")
88
  inputs = model(image)
89
  logging.info("inputs %s", inputs)
90
- # with torch.no_grad():
91
- # outpus = model(**inputs)
92
- # logits = outpus.logits
93
- # logging.info("logits %s", logits)
94
- # probs = F.softmax(logits, dim=1)
95
- # logging.info("probs %s", probs)
96
- # predicted_label_id = probs.argmax(-1).item()
97
- # logging.info("predicted_label_id %s", predicted_label_id)
98
- # predicted_label = model.config.id2label[predicted_label_id]
99
- # logging.info("model.config.id2label %s", model.config.id2label)
100
- # confidence = probs.max().item()
101
- # outpus = model(**inputs)
102
- # logits = outpus.logits
103
- # probs = F.softmax(logits, dim=-1)
104
- # predicted_label_id = probs.argmax(-1).item()
105
- # predicted_label = model.config.id2label[predicted_label_id]
106
- # confidence = probs.max().item()
107
-
108
- # model predicts one of the 1000 ImageNet classes
109
- # predicted_label = logits.argmax(-1).item()
110
- # logging.info("predicted_label", predicted_label)
111
- # logging.info("model.config.id2label[predicted_label] %s", model.config.id2label[predicted_label])
112
- # # print(model.config.id2label[predicted_label])
113
- # Find the prediction with the highest confidence using the max() function
114
  predicted_label = max(inputs, key=lambda x: x["score"])
115
- # logging.info("best_prediction %s", best_prediction)
116
- # best_prediction2 = results[1]["label"]
117
- # logging.info("best_prediction2 %s", best_prediction2)
118
-
119
- # # Calculate the confidence score, rounded to the nearest tenth and as a percentage
120
  confidence = round(predicted_label["score"] * 100, 1)
121
 
122
  # # Prepare the custom response data
@@ -125,32 +99,13 @@ async def classify_image(file: UploadFile = File(None)):
125
  "prediction": predicted_label["label"],
126
  "confidence": str(confidence),
127
  }
128
- # Use the model to classify the image
129
- # results = model(image)
130
-
131
- # Find the prediction with the highest confidence using the max() function
132
- # best_prediction = max(results, key=lambda x: x["score"])
133
-
134
- # Calculate the confidence score, rounded to the nearest tenth and as a percentage
135
- # confidence_percentage = round(best_prediction["score"] * 100, 1)
136
-
137
- # Prepare the custom response data
138
- # detection_result = {
139
- # "is_nsfw": best_prediction["label"] == "nsfw",
140
- # "confidence_percentage": confidence_percentage,
141
- # }
142
-
143
  # Populate hash
144
  cache[image_hash] = response_data.copy()
145
 
146
  # Add url to the API response
147
  response_data["file_name"] = file.filename
148
 
149
- # response_data.append(detection_result)
150
-
151
- # Add file_name to the API response
152
- # response_data["file_name"] = file.filename
153
-
154
  return FileImageDetectionResponse(**response_data)
155
 
156
  # except Exception as e:
@@ -161,6 +116,52 @@ async def classify_image(file: UploadFile = File(None)):
161
  ) from e
162
 
163
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
164
  @app.post("/v1/detect/urls", response_model=list[UrlImageDetectionResponse])
165
  async def classify_images(request: ImageUrlsRequest):
166
  """Function analyzing images from URLs."""
 
16
  FileImageDetectionResponse,
17
  UrlImageDetectionResponse,
18
  ImageUrlsRequest,
19
+ ImageUrlRequest,
20
  )
21
 
22
  app = FastAPI()
 
88
  # inputs = image_processor(image, return_tensors="pt")
89
  inputs = model(image)
90
  logging.info("inputs %s", inputs)
91
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92
  predicted_label = max(inputs, key=lambda x: x["score"])
93
+
 
 
 
 
94
  confidence = round(predicted_label["score"] * 100, 1)
95
 
96
  # # Prepare the custom response data
 
99
  "prediction": predicted_label["label"],
100
  "confidence": str(confidence),
101
  }
102
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
103
  # Populate hash
104
  cache[image_hash] = response_data.copy()
105
 
106
  # Add url to the API response
107
  response_data["file_name"] = file.filename
108
 
 
 
 
 
 
109
  return FileImageDetectionResponse(**response_data)
110
 
111
  # except Exception as e:
 
116
  ) from e
117
 
118
 
119
+ @app.post("/v1/detect/url", response_model=UrlImageDetectionResponse)
120
+ async def classify_images(request: ImageUrlRequest):
121
+
122
+ try:
123
+ image_url = request.url
124
+ logging.info("Downloading image from URL: %s", image_url)
125
+ image_data = await download_image(image_url)
126
+ image_hash = hash_data(image_data)
127
+
128
+ if image_hash in cache:
129
+ # Return cached entry
130
+ logging.info("Returning cached entry for %s", image_url)
131
+
132
+ cached_response = cache[image_hash]
133
+ response_data = {**cached_response, "url": image_url}
134
+
135
+ return UrlImageDetectionResponse(**response_data)
136
+
137
+ image = Image.open(io.BytesIO(image_data))
138
+ # inputs = image_processor(image, return_tensors="pt")
139
+ inputs = model(image)
140
+
141
+ predicted_label = max(inputs, key=lambda x: x["score"])
142
+ confidence = round(predicted_label["score"] * 100, 1)
143
+
144
+ response_data = {
145
+ "prediction": predicted_label["label"],
146
+ "confidence": str(confidence),
147
+ }
148
+
149
+ # Populate hash
150
+ cache[image_hash] = response_data.copy()
151
+
152
+ # Add url to the API response
153
+ response_data["url"] = image_url
154
+
155
+ return UrlImageDetectionResponse(**response_data)
156
+
157
+ # except Exception as e:
158
+ except PipelineException as e:
159
+ logging.error("Error processing image from %s: %s", image_url, str(e))
160
+ raise HTTPException(
161
+ status_code=500,
162
+ detail=f"Error processing image from {image_url}: {str(e)}",
163
+ ) from e
164
+
165
  @app.post("/v1/detect/urls", response_model=list[UrlImageDetectionResponse])
166
  async def classify_images(request: ImageUrlsRequest):
167
  """Function analyzing images from URLs."""
models.py CHANGED
@@ -3,6 +3,16 @@
3
  from pydantic import BaseModel
4
 
5
 
 
 
 
 
 
 
 
 
 
 
6
  class ImageUrlsRequest(BaseModel):
7
  """
8
  Model representing the request body for the /v1/detect/urls endpoint.
 
3
  from pydantic import BaseModel
4
 
5
 
6
+ class ImageUrlRequest(BaseModel):
7
+ """
8
+ Model representing the request body for the /v1/detect/urls endpoint.
9
+
10
+ Attributes:
11
+ urls (list[str]): List of image URLs to be processed.
12
+ """
13
+
14
+ url: str
15
+
16
  class ImageUrlsRequest(BaseModel):
17
  """
18
  Model representing the request body for the /v1/detect/urls endpoint.