Spaces:
Runtime error
Runtime error
File size: 2,238 Bytes
d1c1a86 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 |
import gradio as gr
import torch
import torch.nn.functional as F
from torchvision import transforms
from open_clip import create_model, get_tokenizer
from open_clip.training.imagenet_zeroshot_data import openai_imagenet_template
model_str = "ViT-B-16"
pretrained = "/fs/ess/PAS2136/foundation_model/model/10m/2023_09_22-21_14_04-model_ViT-B-16-lr_0.0001-b_4096-j_8-p_amp/checkpoints/epoch_99.pt"
preprocess_img = transforms.Compose(
[
transforms.ToTensor(),
transforms.Normalize(
mean=(0.48145466, 0.4578275, 0.40821073),
std=(0.26862954, 0.26130258, 0.27577711),
),
]
)
@torch.no_grad()
def get_txt_features(classnames, templates):
all_features = []
for classname in classnames:
txts = [template(classname) for template in templates]
txts = tokenizer(txts)
txt_features = model.encode_text(txts)
txt_features = F.normalize(txt_features, dim=-1).mean(dim=0)
txt_features /= txt_features.norm()
all_features.append(txt_features)
all_features = torch.stack(all_features, dim=1)
return all_features
@torch.no_grad()
def predict(img, cls_str: str) -> dict[str, float]:
classes = [cls.strip() for cls in cls_str.split("\n") if cls.strip()]
txt_features = get_txt_features(classes, openai_imagenet_template)
img = preprocess_img(img)
img_features = model.encode_image(img.unsqueeze(0))
img_features = F.normalize(img_features, dim=-1)
logits = (img_features @ txt_features).squeeze()
probs = F.softmax(logits, dim=0).tolist()
return {cls: prob for cls, prob in zip(classes, probs)}
if __name__ == "__main__":
print("Starting.")
model = create_model(model_str, pretrained, output_dict=True)
print("Created model.")
model = torch.compile(model)
print("Compiled model.")
tokenizer = get_tokenizer(model_str)
demo = gr.Interface(
fn=predict,
inputs=[
gr.Image(shape=(224, 224)),
gr.Textbox(
placeholder="dog\ncat\n...", lines=3, label="Classes", show_label=True
),
],
outputs=gr.Label(num_top_classes=20, label="Predictions", show_label=True),
)
demo.launch()
|