Thomas J. Trebat commited on
Commit
92e317c
·
1 Parent(s): 25bf539

created classes

Browse files
Files changed (1) hide show
  1. app.py +68 -23
app.py CHANGED
@@ -7,26 +7,71 @@ from timm.data import resolve_data_config
7
  from timm.data.transforms_factory import create_transform
8
 
9
 
10
- model = timm.create_model(
11
- 'hf-hub:nateraw/resnet50-oxford-iiit-pet',
12
- pretrained=True
13
- )
14
- model.eval()
15
- transform = create_transform(**resolve_data_config(model.pretrained_cfg, model=model))
16
- labels = model.pretrained_cfg['label_names']
17
- st.title("Pet Image Classification App")
18
- uploaded_image = st.file_uploader("Choose an image...", type=["jpg", "png", "jpeg"])
19
- if uploaded_image is not None:
20
- st.image(uploaded_image, caption="Uploaded Image", use_column_width=True)
21
- st.subheader("Classification Results:")
22
- image_data = uploaded_image.read()
23
- image = Image.open(io.BytesIO(image_data))
24
- output = model(transform(image).unsqueeze(0))
25
- probabilities = torch.nn.functional.softmax(output[0], dim=0)
26
- values, indices = torch.topk(probabilities, 5)
27
- predictions = [
28
- {'label': labels[i], 'score': v.item()}
29
- for i, v in zip(indices, values)
30
- ]
31
- for prediction in predictions:
32
- st.write(f"- {prediction['label']}: {prediction['score']:.4f}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  from timm.data.transforms_factory import create_transform
8
 
9
 
10
+ class ImageClassifier(object):
11
+ def __init__(self, model_name):
12
+ self.model = timm.create_model(
13
+ model_name,
14
+ pretrained=True
15
+ ).eval()
16
+
17
+ def get_top_5_predictions(self, image):
18
+ values, indices = torch.topk(self.get_output_probabilities(image), 5)
19
+ labels = self.get_labels()
20
+ return [
21
+ {'label': labels[i], 'score': v.item()}
22
+ for i, v in zip(indices, values)
23
+ ]
24
+
25
+ def get_output_probabilities(self, image):
26
+ output = self.classify_image(image)
27
+ return torch.nn.functional.softmax(output[0], dim=0)
28
+
29
+ def classify_image(self, image):
30
+ transform = self.create_image_transform()
31
+ return self.model(transform(image).unsqueeze(0))
32
+
33
+ def create_image_transform(self):
34
+ return create_transform(**resolve_data_config(
35
+ self.model.pretrained_cfg, model=self.model))
36
+
37
+ def get_labels(self):
38
+ return self.model.pretrained_cfg['label_names']
39
+
40
+ class ImageClassificationApp(object):
41
+ def __init__(self, title, model_name):
42
+ self.title = title
43
+ self.classifier = ImageClassifier(model_name)
44
+
45
+ def render(self):
46
+ st.title(self.title)
47
+ uploaded_image = self.get_uploaded_image()
48
+ if uploaded_image is not None:
49
+ self.show_image_and_results(uploaded_image)
50
+
51
+ def get_uploaded_image(self):
52
+ return st.file_uploader('Choose an image...', type=['jpg', 'png', 'jpeg'])
53
+
54
+ def show_image_and_results(self, uploaded_image):
55
+ self.show_uploaded_image(uploaded_image)
56
+ self.show_classification_results(self.get_image(uploaded_image.read()))
57
+
58
+ def show_uploaded_image(self, uploaded_image):
59
+ st.image(uploaded_image, caption='Uploaded Image', use_column_width=True)
60
+
61
+ def show_classification_results(self, image):
62
+ st.subheader('Classification Results:')
63
+ self.write_top_5_predictions(image)
64
+
65
+ def write_top_5_predictions(self, image):
66
+ for prediction in self.classifier.get_top_5_predictions(image):
67
+ st.write(f"- {prediction['label']}: {prediction['score']:.4f}")
68
+
69
+ def get_image(self, image_data):
70
+ return Image.open(io.BytesIO(image_data))
71
+
72
+
73
+ if __name__ == '__main__':
74
+ ImageClassificationApp(
75
+ 'Pet Image Classification App',
76
+ 'hf-hub:nateraw/resnet50-oxford-iiit-pet'
77
+ ).render()