user-agent commited on
Commit
aab5af1
·
verified ·
1 Parent(s): 35b0d39

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +25 -14
app.py CHANGED
@@ -24,7 +24,7 @@ image_transform = transforms.Compose([
24
 
25
  def load_image_from_url(url):
26
  try:
27
- response = requests.get(url, timeout=10)
28
  response.raise_for_status()
29
  return Image.open(BytesIO(response.content)).convert("RGB")
30
  except Exception:
@@ -34,7 +34,7 @@ def load_image_from_url(url):
34
  def predict_tags(image_url, threshold=0.5):
35
  image = load_image_from_url(image_url)
36
  if image is None:
37
- return [], "Could not load image from the provided URL."
38
 
39
  image_tensor = image_transform(image).unsqueeze(0).to(device)
40
  with torch.no_grad():
@@ -49,25 +49,36 @@ def predict_tags(image_url, threshold=0.5):
49
  results.sort(key=lambda x: x[1], reverse=True)
50
  return results, None
51
 
52
- def gradio_predict(url, threshold):
53
- tags, error = predict_tags(url, threshold)
54
- if error:
55
- return {"error": error}
56
- if not tags:
57
- return {"error": "No tags above threshold."}
58
-
59
- top_tag, top_score = tags[0]
60
- return {"tag_name": top_tag, "tag_score": round(top_score, 4)}
 
 
 
 
 
 
 
 
 
 
 
61
 
62
  demo = gr.Interface(
63
  fn=gradio_predict,
64
  inputs=[
65
- gr.Textbox(label="Image URL"),
66
  gr.Slider(0, 1, value=0.5, step=0.01, label="Threshold"),
67
  ],
68
  outputs=gr.Textbox(label="Tags"),
69
- title="Image Tagging with ViT",
70
- description="Paste an image URL and get predicted tags using thelabel/240903-image-tagging model.",
71
  )
72
 
73
  if __name__ == "__main__":
 
24
 
25
  def load_image_from_url(url):
26
  try:
27
+ response = requests.get(url.strip(), timeout=10)
28
  response.raise_for_status()
29
  return Image.open(BytesIO(response.content)).convert("RGB")
30
  except Exception:
 
34
  def predict_tags(image_url, threshold=0.5):
35
  image = load_image_from_url(image_url)
36
  if image is None:
37
+ return None, "Could not load image."
38
 
39
  image_tensor = image_transform(image).unsqueeze(0).to(device)
40
  with torch.no_grad():
 
49
  results.sort(key=lambda x: x[1], reverse=True)
50
  return results, None
51
 
52
+ def gradio_predict(urls, threshold):
53
+ url_list = [u.strip() for u in urls.split(",") if u.strip()]
54
+ output = []
55
+
56
+ for url in url_list:
57
+ tags, error = predict_tags(url, threshold)
58
+ if error or not tags:
59
+ output.append({
60
+ "image_url": url,
61
+ "error": error or "No tags above threshold."
62
+ })
63
+ else:
64
+ top_tag, top_score = tags[0]
65
+ output.append({
66
+ "image_url": url,
67
+ "tag_name": top_tag,
68
+ "tag_score": round(top_score, 4)
69
+ })
70
+
71
+ return str(output) # Return as string for textbox display
72
 
73
  demo = gr.Interface(
74
  fn=gradio_predict,
75
  inputs=[
76
+ gr.Textbox(label="Image URL(s) (comma-separated)"),
77
  gr.Slider(0, 1, value=0.5, step=0.01, label="Threshold"),
78
  ],
79
  outputs=gr.Textbox(label="Tags"),
80
+ title="Batch Image Tagging with ViT",
81
+ description="Paste one or more image URLs separated by commas to get predicted tags using thelabel/240903-image-tagging model.",
82
  )
83
 
84
  if __name__ == "__main__":