Spaces:
Sleeping
Sleeping
add endpoint for single url
Browse files
main.py
CHANGED
@@ -16,6 +16,7 @@ from models import (
|
|
16 |
FileImageDetectionResponse,
|
17 |
UrlImageDetectionResponse,
|
18 |
ImageUrlsRequest,
|
|
|
19 |
)
|
20 |
|
21 |
app = FastAPI()
|
@@ -87,36 +88,9 @@ async def classify_image(file: UploadFile = File(None)):
|
|
87 |
# inputs = image_processor(image, return_tensors="pt")
|
88 |
inputs = model(image)
|
89 |
logging.info("inputs %s", inputs)
|
90 |
-
|
91 |
-
# outpus = model(**inputs)
|
92 |
-
# logits = outpus.logits
|
93 |
-
# logging.info("logits %s", logits)
|
94 |
-
# probs = F.softmax(logits, dim=1)
|
95 |
-
# logging.info("probs %s", probs)
|
96 |
-
# predicted_label_id = probs.argmax(-1).item()
|
97 |
-
# logging.info("predicted_label_id %s", predicted_label_id)
|
98 |
-
# predicted_label = model.config.id2label[predicted_label_id]
|
99 |
-
# logging.info("model.config.id2label %s", model.config.id2label)
|
100 |
-
# confidence = probs.max().item()
|
101 |
-
# outpus = model(**inputs)
|
102 |
-
# logits = outpus.logits
|
103 |
-
# probs = F.softmax(logits, dim=-1)
|
104 |
-
# predicted_label_id = probs.argmax(-1).item()
|
105 |
-
# predicted_label = model.config.id2label[predicted_label_id]
|
106 |
-
# confidence = probs.max().item()
|
107 |
-
|
108 |
-
# model predicts one of the 1000 ImageNet classes
|
109 |
-
# predicted_label = logits.argmax(-1).item()
|
110 |
-
# logging.info("predicted_label", predicted_label)
|
111 |
-
# logging.info("model.config.id2label[predicted_label] %s", model.config.id2label[predicted_label])
|
112 |
-
# # print(model.config.id2label[predicted_label])
|
113 |
-
# Find the prediction with the highest confidence using the max() function
|
114 |
predicted_label = max(inputs, key=lambda x: x["score"])
|
115 |
-
|
116 |
-
# best_prediction2 = results[1]["label"]
|
117 |
-
# logging.info("best_prediction2 %s", best_prediction2)
|
118 |
-
|
119 |
-
# # Calculate the confidence score, rounded to the nearest tenth and as a percentage
|
120 |
confidence = round(predicted_label["score"] * 100, 1)
|
121 |
|
122 |
# # Prepare the custom response data
|
@@ -125,32 +99,13 @@ async def classify_image(file: UploadFile = File(None)):
|
|
125 |
"prediction": predicted_label["label"],
|
126 |
"confidence": str(confidence),
|
127 |
}
|
128 |
-
|
129 |
-
# results = model(image)
|
130 |
-
|
131 |
-
# Find the prediction with the highest confidence using the max() function
|
132 |
-
# best_prediction = max(results, key=lambda x: x["score"])
|
133 |
-
|
134 |
-
# Calculate the confidence score, rounded to the nearest tenth and as a percentage
|
135 |
-
# confidence_percentage = round(best_prediction["score"] * 100, 1)
|
136 |
-
|
137 |
-
# Prepare the custom response data
|
138 |
-
# detection_result = {
|
139 |
-
# "is_nsfw": best_prediction["label"] == "nsfw",
|
140 |
-
# "confidence_percentage": confidence_percentage,
|
141 |
-
# }
|
142 |
-
|
143 |
# Populate hash
|
144 |
cache[image_hash] = response_data.copy()
|
145 |
|
146 |
# Add url to the API response
|
147 |
response_data["file_name"] = file.filename
|
148 |
|
149 |
-
# response_data.append(detection_result)
|
150 |
-
|
151 |
-
# Add file_name to the API response
|
152 |
-
# response_data["file_name"] = file.filename
|
153 |
-
|
154 |
return FileImageDetectionResponse(**response_data)
|
155 |
|
156 |
# except Exception as e:
|
@@ -161,6 +116,52 @@ async def classify_image(file: UploadFile = File(None)):
|
|
161 |
) from e
|
162 |
|
163 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
164 |
@app.post("/v1/detect/urls", response_model=list[UrlImageDetectionResponse])
|
165 |
async def classify_images(request: ImageUrlsRequest):
|
166 |
"""Function analyzing images from URLs."""
|
|
|
16 |
FileImageDetectionResponse,
|
17 |
UrlImageDetectionResponse,
|
18 |
ImageUrlsRequest,
|
19 |
+
ImageUrlRequest,
|
20 |
)
|
21 |
|
22 |
app = FastAPI()
|
|
|
88 |
# inputs = image_processor(image, return_tensors="pt")
|
89 |
inputs = model(image)
|
90 |
logging.info("inputs %s", inputs)
|
91 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
92 |
predicted_label = max(inputs, key=lambda x: x["score"])
|
93 |
+
|
|
|
|
|
|
|
|
|
94 |
confidence = round(predicted_label["score"] * 100, 1)
|
95 |
|
96 |
# # Prepare the custom response data
|
|
|
99 |
"prediction": predicted_label["label"],
|
100 |
"confidence": str(confidence),
|
101 |
}
|
102 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
103 |
# Populate hash
|
104 |
cache[image_hash] = response_data.copy()
|
105 |
|
106 |
# Add url to the API response
|
107 |
response_data["file_name"] = file.filename
|
108 |
|
|
|
|
|
|
|
|
|
|
|
109 |
return FileImageDetectionResponse(**response_data)
|
110 |
|
111 |
# except Exception as e:
|
|
|
116 |
) from e
|
117 |
|
118 |
|
119 |
+
@app.post("/v1/detect/url", response_model=UrlImageDetectionResponse)
|
120 |
+
async def classify_images(request: ImageUrlRequest):
|
121 |
+
|
122 |
+
try:
|
123 |
+
image_url = request.url
|
124 |
+
logging.info("Downloading image from URL: %s", image_url)
|
125 |
+
image_data = await download_image(image_url)
|
126 |
+
image_hash = hash_data(image_data)
|
127 |
+
|
128 |
+
if image_hash in cache:
|
129 |
+
# Return cached entry
|
130 |
+
logging.info("Returning cached entry for %s", image_url)
|
131 |
+
|
132 |
+
cached_response = cache[image_hash]
|
133 |
+
response_data = {**cached_response, "url": image_url}
|
134 |
+
|
135 |
+
return UrlImageDetectionResponse(**response_data)
|
136 |
+
|
137 |
+
image = Image.open(io.BytesIO(image_data))
|
138 |
+
# inputs = image_processor(image, return_tensors="pt")
|
139 |
+
inputs = model(image)
|
140 |
+
|
141 |
+
predicted_label = max(inputs, key=lambda x: x["score"])
|
142 |
+
confidence = round(predicted_label["score"] * 100, 1)
|
143 |
+
|
144 |
+
response_data = {
|
145 |
+
"prediction": predicted_label["label"],
|
146 |
+
"confidence": str(confidence),
|
147 |
+
}
|
148 |
+
|
149 |
+
# Populate hash
|
150 |
+
cache[image_hash] = response_data.copy()
|
151 |
+
|
152 |
+
# Add url to the API response
|
153 |
+
response_data["url"] = image_url
|
154 |
+
|
155 |
+
return UrlImageDetectionResponse(**response_data)
|
156 |
+
|
157 |
+
# except Exception as e:
|
158 |
+
except PipelineException as e:
|
159 |
+
logging.error("Error processing image from %s: %s", image_url, str(e))
|
160 |
+
raise HTTPException(
|
161 |
+
status_code=500,
|
162 |
+
detail=f"Error processing image from {image_url}: {str(e)}",
|
163 |
+
) from e
|
164 |
+
|
165 |
@app.post("/v1/detect/urls", response_model=list[UrlImageDetectionResponse])
|
166 |
async def classify_images(request: ImageUrlsRequest):
|
167 |
"""Function analyzing images from URLs."""
|
models.py
CHANGED
@@ -3,6 +3,16 @@
|
|
3 |
from pydantic import BaseModel
|
4 |
|
5 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
6 |
class ImageUrlsRequest(BaseModel):
|
7 |
"""
|
8 |
Model representing the request body for the /v1/detect/urls endpoint.
|
|
|
3 |
from pydantic import BaseModel
|
4 |
|
5 |
|
6 |
+
class ImageUrlRequest(BaseModel):
|
7 |
+
"""
|
8 |
+
Model representing the request body for the /v1/detect/urls endpoint.
|
9 |
+
|
10 |
+
Attributes:
|
11 |
+
urls (list[str]): List of image URLs to be processed.
|
12 |
+
"""
|
13 |
+
|
14 |
+
url: str
|
15 |
+
|
16 |
class ImageUrlsRequest(BaseModel):
|
17 |
"""
|
18 |
Model representing the request body for the /v1/detect/urls endpoint.
|