|
import torch |
|
import torch |
|
import torch.nn as nn |
|
from transformers import ViTImageProcessor, ViTModel, BertTokenizerFast, BertModel |
|
from PIL import Image |
|
import gradio as gr |
|
|
|
class VisionLanguageModel(nn.Module): |
|
def __init__(self): |
|
super(VisionLanguageModel, self).__init__() |
|
self.vision_model = ViTModel.from_pretrained('google/vit-base-patch16-224-in21k') |
|
self.language_model = BertModel.from_pretrained('bert-base-uncased') |
|
self.classifier = nn.Linear( |
|
self.vision_model.config.hidden_size + self.language_model.config.hidden_size, |
|
2 |
|
) |
|
|
|
def forward(self, input_ids, attention_mask, pixel_values): |
|
vision_outputs = self.vision_model(pixel_values=pixel_values) |
|
vision_pooled_output = vision_outputs.pooler_output |
|
|
|
language_outputs = self.language_model( |
|
input_ids=input_ids, |
|
attention_mask=attention_mask |
|
) |
|
language_pooled_output = language_outputs.pooler_output |
|
|
|
combined_features = torch.cat( |
|
(vision_pooled_output, language_pooled_output), |
|
dim=1 |
|
) |
|
|
|
logits = self.classifier(combined_features) |
|
return logits |
|
|
|
|
|
model = VisionLanguageModel() |
|
model.load_state_dict(torch.load('best_model.pth', map_location=torch.device('cpu'), weights_only=True)) |
|
model.eval() |
|
|
|
tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased') |
|
feature_extractor = ViTImageProcessor.from_pretrained('google/vit-base-patch16-224-in21k') |
|
|
|
def predict(image, text_input): |
|
|
|
image = feature_extractor(images=image, return_tensors="pt").pixel_values |
|
|
|
|
|
encoding = tokenizer( |
|
text_input, |
|
add_special_tokens=True, |
|
max_length=256, |
|
padding='max_length', |
|
truncation=True, |
|
return_tensors='pt' |
|
) |
|
|
|
|
|
with torch.no_grad(): |
|
outputs = model( |
|
input_ids=encoding['input_ids'], |
|
attention_mask=encoding['attention_mask'], |
|
pixel_values=image |
|
) |
|
_, prediction = torch.max(outputs, dim=1) |
|
return "Malignant" if prediction.item() == 1 else "Benign" |
|
|
|
|
|
iface = gr.Interface( |
|
fn=predict, |
|
inputs=[ |
|
gr.Image(type="pil", label="Upload Skin Lesion Image"), |
|
gr.Textbox(label="Clinical Information (e.g., patient age, symptoms)") |
|
], |
|
outputs="text", |
|
title="Skin Lesion Classification Demo", |
|
description="This model classifies skin lesions as benign or malignant based on an image and clinical information." |
|
) |
|
|
|
iface.launch() |
|
|
|
|