Spaces:
Sleeping
Sleeping
# Standard Python imports | |
import os | |
import re | |
import json | |
from typing import List, Dict, Any | |
# Data processing and visualization | |
from PIL import Image | |
from tqdm import tqdm | |
from tqdm.notebook import tqdm | |
# Deep Learning & ML | |
import torch | |
from transformers import ( | |
AutoProcessor, | |
AutoModelForVision2Seq, | |
AutoTokenizer, | |
AutoModelForCausalLM, | |
TextStreamer, | |
Idefics3ForConditionalGeneration, | |
BitsAndBytesConfig | |
) | |
from unsloth import FastVisionModel | |
# Dataset handling | |
from datasets import load_from_disk | |
# API & Authentication | |
from huggingface_hub import login | |
# UI & Environment | |
import gradio as gr | |
from dotenv import load_dotenv | |
# Available models | |
MODELS = { | |
"Blood Cell Classifier with Llama-3.2": "laurru01/Llama-3.2-11B-Vision-Instruct-ft-PeripherallBloodCells", | |
"Blood Cell Classifier with Qwen2-VL": "laurru01/Qwen2-VL-2B-Instruct-ft-bloodcells-big", | |
"Blood Cell Classifier with SmolVLM": "laurru01/SmolVLM-Instruct-ft-PeripherallBloodCells", | |
} | |
# Global dictionary to store loaded models | |
loaded_models = {} | |
def initialize_models(): | |
"""Preload all models during startup""" | |
print("Initializing models...") | |
for model_name, model_path in MODELS.items(): | |
print(f"Loading {model_name}...") | |
try: | |
if "SmolVLM" in model_name: | |
# Carga específica para SmolVLM | |
base_model = Idefics3ForConditionalGeneration.from_pretrained( | |
"HuggingFaceTB/SmolVLM-Instruct", | |
device_map="auto", | |
torch_dtype=torch.bfloat16, | |
load_in_4bit=True, | |
max_memory={0: "12GB"} | |
) | |
base_model.load_adapter(model_path) | |
processor = AutoProcessor.from_pretrained("HuggingFaceTB/SmolVLM-Instruct") | |
loaded_models[model_name] = { | |
"model": base_model, | |
"processor": processor, | |
"type": "smolvlm" | |
} | |
else: | |
# Carga original para Llama y Qwen (sin cambios) | |
model, tokenizer = FastVisionModel.from_pretrained( | |
model_name=model_path, | |
load_in_4bit=True, | |
use_gradient_checkpointing="unsloth" | |
) | |
FastVisionModel.for_inference(model) | |
processor = AutoProcessor.from_pretrained(model_path) | |
loaded_models[model_name] = { | |
"model": model, | |
"tokenizer": tokenizer, | |
"processor": processor, | |
"type": "standard" | |
} | |
print(f"Successfully loaded {model_name}") | |
except Exception as e: | |
print(f"Error loading {model_name}: {str(e)}") | |
print("Model initialization complete") | |
def extract_cell_type(text): | |
"""Extract cell type from generated description""" | |
cell_types = ['neutrophil', 'lymphocyte', 'monocyte', 'eosinophil', 'basophil'] | |
text_lower = text.lower() | |
for cell_type in cell_types: | |
if cell_type in text_lower: | |
return cell_type.capitalize() | |
return "Unidentified Cell Type" | |
def generate_description_standard(model, tokenizer, image): | |
"""Generate description using standard models (Llama and Qwen)""" | |
messages = [{ | |
"role": "user", | |
"content": [ | |
{"type": "image"}, | |
{"type": "text", "text": "As a hematologist, carefully identify the type of blood cell in this image and describe its key characteristics."} | |
]}] | |
input_text = tokenizer.apply_chat_template(messages, add_generation_prompt=True) | |
inputs = tokenizer(image, input_text, add_special_tokens=False, return_tensors="pt").to("cuda") | |
text_streamer = TextStreamer(tokenizer, skip_prompt=True) | |
output = model.generate( | |
**inputs, | |
streamer=text_streamer, | |
max_new_tokens=1024, | |
use_cache=True, | |
temperature=1.5, | |
min_p=0.1 | |
) | |
raw_output = tokenizer.decode(output[0], skip_special_tokens=True) | |
if "The provided image" in raw_output: | |
start_idx = raw_output.find("assistant") | |
cleaned_output = raw_output[start_idx:] | |
else: | |
cleaned_output = raw_output | |
return cleaned_output.strip() | |
def generate_description_smolvlm(model, processor, image): | |
"""Generate description using SmolVLM model with memory-efficient settings""" | |
if image.mode != "RGB": | |
image = image.convert("RGB") | |
# Redimensionar a un tamaño más pequeño para reducir memoria | |
max_size = 192 | |
image.thumbnail((max_size, max_size), Image.Resampling.LANCZOS) | |
sample = [{ | |
"role": "user", | |
"content": [ | |
{"type": "image", "image": image}, | |
{"type": "text", "text": "As a hematologist, carefully identify the type of blood cell in this image and describe its key characteristics."} | |
] | |
}] | |
text_input = processor.apply_chat_template( | |
sample, | |
add_generation_prompt=True | |
) | |
try: | |
torch.cuda.empty_cache() | |
with torch.cuda.amp.autocast(): | |
model_inputs = processor( | |
text=text_input, | |
images=[[image]], | |
return_tensors="pt", | |
).to("cuda") | |
generated_ids = model.generate( | |
**model_inputs, | |
max_new_tokens=256, | |
do_sample=True, | |
temperature=0.7, | |
top_p=0.9, | |
repetition_penalty=1.5, | |
no_repeat_ngram_size=3, | |
num_beams=2, | |
length_penalty=1.0, | |
early_stopping=True, | |
use_cache=True, | |
pad_token_id=processor.tokenizer.pad_token_id, | |
) | |
response_ids = generated_ids[0][len(model_inputs.input_ids[0]):] | |
output_text = processor.decode( | |
response_ids, | |
skip_special_tokens=True, | |
clean_up_tokenization_spaces=True | |
).strip() | |
if len(set(output_text.split())) < 5: | |
output_text = "Error: Generated response was too repetitive. Please try again." | |
del model_inputs, generated_ids, response_ids | |
torch.cuda.empty_cache() | |
return output_text | |
except Exception as e: | |
torch.cuda.empty_cache() | |
raise e | |
def analyze_cell(image, model_name): | |
"""Main function to analyze cell images""" | |
if not isinstance(image, Image.Image): | |
return "Invalid image format. Please upload a valid image.", "", None | |
try: | |
if model_name not in loaded_models: | |
return f"Model {model_name} not loaded.", "", None | |
model_components = loaded_models[model_name] | |
if model_components["type"] == "smolvlm": | |
description = generate_description_smolvlm( | |
model_components["model"], | |
model_components["processor"], | |
image | |
) | |
else: | |
description = generate_description_standard( | |
model_components["model"], | |
model_components["tokenizer"], | |
image | |
) | |
cell_type = extract_cell_type(description) | |
return cell_type, description, image | |
except Exception as e: | |
return f"Error occurred: {str(e)}", "", None | |
# Initialize all models before starting the interface | |
initialize_models() | |
# Gradio Interface | |
with gr.Blocks() as iface: | |
gr.HTML("<h1>Blood Cell Analyzer</h1>") | |
gr.HTML("<p>Upload a microscopic blood cell image for instant classification and detailed analysis</p>") | |
with gr.Row(): | |
with gr.Column(): | |
input_image = gr.Image( | |
label="Upload Blood Cell Image", | |
type="pil", | |
sources=["upload"] | |
) | |
model_dropdown = gr.Dropdown( | |
choices=list(MODELS.keys()), | |
value=list(MODELS.keys())[0], | |
label="Select Model Version" | |
) | |
submit_btn = gr.Button("Analyze Cell") | |
with gr.Column(): | |
cell_type = gr.Textbox(label="Identified Cell Type") | |
description = gr.Textbox(label="Analysis Details", lines=8) | |
output_image = gr.Image(label="Analyzed Image") | |
submit_btn.click( | |
fn=analyze_cell, | |
inputs=[input_image, model_dropdown], | |
outputs=[cell_type, description, output_image] | |
) | |
# Enhanced CSS with modern color scheme | |
custom_css = """ | |
.container { | |
max-width: 1000px; | |
margin: auto; | |
padding: 30px; | |
background: linear-gradient(135deg, #f6f9fc 0%, #ffffff 100%); | |
border-radius: 20px; | |
box-shadow: 0 10px 20px rgba(0,0,0,0.05); | |
} | |
.title { | |
text-align: center; | |
color: #2d3436; | |
font-size: 3em; | |
font-weight: 700; | |
margin-bottom: 20px; | |
text-shadow: 2px 2px 4px rgba(0,0,0,0.1); | |
} | |
.subtitle { | |
text-align: center; | |
color: #636e72; | |
font-size: 1.2em; | |
margin-bottom: 40px; | |
} | |
.input-image { | |
border: 2px dashed #74b9ff; | |
border-radius: 15px; | |
padding: 20px; | |
transition: all 0.3s ease; | |
} | |
.input-image:hover { | |
border-color: #0984e3; | |
transform: translateY(-2px); | |
} | |
.model-dropdown { | |
background: #f8f9fa; | |
border-radius: 10px; | |
border: 1px solid #dfe6e9; | |
margin: 15px 0; | |
} | |
.submit-button { | |
background: linear-gradient(45deg, #0984e3, #74b9ff); | |
color: white; | |
border: none; | |
padding: 12px 25px; | |
border-radius: 10px; | |
font-weight: 600; | |
transition: all 0.3s ease; | |
} | |
.submit-button:hover { | |
transform: translateY(-2px); | |
box-shadow: 0 5px 15px rgba(9, 132, 227, 0.3); | |
} | |
.result-box { | |
background: white; | |
border-radius: 10px; | |
border: 1px solid #dfe6e9; | |
padding: 15px; | |
margin: 10px 0; | |
} | |
.output-image { | |
border-radius: 15px; | |
overflow: hidden; | |
box-shadow: 0 5px 15px rgba(0,0,0,0.1); | |
} | |
""" | |
# Interface | |
with gr.Blocks(css=custom_css) as iface: | |
gr.HTML("<h1 class='title'>Blood Cell Classifier</h1>") | |
gr.HTML("<p class='subtitle'>Upload a microscopic blood cell image for instant classification and detailed analysis</p>") | |
with gr.Row(): | |
with gr.Column(): | |
input_image = gr.Image( | |
label="Upload Blood Cell Image", | |
type="pil", | |
sources=["upload"], # Only allow computer uploads | |
elem_classes="input-image" | |
) | |
model_dropdown = gr.Dropdown( | |
choices=list(MODELS.keys()), | |
value=list(MODELS.keys())[0], | |
label="Select Model Version", | |
elem_classes="model-dropdown" | |
) | |
submit_btn = gr.Button( | |
"Analyze Cell", | |
variant="primary", | |
elem_classes="submit-button" | |
) | |
with gr.Column(): | |
cell_type = gr.Textbox( | |
label="Identified Cell Type", | |
elem_classes="result-box" | |
) | |
description = gr.Textbox( | |
label="Analysis Details", | |
lines=8, | |
elem_classes="result-box" | |
) | |
output_image = gr.Image( | |
label="Analyzed Image", | |
elem_classes="output-image" | |
) | |
submit_btn.click( | |
fn=analyze_cell, | |
inputs=[input_image, model_dropdown], | |
outputs=[cell_type, description, output_image] | |
) | |
gr.HTML(""" | |
<div style="text-align: center; margin-top: 30px; padding: 20px;"> | |
<p style="color: #636e72;">Developed by Laura Ruiz | MSc Bioinformatics and Biostatistics</p> | |
<a href="https://github.com/laurru01" target="_blank" | |
style="color: #0984e3; text-decoration: none; font-weight: 600;"> | |
View on GitHub | |
</a> | |
</div> | |
""") | |
# Launch the interface | |
iface.launch() | |