import streamlit as st from io import StringIO import requests import torch from torchvision.models.inception import inception_v3 import matplotlib.pyplot as plt from skimage.transform import resize @st.cache def load_stuff(): model = inception_v3(pretrained=True, # load existing weights transform_input=True, # preprocess input image the same way as in training ) model.aux_logits = False # don't predict intermediate logits (yellow layers at the bottom) model.train(False) LABELS_URL = 'https://raw.githubusercontent.com/anishathalye/imagenet-simple-labels/master/imagenet-simple-labels.json' labels = {i: c for i, c in enumerate(requests.get(LABELS_URL).json())} return model, labels model, labels = load_stuff() def transform_input(img): return torch.as_tensor(img.reshape([1, 299, 299, 3]).transpose([0, 3, 1, 2]), dtype=torch.float32) def predict(img): img = transform_input(img) probs = torch.nn.functional.softmax(model(img), dim=-1) probs = probs.data.numpy() top_ix = probs.ravel().argsort()[-1:-10:-1] s = 'top-10 classes are: \n\n [prob : class label]\n\n' for l in top_ix: s = s + '%.4f :\t%s' % (probs.ravel()[l], labels[l].split(',')[0]) + '\n\n' return s st.markdown("### Hello dude!") uploaded_file = st.file_uploader("Choose a file") if uploaded_file is not None: # To read file as bytes: bytes_data = uploaded_file.getvalue() with open('tmp', 'wb')as f: f.write(bytes_data) img = resize(plt.imread('tmp'), (299, 299))[..., :3] top_classes = predict(img) st.markdown(top_classes) st.image('tmp')