File size: 865 Bytes
1c962af
 
215a9c6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1c962af
215a9c6
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
import gradio as gr
import gradio.components as grc
from torchvision import transforms
from transformers import ViTForImageClassification

model_path = "Inf009/view-angle"
model = ViTForImageClassification.from_pretrained(model_path)
model.eval()
val_transforms = transforms.Compose(
        [
            transforms.Resize((224, 224)),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
        ]
    )

def predict_view_angle(image):
    image = val_transforms(image)
    outputs = model(image.unsqueeze(0)).logits.squeeze(0).sigmoid().detach().numpy()
    indices = sorted(range(len(outputs)), key=lambda x: outputs[x], reverse=True)
    predict_tags = ["45度俯视", "俯视", "正视"]
    return predict_tags[indices[0]]

    
app = gr.Interface(fn=predict_view_angle, inputs=grc.Image(type="pil"), outputs=grc.Textbox())
app.launch()