Spaces:
Sleeping
Sleeping
File size: 4,271 Bytes
aeae044 1daa088 5d22509 1daa088 aeae044 1daa088 aeae044 1daa088 aeae044 6c28e99 aeae044 1daa088 aeae044 05d6807 aeae044 58af0b8 1daa088 aeae044 58af0b8 1daa088 aeae044 58af0b8 aeae044 58af0b8 aeae044 58af0b8 aeae044 05d6807 aeae044 6c28e99 1daa088 aeae044 1daa088 aeae044 6c28e99 1daa088 5d22509 1daa088 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 |
import os
import torch
import config
import streamlit as st
import spacy
spacy.cli.download("en_core_web_sm")
from utils import (
load_dataset,
get_model_instance,
load_checkpoint,
can_load_checkpoint,
normalize_text,
)
from PIL import Image
import torchvision.transforms as transforms
# Define device
DEVICE = 'cpu'
# Define image transformations
TRANSFORMS = transforms.Compose([
transforms.Resize((224, 224)), # Replace with your model's expected input size
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
# def load_model():
# """
# Loads the model with the vocabulary and checkpoint.
# """
# st.write("Loading dataset and vocabulary...")
# dataset = load_dataset() # Load dataset to access vocabulary
# vocabulary = dataset.vocab # Assuming 'vocab' is an attribute of the dataset
# st.write("Initializing the model...")
# model = get_model_instance(vocabulary) # Initialize the model
# if can_load_checkpoint():
# st.write("Loading checkpoint...")
# load_checkpoint(model)
# else:
# st.write("No checkpoint found, starting with untrained model.")
# model.eval() # Set the model to evaluation mode
# st.write("Model is ready for inference.")
# return model
def load_model():
"""
Loads the model with the vocabulary and checkpoint.
"""
st.write("Loading dataset and vocabulary...")
dataset = load_dataset() # Load dataset to access vocabulary
vocabulary = dataset.vocab # Assuming 'vocab' is an attribute of the dataset
st.write("Initializing the model...")
model = get_model_instance(vocabulary) # Initialize the model
if can_load_checkpoint():
st.write("Loading checkpoint...")
checkpoint = torch.load(config.CHECKPOINT_FILE, map_location=DEVICE)
# Print out the checkpoint layer sizes for debugging
print({k: v.shape for k, v in checkpoint['state_dict'].items()})
# Try loading the checkpoint with strict=False to ignore mismatched layers
try:
model.load_state_dict(checkpoint['state_dict'], strict=False)
st.write("Checkpoint loaded successfully.")
except RuntimeError as e:
st.write(f"Error loading checkpoint: {e}")
st.write("Starting with untrained model.")
else:
st.write("No checkpoint found, starting with untrained model.")
model.eval() # Set the model to evaluation mode
st.write("Model is ready for inference.")
return model
def preprocess_image(image_path):
"""
Preprocess the input image for the model.
"""
st.write(f"Preprocessing image: {image_path}")
image = Image.open(image_path).convert("RGB") # Ensure RGB format
image = TRANSFORMS(image).unsqueeze(0) # Add batch dimension
return image.to(DEVICE)
def generate_report(model, image_path):
"""
Generates a report for a given image using the model.
"""
image = preprocess_image(image_path)
st.write("Generating report...")
with torch.no_grad():
# Assuming the model has a 'generate_caption' method
output = model.generate_caption(image, max_length=25)
report = " ".join(output)
st.write(f"Generated report: {report}")
return report
# Streamlit app
def main():
st.title("Chest X-Ray Report Generator")
st.write("Upload a Chest X-Ray image to generate a medical report.")
# Upload image
uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
if uploaded_file is not None:
st.image(uploaded_file, caption="Uploaded Image", use_column_width=True)
st.write("")
# Save the uploaded file temporarily
image_path = "./temp_image.png"
with open(image_path, "wb") as f:
f.write(uploaded_file.getbuffer())
st.write("Image uploaded successfully.")
# Load the model
model = load_model()
# Generate report
report = generate_report(model, image_path)
st.write("### Generated Report:")
st.write(report)
# Clean up temporary file
os.remove(image_path)
if __name__ == "__main__":
main() |