hlydecker commited on
Commit
075db70
·
verified ·
1 Parent(s): 4666a7e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +40 -10
app.py CHANGED
@@ -25,16 +25,46 @@ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
25
  processor = ViTImageProcessor.from_pretrained("ViT_LCZs_v3",local_files_only=True)
26
  model = ViTForImageClassification.from_pretrained("ViT_LCZs_v3",local_files_only=True).to(device)
27
 
28
- def predict(image):
29
- inputs = processor(images=image, return_tensors="pt").to(device)
30
- outputs = model(**inputs)
 
 
 
 
 
 
 
 
 
 
31
  logits = outputs.logits
32
- predicted_class_prob = F.softmax(logits, dim=-1).detach().cpu().numpy().max()
33
- predicted_class_idx = logits.argmax(-1).item()
34
- label = model.config.id2label[predicted_class_idx].split(",")[0]
35
- time.sleep(2)
36
- return {label: float(predicted_class_prob)}
 
 
 
 
 
 
 
 
 
 
 
 
 
37
 
 
 
 
 
 
 
 
 
38
 
39
- examples = [['data/closed_highrise.png'], ['data/open_lowrise.png'],['data/dense_trees.png'],['data/large_lowrise.png']]
40
- gr.Interface(predict, gr.Image(type="pil"), "label", examples=examples).launch()
 
25
  processor = ViTImageProcessor.from_pretrained("ViT_LCZs_v3",local_files_only=True)
26
  model = ViTForImageClassification.from_pretrained("ViT_LCZs_v3",local_files_only=True).to(device)
27
 
28
+ import os, glob
29
+
30
+ examples_dir = './samples'
31
+ example_files = glob.glob(os.path.join(examples_dir, '*.jpg'))
32
+
33
+ def classify_image(image):
34
+
35
+ with torch.no_grad():
36
+ model.eval()
37
+
38
+ inputs = feature_extractor(images=image, return_tensors="pt")
39
+ outputs = model(**inputs)
40
+
41
  logits = outputs.logits
42
+ prob = torch.nn.functional.softmax(logits, dim=1)
43
+
44
+ top10_prob, top10_indices = torch.topk(prob, 10)
45
+ top10_confidences = {}
46
+ for i in range(10):
47
+ top10_confidences[model.config.id2label[int(top10_indices[0][i])]] = float(top10_prob[0][i])
48
+
49
+ return top10_confidences #confidences
50
+
51
+
52
+ with gr.Blocks(title="ViT LCZ Classification - ClassCat",
53
+ css=".gradio-container {background:white;}"
54
+ ) as demo:
55
+ gr.HTML("""<div style="font-family:'Times New Roman', 'Serif'; font-size:16pt; font-weight:bold; text-align:center; color:royalblue;">LCZ Classification with ViT</div>""")
56
+
57
+ with gr.Row():
58
+ input_image = gr.Image(type="pil", image_mode="RGB", shape=(224, 224))
59
+ output_label=gr.Label(label="Probabilities", num_top_classes=3)
60
 
61
+ send_btn = gr.Button("Infer")
62
+ send_btn.click(fn=classify_image, inputs=input_image, outputs=output_label)
63
+
64
+ with gr.Row():
65
+ gr.Examples(['data/closed_highrise.png'], label='Sample images : cat', inputs=input_image)
66
+ gr.Examples(['data/open_lowrise.png'], label='cheetah', inputs=input_image)
67
+ gr.Examples(['data/dense_trees.png'], label='hotdog', inputs=input_image)
68
+ gr.Examples(['data/large_lowrise.png'], label='lion', inputs=input_image)
69
 
70
+ demo.launch(debug=True)