user-agent commited on
Commit
70b8f95
·
verified ·
1 Parent(s): ad52c93

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +76 -0
app.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import requests
2
+ from PIL import Image
3
+ from io import BytesIO
4
+ import torch
5
+ from torchvision import transforms
6
+ from transformers import AutoModelForImageClassification, AutoConfig
7
+ import gradio as gr
8
+
9
+ model_id = "thelabel/240903-image-tagging"
10
+ config = AutoConfig.from_pretrained(model_id)
11
+ model = AutoModelForImageClassification.from_pretrained(model_id)
12
+ model.eval()
13
+
14
+ # Standard ViT image transforms
15
+ image_transform = transforms.Compose([
16
+ transforms.Resize((224, 224)),
17
+ transforms.ToTensor(),
18
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
19
+ ])
20
+
21
+ def load_image_from_url(url):
22
+ try:
23
+ response = requests.get(url, timeout=10)
24
+ response.raise_for_status()
25
+ return Image.open(BytesIO(response.content)).convert("RGB")
26
+ except Exception as e:
27
+ return None
28
+
29
+ def predict_tags(image_url, threshold=0.5):
30
+ image = load_image_from_url(image_url)
31
+ if image is None:
32
+ return [], "Could not load image from the provided URL."
33
+
34
+ image_tensor = image_transform(image).unsqueeze(0)
35
+ with torch.no_grad():
36
+ logits = model(image_tensor).logits
37
+ probs = torch.sigmoid(logits).squeeze()
38
+
39
+ results = [
40
+ (config.idx_to_label[str(i)], float(probs[i]))
41
+ for i in range(len(probs))
42
+ if probs[i] >= threshold
43
+ ]
44
+ results.sort(key=lambda x: x[1], reverse=True)
45
+ return results, None
46
+
47
+ def gradio_predict(url, threshold):
48
+ tags, error = predict_tags(url, threshold)
49
+ if error:
50
+ return error, None
51
+ return "\n".join([f"{tag}: {score:.2f}" for tag, score in tags]), url
52
+
53
+ demo = gr.Interface(
54
+ fn=gradio_predict,
55
+ inputs=[
56
+ gr.Textbox(label="Image URL", value="https://d2q1sfov6ca7my.cloudfront.net/eyJidWNrZXQiOiJoaWNjdXAtaW1hZ2UtaG9zdGluZyIsImtleSI6ImhpY2N1cC1wcm9kdWN0cy9GQVFZTFkyNzFGLmpwZWciLCJlZGl0cyI6eyJyZXNpemUiOnsid2lkdGgiOjI1NjAsImhlaWdodCI6Mzg0MCwiZml0IjoiY292ZXIifX19?v=1748968367"),
57
+ gr.Slider(0, 1, value=0.5, step=0.01, label="Threshold"),
58
+ ],
59
+ outputs=[
60
+ gr.Textbox(label="Tags"),
61
+ gr.Image(label="Preview", type="url"),
62
+ ],
63
+ title="Image Tagging with ViT",
64
+ description="Paste an image URL and get predicted tags using thelabel/240903-image-tagging model.",
65
+ examples=[
66
+ [
67
+ "https://d2q1sfov6ca7my.cloudfront.net/eyJidWNrZXQiOiJoaWNjdXAtaW1hZ2UtaG9zdGluZyIsImtleSI6ImhpY2N1cC1wcm9kdWN0cy9GQVFZTFkyNzFGLmpwZWciLCJlZGl0cyI6eyJyZXNpemUiOnsid2lkdGgiOjI1NjAsImhlaWdodCI6Mzg0MCwiZml0IjoiY292ZXIifX19?v=1748968367", 0.5
68
+ ],
69
+ [
70
+ "https://d2q1sfov6ca7my.cloudfront.net/eyJidWNrZXQiOiJoaWNjdXAtaW1hZ2UtaG9zdGluZyIsImtleSI6ImhpY2N1cC1wcm9kdWN0cy9ON01aQkpUMDlFLmpwZWciLCJlZGl0cyI6eyJyZXNpemUiOnsid2lkdGgiOjI1NjAsImhlaWdodCI6Mzg0MCwiZml0IjoiY292ZXIifX19?v=1748968367", 0.5
71
+ ]
72
+ ]
73
+ )
74
+
75
+ if __name__ == "__main__":
76
+ demo.launch()