File size: 2,745 Bytes
5c9bc3a a1ee699 5c9bc3a a1ee699 5c9bc3a a1ee699 5c9bc3a a1ee699 5c9bc3a a1ee699 5c9bc3a a1ee699 5c9bc3a a1ee699 5c9bc3a a1ee699 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 |
import torch
import torch
import torch.nn as nn
from transformers import ViTImageProcessor, ViTModel, BertTokenizerFast, BertModel
from PIL import Image
import gradio as gr
class VisionLanguageModel(nn.Module):
def __init__(self):
super(VisionLanguageModel, self).__init__()
self.vision_model = ViTModel.from_pretrained('google/vit-base-patch16-224-in21k')
self.language_model = BertModel.from_pretrained('bert-base-uncased')
self.classifier = nn.Linear(
self.vision_model.config.hidden_size + self.language_model.config.hidden_size,
2 # Number of classes: benign or malignant
)
def forward(self, input_ids, attention_mask, pixel_values):
vision_outputs = self.vision_model(pixel_values=pixel_values)
vision_pooled_output = vision_outputs.pooler_output
language_outputs = self.language_model(
input_ids=input_ids,
attention_mask=attention_mask
)
language_pooled_output = language_outputs.pooler_output
combined_features = torch.cat(
(vision_pooled_output, language_pooled_output),
dim=1
)
logits = self.classifier(combined_features)
return logits
# Load the model checkpoint with safer loading
model = VisionLanguageModel()
model.load_state_dict(torch.load('best_model.pth', map_location=torch.device('cpu'), weights_only=True))
model.eval()
tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')
feature_extractor = ViTImageProcessor.from_pretrained('google/vit-base-patch16-224-in21k')
def predict(image, text_input):
# Preprocess the image
image = feature_extractor(images=image, return_tensors="pt").pixel_values
# Preprocess the text
encoding = tokenizer(
text_input,
add_special_tokens=True,
max_length=256,
padding='max_length',
truncation=True,
return_tensors='pt'
)
# Make a prediction
with torch.no_grad():
outputs = model(
input_ids=encoding['input_ids'],
attention_mask=encoding['attention_mask'],
pixel_values=image
)
_, prediction = torch.max(outputs, dim=1)
return "Malignant" if prediction.item() == 1 else "Benign"
# Define Gradio interface with updated component syntax
iface = gr.Interface(
fn=predict,
inputs=[
gr.Image(type="pil", label="Upload Skin Lesion Image"),
gr.Textbox(label="Clinical Information (e.g., patient age, symptoms)")
],
outputs="text",
title="Skin Lesion Classification Demo",
description="This model classifies skin lesions as benign or malignant based on an image and clinical information."
)
iface.launch()
|