Spaces:
Sleeping
Sleeping
File size: 3,076 Bytes
0faaa54 a8175a2 e90f280 0faaa54 a8175a2 0faaa54 e90f280 0faaa54 e9e0f3e 0faaa54 a8175a2 e90f280 0faaa54 e9e0f3e a8175a2 0faaa54 e9e0f3e a8175a2 0faaa54 e9e0f3e e90f280 0faaa54 e90f280 0faaa54 e9e0f3e 0faaa54 a8175a2 0faaa54 e9e0f3e e90f280 0faaa54 a8175a2 0faaa54 e9e0f3e e90f280 0faaa54 a8175a2 0faaa54 a8175a2 0faaa54 a8175a2 e90f280 a8175a2 0faaa54 a8175a2 e9e0f3e a8175a2 e9e0f3e a8175a2 e90f280 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 |
import os
import torch
import spacy
import config
from utils import (
load_dataset,
get_model_instance,
load_checkpoint,
can_load_checkpoint,
normalize_text,
)
from PIL import Image
import torchvision.transforms as transforms
import streamlit as st
# Download Spacy model (only once during runtime)
spacy.cli.download("en_core_web_sm")
# Define device
DEVICE = 'cpu'
# Define image transformations
TRANSFORMS = transforms.Compose([
transforms.Resize((224, 224)), # Replace with your model's expected input size
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
@st.cache_resource
def load_model():
"""
Loads the model with the vocabulary and checkpoint.
"""
st.write("Loading dataset and vocabulary...")
dataset = load_dataset()
vocabulary = dataset.vocab
st.write("Initializing the model...")
model = get_model_instance(vocabulary)
if can_load_checkpoint():
st.write("Loading checkpoint...")
try:
load_checkpoint(model)
except RuntimeError as e:
st.error(f"Error loading checkpoint: {e}")
st.stop()
else:
st.warning("No checkpoint found, starting with untrained model.")
model.eval() # Set the model to evaluation mode
st.write("Model is ready for inference.")
return model
def preprocess_image(image_path):
"""
Preprocess the input image for the model.
"""
st.write(f"Preprocessing image: {image_path}")
try:
image = Image.open(image_path).convert("RGB")
image = TRANSFORMS(image).unsqueeze(0)
return image.to(DEVICE)
except Exception as e:
st.error(f"Error preprocessing image: {e}")
st.stop()
def generate_report(model, image):
"""
Generates a report for a given image using the model.
"""
st.write("Generating report...")
try:
with torch.no_grad():
output = model.generate_caption(image, max_length=25)
report = " ".join(output)
st.write(f"Generated report: {report}")
return report
except Exception as e:
st.error(f"Error generating report: {e}")
st.stop()
# Streamlit App
st.title("Medical Image Report Generator")
st.write("Upload an X-ray image to generate a report.")
# File uploader
uploaded_file = st.file_uploader("Choose an image file", type=["png", "jpg", "jpeg"])
if uploaded_file is not None:
# Ensure 'temp' directory exists
os.makedirs("temp", exist_ok=True)
# Save uploaded file to disk
image_path = os.path.join("temp", uploaded_file.name)
with open(image_path, "wb") as f:
f.write(uploaded_file.getbuffer())
# Load the model
model = load_model()
# Preprocess and generate the report
image = preprocess_image(image_path)
report = generate_report(model, image)
# Display the image and the report
st.image(image_path, caption="Uploaded Image", use_column_width=True)
st.write("Generated Report:")
st.write(report)
|