import os import torch import spacy import config from utils import ( load_dataset, get_model_instance, load_checkpoint, can_load_checkpoint, normalize_text, ) from PIL import Image import torchvision.transforms as transforms import streamlit as st # Download Spacy model (only once during runtime) spacy.cli.download("en_core_web_sm") # Define device DEVICE = 'cpu' # Define image transformations TRANSFORMS = transforms.Compose([ transforms.Resize((224, 224)), # Replace with your model's expected input size transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) @st.cache_resource def load_model(): """ Loads the model with the vocabulary and checkpoint. """ st.write("Loading dataset and vocabulary...") dataset = load_dataset() vocabulary = dataset.vocab st.write("Initializing the model...") model = get_model_instance(vocabulary) if can_load_checkpoint(): st.write("Loading checkpoint...") try: load_checkpoint(model) except RuntimeError as e: st.error(f"Error loading checkpoint: {e}") st.stop() else: st.warning("No checkpoint found, starting with untrained model.") model.eval() # Set the model to evaluation mode st.write("Model is ready for inference.") return model def preprocess_image(image_path): """ Preprocess the input image for the model. """ st.write(f"Preprocessing image: {image_path}") try: image = Image.open(image_path).convert("RGB") image = TRANSFORMS(image).unsqueeze(0) return image.to(DEVICE) except Exception as e: st.error(f"Error preprocessing image: {e}") st.stop() def generate_report(model, image): """ Generates a report for a given image using the model. """ st.write("Generating report...") try: with torch.no_grad(): output = model.generate_caption(image, max_length=25) report = " ".join(output) st.write(f"Generated report: {report}") return report except Exception as e: st.error(f"Error generating report: {e}") st.stop() # Streamlit App st.title("Medical Image Report Generator") st.write("Upload an X-ray image to generate a report.") # File uploader uploaded_file = st.file_uploader("Choose an image file", type=["png", "jpg", "jpeg"]) if uploaded_file is not None: # Ensure 'temp' directory exists os.makedirs("temp", exist_ok=True) # Save uploaded file to disk image_path = os.path.join("temp", uploaded_file.name) with open(image_path, "wb") as f: f.write(uploaded_file.getbuffer()) # Load the model model = load_model() # Preprocess and generate the report image = preprocess_image(image_path) report = generate_report(model, image) # Display the image and the report st.image(image_path, caption="Uploaded Image", use_column_width=True) st.write("Generated Report:") st.write(report)