ankitkupadhyay's picture
Update app.py
a1ee699 verified
raw
history blame
2.75 kB
import torch
import torch
import torch.nn as nn
from transformers import ViTImageProcessor, ViTModel, BertTokenizerFast, BertModel
from PIL import Image
import gradio as gr
class VisionLanguageModel(nn.Module):
def __init__(self):
super(VisionLanguageModel, self).__init__()
self.vision_model = ViTModel.from_pretrained('google/vit-base-patch16-224-in21k')
self.language_model = BertModel.from_pretrained('bert-base-uncased')
self.classifier = nn.Linear(
self.vision_model.config.hidden_size + self.language_model.config.hidden_size,
2 # Number of classes: benign or malignant
)
def forward(self, input_ids, attention_mask, pixel_values):
vision_outputs = self.vision_model(pixel_values=pixel_values)
vision_pooled_output = vision_outputs.pooler_output
language_outputs = self.language_model(
input_ids=input_ids,
attention_mask=attention_mask
)
language_pooled_output = language_outputs.pooler_output
combined_features = torch.cat(
(vision_pooled_output, language_pooled_output),
dim=1
)
logits = self.classifier(combined_features)
return logits
# Load the model checkpoint with safer loading
model = VisionLanguageModel()
model.load_state_dict(torch.load('best_model.pth', map_location=torch.device('cpu'), weights_only=True))
model.eval()
tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')
feature_extractor = ViTImageProcessor.from_pretrained('google/vit-base-patch16-224-in21k')
def predict(image, text_input):
# Preprocess the image
image = feature_extractor(images=image, return_tensors="pt").pixel_values
# Preprocess the text
encoding = tokenizer(
text_input,
add_special_tokens=True,
max_length=256,
padding='max_length',
truncation=True,
return_tensors='pt'
)
# Make a prediction
with torch.no_grad():
outputs = model(
input_ids=encoding['input_ids'],
attention_mask=encoding['attention_mask'],
pixel_values=image
)
_, prediction = torch.max(outputs, dim=1)
return "Malignant" if prediction.item() == 1 else "Benign"
# Define Gradio interface with updated component syntax
iface = gr.Interface(
fn=predict,
inputs=[
gr.Image(type="pil", label="Upload Skin Lesion Image"),
gr.Textbox(label="Clinical Information (e.g., patient age, symptoms)")
],
outputs="text",
title="Skin Lesion Classification Demo",
description="This model classifies skin lesions as benign or malignant based on an image and clinical information."
)
iface.launch()