File size: 1,736 Bytes
08a87c6
 
 
94a8d69
08a87c6
 
 
 
 
 
af94a82
 
08a87c6
 
94a8d69
 
 
08a87c6
6d848c4
08a87c6
 
 
 
 
 
 
 
 
 
 
 
 
 
0f3f608
08a87c6
c79b019
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
import gradio as gr
import torch
import clip
from PIL import Image, ImageEnhance

device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load("ViT-B/32", device=device)


def predict(image):
    labels = "Early American Art,19th and 20th–Century Art,Contemporary Art,Modern Folk,African American Art,Latino Art,Mesoamerican,Egyptian,British Art,Celtic Art,German Art,Medieval European,Gothic,Native American,African Art,Asia pacific Art,Oceanía,Classical,Byzantine,Medieval,Gothic,Renaissance,Baroque,Rococo,Neoclassical,Modernism,Postmodern ,Irish,German,French,Italian,Spanish,Portuguese,Greek,Chinese,Japanese,Korean,Thai,Australian,Middle Eastern,Mesopotamian,Prehistoric,Mexican,Popart,Scottish,Netherlands"
    # labels = "Japanese, Chinese, Roman, Greek, Etruscan, Scandinavian, Celtic, Medieval, Victorian, Neoclassic, Romanticism, Art Nouveau, Art deco"
    labels = labels.split(',')

    converter = ImageEnhance.Color(image)
    image = converter.enhance(0.5)
    image = image.convert("L")
    image = preprocess(image).unsqueeze(0).to(device)
    text = clip.tokenize([f"a character of origin: {c}" for c in labels]).to(device)

    with torch.inference_mode():
        logits_per_image, logits_per_text = model(image, text)
        probs = logits_per_image.softmax(dim=-1).cpu().numpy()

    return {k: float(v) for k, v in zip(labels, probs[0])}

# probs = predict(Image.open("../CLIP/CLIP.png"), "cat, dog, ball")
# print(probs)


gr.Interface(fn=predict,
             inputs=[
                 gr.inputs.Image(label="Image to classify.", type="pil")],
             theme="gradio/monochrome",
             outputs="label",
             description="Character Image classification").launch()