mnh
hello
94a8d69
raw
history blame
1.2 kB
import gradio as gr
import torch
import clip
from PIL import Image, ImageEnhance
device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load("ViT-B/32", device=device)
def predict(image):
labels = "Japanese, Chinese, Roman, Greek, Etruscan, Scandinavian, Celtic, Medieval, Victorian, Neoclassic, Romanticism, Art Nouveau, Art deco"
labels = labels.split(',')
converter = ImageEnhance.Color(image)
image = converter.enhance(0.5)
image = image.convert("L")
image = preprocess(image).unsqueeze(0).to(device)
text = clip.tokenize([f"a character of origin {c}" for c in labels]).to(device)
with torch.inference_mode():
logits_per_image, logits_per_text = model(image, text)
probs = logits_per_image.softmax(dim=-1).cpu().numpy()
return {k: float(v) for k, v in zip(labels, probs[0])}
# probs = predict(Image.open("../CLIP/CLIP.png"), "cat, dog, ball")
# print(probs)
gr.Interface(fn=predict,
inputs=[
gr.inputs.Image(label="Image to classify.", type="pil")],
theme="grass",
outputs="label",
description="Character Image classification").launch()