Spaces:
Sleeping
Sleeping
import os | |
import torch | |
import spacy | |
import config | |
from utils import ( | |
load_dataset, | |
get_model_instance, | |
load_checkpoint, | |
can_load_checkpoint, | |
normalize_text, | |
) | |
from PIL import Image | |
import torchvision.transforms as transforms | |
import streamlit as st | |
# Download Spacy model (only once during runtime) | |
spacy.cli.download("en_core_web_sm") | |
# Define device | |
DEVICE = 'cpu' | |
# Define image transformations | |
TRANSFORMS = transforms.Compose([ | |
transforms.Resize((224, 224)), # Replace with your model's expected input size | |
transforms.ToTensor(), | |
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), | |
]) | |
def load_model(): | |
""" | |
Loads the model with the vocabulary and checkpoint. | |
""" | |
st.write("Loading dataset and vocabulary...") | |
dataset = load_dataset() | |
vocabulary = dataset.vocab | |
st.write("Initializing the model...") | |
model = get_model_instance(vocabulary) | |
if can_load_checkpoint(): | |
st.write("Loading checkpoint...") | |
try: | |
load_checkpoint(model) | |
except RuntimeError as e: | |
st.error(f"Error loading checkpoint: {e}") | |
st.stop() | |
else: | |
st.warning("No checkpoint found, starting with untrained model.") | |
model.eval() # Set the model to evaluation mode | |
st.write("Model is ready for inference.") | |
return model | |
def preprocess_image(image_path): | |
""" | |
Preprocess the input image for the model. | |
""" | |
st.write(f"Preprocessing image: {image_path}") | |
try: | |
image = Image.open(image_path).convert("RGB") | |
image = TRANSFORMS(image).unsqueeze(0) | |
return image.to(DEVICE) | |
except Exception as e: | |
st.error(f"Error preprocessing image: {e}") | |
st.stop() | |
def generate_report(model, image): | |
""" | |
Generates a report for a given image using the model. | |
""" | |
st.write("Generating report...") | |
try: | |
with torch.no_grad(): | |
output = model.generate_caption(image, max_length=25) | |
report = " ".join(output) | |
st.write(f"Generated report: {report}") | |
return report | |
except Exception as e: | |
st.error(f"Error generating report: {e}") | |
st.stop() | |
# Streamlit App | |
st.title("Medical Image Report Generator") | |
st.write("Upload an X-ray image to generate a report.") | |
# File uploader | |
uploaded_file = st.file_uploader("Choose an image file", type=["png", "jpg", "jpeg"]) | |
if uploaded_file is not None: | |
# Ensure 'temp' directory exists | |
os.makedirs("temp", exist_ok=True) | |
# Save uploaded file to disk | |
image_path = os.path.join("temp", uploaded_file.name) | |
with open(image_path, "wb") as f: | |
f.write(uploaded_file.getbuffer()) | |
# Load the model | |
model = load_model() | |
# Preprocess and generate the report | |
image = preprocess_image(image_path) | |
report = generate_report(model, image) | |
# Display the image and the report | |
st.image(image_path, caption="Uploaded Image", use_column_width=True) | |
st.write("Generated Report:") | |
st.write(report) | |