brendenc commited on
Commit
06ce617
·
2 Parent(s): 0698fd4 4f38b79

Updated from colab

Browse files
Files changed (1) hide show
  1. app.py +7 -1
app.py CHANGED
@@ -8,11 +8,17 @@ import matplotlib.pyplot as plt
8
  extractor = AutoFeatureExtractor.from_pretrained("brendenc/my-segmentation-model")
9
  model = AutoModelForImageClassification.from_pretrained("brendenc/my-segmentation-model")
10
 
 
 
 
 
 
11
  def classify(im):
12
  inputs = extractor(images=im, return_tensors="pt")
13
  outputs = model(**inputs)
14
  logits = outputs.logits
15
  classes = logits[0].detach().numpy().argmax(axis=0)
 
16
  colors = np.array([[128,0,0], [128,128,0], [0, 0, 128], [128,0,128], [0, 0, 0]])
17
  return colors[classes]
18
 
@@ -20,8 +26,8 @@ example_imgs = [f"example_{i}.jpg" for i in range(3)]
20
  interface = gr.Interface(classify,
21
  inputs="image",
22
  outputs="image",
23
- examples = example_imgs,
24
  title = "Street Image Segmentation",
 
25
  description = """Below is a simple app for image segmentation. This model was trained using""")
26
 
27
  interface.launch(debug=True)
 
8
  extractor = AutoFeatureExtractor.from_pretrained("brendenc/my-segmentation-model")
9
  model = AutoModelForImageClassification.from_pretrained("brendenc/my-segmentation-model")
10
 
11
+ collapse_categories = {**{i: 0 for i in range(1, 8)},
12
+ **{i: 1 for i in range(8, 10)},
13
+ **{i: 2 for i in range(10, 18)},
14
+ **{i: 3 for i in range(18, 28)}}
15
+
16
  def classify(im):
17
  inputs = extractor(images=im, return_tensors="pt")
18
  outputs = model(**inputs)
19
  logits = outputs.logits
20
  classes = logits[0].detach().numpy().argmax(axis=0)
21
+ classes = np.vectorize(lambda x: collapse_categories.get(x, 4))(classes)
22
  colors = np.array([[128,0,0], [128,128,0], [0, 0, 128], [128,0,128], [0, 0, 0]])
23
  return colors[classes]
24
 
 
26
  interface = gr.Interface(classify,
27
  inputs="image",
28
  outputs="image",
 
29
  title = "Street Image Segmentation",
30
+ examples = example_imgs,
31
  description = """Below is a simple app for image segmentation. This model was trained using""")
32
 
33
  interface.launch(debug=True)