Spaces:
Sleeping
Sleeping
File size: 5,532 Bytes
471d95f fe0c1e0 471d95f fe0c1e0 471d95f fe0c1e0 471d95f fe0c1e0 471d95f fe0c1e0 471d95f fe0c1e0 471d95f fe0c1e0 471d95f fe0c1e0 471d95f fe0c1e0 471d95f fe0c1e0 471d95f fe0c1e0 471d95f fe0c1e0 471d95f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 |
import gradio as gr
import numpy as np
from PIL import Image, ImageDraw
import torch
import torchvision.transforms as transforms
import timm
# URL for the Hugging Face checkpoint
CHECKPOINT_URL = "https://huggingface.co/ReefNet/beit_global/resolve/main/checkpoint-60.pth"
# Class labels
all_classes = [
'Acanthastrea', 'Acropora', 'Agaricia', 'Alveopora', 'Astrea', 'Astreopora',
'Caulastraea', 'Coeloseris', 'Colpophyllia', 'Coscinaraea', 'Ctenactis',
'Cycloseris', 'Cyphastrea', 'Dendrogyra', 'Dichocoenia', 'Diploastrea',
'Diploria', 'Dipsastraea', 'Echinophyllia', 'Echinopora', 'Euphyllia',
'Eusmilia', 'Favia', 'Favites', 'Fungia', 'Galaxea', 'Gardineroseris',
'Goniastrea', 'Goniopora', 'Halomitra', 'Herpolitha', 'Hydnophora',
'Isophyllia', 'Isopora', 'Leptastrea', 'Leptoria', 'Leptoseris',
'Lithophyllon', 'Lobactis', 'Lobophyllia', 'Madracis', 'Meandrina', 'Merulina',
'Montastraea', 'Montipora', 'Mussa', 'Mussismilia', 'Mycedium', 'Orbicella',
'Oulastrea', 'Oulophyllia', 'Oxypora', 'Pachyseris', 'Pavona', 'Pectinia',
'Physogyra', 'Platygyra', 'Plerogyra', 'Plesiastrea', 'Pocillopora',
'Podabacia', 'Porites', 'Psammocora', 'Pseudodiploria', 'Sandalolitha',
'Scolymia', 'Seriatopora', 'Siderastrea', 'Stephanocoenia', 'Stylocoeniella',
'Stylophora', 'Tubastraea', 'Turbinaria'
]
# Function to load the BeIT model
def load_model(model_name):
print(f"Loading {model_name} model...")
if model_name == 'beit':
args = type('', (), {})()
args.model = 'beitv2_large_patch16_224.in1k_ft_in22k_in1k'
args.nb_classes = len(all_classes)
args.drop_path = 0.1
# Create model
model = timm.create_model(
args.model,
pretrained=False,
num_classes=args.nb_classes,
drop_path_rate=args.drop_path,
use_rel_pos_bias=True,
use_abs_pos_emb=True,
)
# Load checkpoint from Hugging Face
checkpoint = torch.hub.load_state_dict_from_url(CHECKPOINT_URL, map_location="cpu")
state_dict = checkpoint.get('model', checkpoint)
# Filter state dict
filtered_state_dict = {k: v for k, v in state_dict.items() if "relative_position_index" not in k}
model.load_state_dict(filtered_state_dict, strict=False)
else:
raise ValueError(f"Model {model_name} not implemented!")
# Move model to CUDA if available
model.eval()
if torch.cuda.is_available():
model.cuda()
return model
# Preprocessing transforms
preprocess = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
# Initialize selected model
selected_model_name = 'beit'
model = load_model(selected_model_name)
def predict_label(image):
"""Predict the label for the given image."""
# Ensure the image is a PIL Image
if isinstance(image, np.ndarray):
image = Image.fromarray(image)
elif not isinstance(image, Image.Image):
raise TypeError(f"Unexpected type {type(image)}, expected PIL.Image or numpy.ndarray.")
input_tensor = preprocess(image).unsqueeze(0)
if torch.cuda.is_available():
input_tensor = input_tensor.cuda()
with torch.no_grad():
outputs = model(input_tensor)
predicted_class = torch.argmax(outputs, dim=1).item()
return all_classes[predicted_class]
# Function to draw a rectangle on the image
def draw_rectangle(image, x, y, size=224):
image_pil = image.copy()
draw = ImageDraw.Draw(image_pil)
draw.rectangle([x, y, x + size, y + size], outline="red", width=3)
return image_pil
# Crop a region of interest
def crop_image(image, x, y, size=224):
image_np = np.array(image)
h, w, _ = image_np.shape
x = min(max(x, 0), w - size)
y = min(max(y, 0), h - size)
cropped = image_np[y:y+size, x:x+size]
return Image.fromarray(cropped)
# Gradio UI
with gr.Blocks() as demo:
gr.Markdown("## Coral Classification with BeIT Model")
with gr.Row():
with gr.Column():
image_input = gr.Image(type="pil", label="Upload Image", interactive=True)
x_slider = gr.Slider(0, 1000, step=1, value=0, label="X Coordinate")
y_slider = gr.Slider(0, 1000, step=1, value=0, label="Y Coordinate")
with gr.Column():
interactive_image = gr.Image(label="Interactive Image")
cropped_image = gr.Image(label="Cropped Patch")
label_output = gr.Textbox(label="Predicted Label")
# Interactions
def update_selection(image, x, y):
overlay_image = draw_rectangle(image, x, y)
cropped = crop_image(image, x, y)
return overlay_image, cropped
def predict_from_cropped(cropped):
return predict_label(cropped)
crop_button = gr.Button("Crop")
crop_button.click(fn=update_selection, inputs=[image_input, x_slider, y_slider], outputs=[interactive_image, cropped_image])
predict_button = gr.Button("Predict")
predict_button.click(fn=predict_from_cropped, inputs=cropped_image, outputs=label_output)
def update_sliders(image):
if image:
width, height = image.size
return gr.update(maximum=width - 224), gr.update(maximum=height - 224)
return gr.update(), gr.update()
image_input.change(fn=update_sliders, inputs=image_input, outputs=[x_slider, y_slider])
demo.launch(server_name="0.0.0.0", server_port=7860)
|