Spaces:
Running
Running
import os | |
import gradio as gr | |
import torch | |
import torch.nn.functional as F | |
from open_clip import create_model, get_tokenizer | |
from torchvision import transforms | |
from templates import openai_imagenet_template | |
hf_token = os.getenv("HF_TOKEN") | |
hf_writer = gr.HuggingFaceDatasetSaver(hf_token, "bioclip-demo") | |
model_str = "hf-hub:imageomics/bioclip" | |
tokenizer_str = "ViT-B-16" | |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |
preprocess_img = transforms.Compose( | |
[ | |
transforms.ToTensor(), | |
transforms.Normalize( | |
mean=(0.48145466, 0.4578275, 0.40821073), | |
std=(0.26862954, 0.26130258, 0.27577711), | |
), | |
] | |
) | |
def get_txt_features(classnames, templates): | |
all_features = [] | |
for classname in classnames: | |
txts = [template(classname) for template in templates] | |
txts = tokenizer(txts).to(device) | |
txt_features = model.encode_text(txts) | |
txt_features = F.normalize(txt_features, dim=-1).mean(dim=0) | |
txt_features /= txt_features.norm() | |
all_features.append(txt_features) | |
all_features = torch.stack(all_features, dim=1) | |
return all_features | |
def predict(img, classes: list[str]) -> dict[str, float]: | |
classes = [cls.strip() for cls in classes if cls.strip()] | |
txt_features = get_txt_features(classes, openai_imagenet_template) | |
img = preprocess_img(img).to(device) | |
img_features = model.encode_image(img.unsqueeze(0)) | |
img_features = F.normalize(img_features, dim=-1) | |
logits = (model.logit_scale.exp() * img_features @ txt_features).squeeze() | |
probs = F.softmax(logits, dim=0).to("cpu").tolist() | |
return {cls: prob for cls, prob in zip(classes, probs)} | |
def hierarchical_predict(img) -> list[str]: | |
""" | |
Predicts from the top of the tree of life down to the species. | |
""" | |
img = preprocess_img(img).to(device) | |
img_features = model.encode_image(img.unsqueeze(0)) | |
img_features = F.normalize(img_features, dim=-1) | |
breakpoint() | |
def run(img, cls_str: str) -> dict[str, float]: | |
breakpoint() | |
if cls_str: | |
classes = [cls.strip() for cls in cls_str.split("\n") if cls.strip()] | |
return predict(img, classes) | |
else: | |
return hierarchical_predict(img) | |
if __name__ == "__main__": | |
print("Starting.") | |
model = create_model(model_str, output_dict=True, require_pretrained=True) | |
model = model.to(device) | |
print("Created model.") | |
model = torch.compile(model) | |
print("Compiled model.") | |
tokenizer = get_tokenizer(tokenizer_str) | |
demo = gr.Interface( | |
fn=run, | |
inputs=[ | |
gr.Image(shape=(224, 224)), | |
gr.Textbox( | |
placeholder="dog\ncat\n...", | |
lines=3, | |
label="Classes", | |
show_label=True, | |
info="If empty, will predict from the entire tree of life.", | |
), | |
], | |
outputs=gr.Label(num_top_classes=20, label="Predictions", show_label=True), | |
allow_flagging="manual", | |
flagging_options=["Incorrect", "Other"], | |
flagging_callback=hf_writer, | |
) | |
demo.launch() | |