brendenc commited on
Commit
cdcfb27
·
1 Parent(s): 151ebdb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -0
app.py CHANGED
@@ -8,11 +8,17 @@ import matplotlib.pyplot as plt
8
  extractor = AutoFeatureExtractor.from_pretrained("brendenc/my-segmentation-model")
9
  model = AutoModelForImageClassification.from_pretrained("brendenc/my-segmentation-model")
10
 
 
 
 
 
 
11
  def classify(im):
12
  inputs = extractor(images=im, return_tensors="pt")
13
  outputs = model(**inputs)
14
  logits = outputs.logits
15
  classes = logits[0].detach().numpy().argmax(axis=0)
 
16
  colors = np.array([[128,0,0], [128,128,0], [0, 0, 128], [128,0,128], [0, 0, 0]])
17
  return colors[classes]
18
 
 
8
  extractor = AutoFeatureExtractor.from_pretrained("brendenc/my-segmentation-model")
9
  model = AutoModelForImageClassification.from_pretrained("brendenc/my-segmentation-model")
10
 
11
+ collapse_categories = {**{i: 0 for i in range(1, 8)},
12
+ **{i: 1 for i in range(8, 10)},
13
+ **{i: 2 for i in range(10, 18)},
14
+ **{i: 3 for i in range(18, 28)}}
15
+
16
  def classify(im):
17
  inputs = extractor(images=im, return_tensors="pt")
18
  outputs = model(**inputs)
19
  logits = outputs.logits
20
  classes = logits[0].detach().numpy().argmax(axis=0)
21
+ classes = np.vectorize(lambda x: collapse_categories.get(x, 4))(classes)
22
  colors = np.array([[128,0,0], [128,128,0], [0, 0, 128], [128,0,128], [0, 0, 0]])
23
  return colors[classes]
24