hlydecker's picture
Update app.py
aeeae70 verified
raw
history blame
1.66 kB
import transformers
import torch
import torchvision
from transformers import TrainingArguments, Trainer
from transformers import ViTImageProcessor
from transformers import ViTForImageClassification
from torch.utils.data import DataLoader
from datasets import load_dataset
from torchvision.transforms import (CenterCrop,
Compose,
Normalize,
RandomHorizontalFlip,
RandomResizedCrop,
Resize,
ToTensor)
from transformers import ViTImageProcessor, ViTForImageClassification
from PIL import Image
import torch
import torch.nn.functional as F
import time
import gradio as gr
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
processor = ViTImageProcessor.from_pretrained("ViT_LCZs_v2",local_files_only=True)
model = ViTForImageClassification.from_pretrained("ViT_LCZs_v2",local_files_only=True).to(device)
def predict(image):
inputs = processor(images=image, return_tensors="pt").to(device)
outputs = model(**inputs)
logits = outputs.logits
predicted_class_prob = F.softmax(logits, dim=-1).detach().cpu().numpy().max()
predicted_class_idx = logits.argmax(-1).item()
label = model.config.id2label[predicted_class_idx].split(",")[0]
time.sleep(2)
return {label: float(predicted_class_prob)}
examples = [['data/closed_highrise.png'], ['data/open_lowrise.png'],['data/dense_trees.png'],['data/large_lowrise.png']]
gr.Interface(predict, gr.Image(type="pil"), "label", examples=examples).launch()