File size: 1,679 Bytes
f57db59
efb2a0e
 
f57db59
efb2a0e
 
 
 
f57db59
efb2a0e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f57db59
 
efb2a0e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
import streamlit as st
from io import StringIO
import requests

import torch
from torchvision.models.inception import inception_v3
import matplotlib.pyplot as plt
from skimage.transform import resize

@st.cache
def load_stuff():
    model = inception_v3(pretrained=True,      # load existing weights
	             transform_input=True, # preprocess input image the same way as in training
	            )

    model.aux_logits = False # don't predict intermediate logits (yellow layers at the bottom)
    model.train(False)

    LABELS_URL = 'https://raw.githubusercontent.com/anishathalye/imagenet-simple-labels/master/imagenet-simple-labels.json'
    labels = {i: c for i, c in enumerate(requests.get(LABELS_URL).json())}
    return model, labels
    

model, labels = load_stuff()


def transform_input(img):
    return torch.as_tensor(img.reshape([1, 299, 299, 3]).transpose([0, 3, 1, 2]), dtype=torch.float32)


def predict(img):
    img = transform_input(img)
    probs = torch.nn.functional.softmax(model(img), dim=-1)
    probs = probs.data.numpy()
    top_ix = probs.ravel().argsort()[-1:-10:-1]
    s = 'top-10 classes are: \n\n [prob : class label]\n\n'
    for l in top_ix:
        s = s + '%.4f :\t%s' % (probs.ravel()[l], labels[l].split(',')[0]) + '\n\n'
    return s



st.markdown("### Hello dude!")

uploaded_file = st.file_uploader("Choose a file")
if uploaded_file is not None:
    # To read file as bytes:
    bytes_data = uploaded_file.getvalue()
    
    
    with open('tmp', 'wb')as f:
       f.write(bytes_data)
    img = resize(plt.imread('tmp'), (299, 299))[..., :3]
    
    top_classes = predict(img)
    st.markdown(top_classes)
    st.image('tmp')