medaid-simple / app.py
mrmuminov's picture
Create app.py
d3c0464 verified
raw
history blame
1.61 kB
import torch
import numpy as np
import pydicom
import gradio as gr
from torchvision import transforms
from PIL import Image
# Load your PyTorch model
model = torch.load('your_model.pth') # Replace with your model path
model.eval()
# Define a function to preprocess the DICOM
def preprocess_dicom(dicom_path):
# Load DICOM file
dicom = pydicom.dcmread(dicom_path)
image = dicom.pixel_array # Extract image data
# Normalize to [0, 1] and convert to PIL Image for transforms
image = (image - np.min(image)) / (np.max(image) - np.min(image))
image = Image.fromarray((image * 255).astype(np.uint8))
# Apply transforms
transform = transforms.Compose([
transforms.Resize((224, 224)), # Resize to model's input size
transforms.ToTensor(),
])
return transform(image).unsqueeze(0) # Add batch dimension
# Prediction function
def predict_dicom(dicom_file):
# Preprocess
input_tensor = preprocess_dicom(dicom_file.name)
# Inference
with torch.no_grad():
output = model(input_tensor)
# Convert output tensor to image (dummy example, replace as needed)
output_image = output.squeeze().numpy()
output_image = (output_image - np.min(output_image)) / (np.max(output_image) - np.min(output_image)) * 255
output_image = Image.fromarray(output_image.astype(np.uint8))
return output_image
# Create Gradio interface
interface = gr.Interface(
fn=predict_dicom,
inputs=gr.inputs.File(label="Upload DICOM File"),
outputs="image",
title="DICOM Image Prediction"
)
# Launch the Gradio app
interface.launch()