File size: 1,406 Bytes
c74bd0f
303db69
c74bd0f
 
 
 
 
e929201
c74bd0f
 
cdcfb27
 
 
 
 
c74bd0f
945aef0
c74bd0f
 
945aef0
cdcfb27
c74bd0f
 
 
 
 
 
 
 
4f38b79
c74bd0f
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
import gradio as gr
import numpy as np
from transformers import AutoFeatureExtractor, AutoModelForImageClassification
import torch
import matplotlib.pyplot as plt


extractor = AutoFeatureExtractor.from_pretrained("brendenc/my-segmentation-model")
model = AutoModelForImageClassification.from_pretrained("brendenc/my-segmentation-model")

collapse_categories = {**{i: 0 for i in range(1, 8)}, 
                            **{i: 1 for i in range(8, 10)}, 
                            **{i: 2 for i in range(10, 18)}, 
                            **{i: 3 for i in range(18, 28)}}
                            
def classify(im):
  inputs = extractor(images=im, return_tensors="pt")
  outputs = model(**inputs)
  logits = outputs.logits
  classes = logits[0].detach().numpy().argmax(axis=0)
  classes = np.vectorize(lambda x: collapse_categories.get(x, 4))(classes)
  colors = np.array([[128,0,0], [128,128,0], [0, 0, 128], 	[128,0,128], [0, 0, 0]])
  return colors[classes]

example_imgs = [f"example_{i}.jpg" for i in range(3)]
interface = gr.Interface(classify,
                         inputs="image",
                         outputs="image",
                         title = "Street Image Segmentation",
                         examples = example_imgs,
                         description = """Below is a simple app for image segmentation. This model was trained using""")

interface.launch(debug=True)