File size: 3,076 Bytes
0faaa54
 
a8175a2
e90f280
0faaa54
 
 
 
 
 
 
 
 
a8175a2
0faaa54
e90f280
 
 
0faaa54
 
 
e9e0f3e
0faaa54
 
 
 
 
 
a8175a2
e90f280
0faaa54
 
 
 
e9e0f3e
a8175a2
 
0faaa54
e9e0f3e
a8175a2
0faaa54
 
e9e0f3e
e90f280
 
 
 
 
0faaa54
e90f280
0faaa54
 
e9e0f3e
0faaa54
 
a8175a2
0faaa54
 
 
 
e9e0f3e
e90f280
 
 
 
 
 
 
0faaa54
a8175a2
 
0faaa54
 
 
e9e0f3e
e90f280
 
 
 
 
 
 
 
 
0faaa54
 
a8175a2
 
 
0faaa54
a8175a2
 
0faaa54
a8175a2
e90f280
 
 
a8175a2
 
 
 
0faaa54
a8175a2
 
e9e0f3e
a8175a2
 
 
e9e0f3e
a8175a2
 
 
e90f280
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import os
import torch
import spacy
import config
from utils import (
    load_dataset,
    get_model_instance,
    load_checkpoint,
    can_load_checkpoint,
    normalize_text,
)
from PIL import Image
import torchvision.transforms as transforms
import streamlit as st

# Download Spacy model (only once during runtime)
spacy.cli.download("en_core_web_sm")

# Define device
DEVICE = 'cpu'

# Define image transformations
TRANSFORMS = transforms.Compose([
    transforms.Resize((224, 224)),  # Replace with your model's expected input size
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])


@st.cache_resource
def load_model():
    """
    Loads the model with the vocabulary and checkpoint.
    """
    st.write("Loading dataset and vocabulary...")
    dataset = load_dataset()
    vocabulary = dataset.vocab

    st.write("Initializing the model...")
    model = get_model_instance(vocabulary)

    if can_load_checkpoint():
        st.write("Loading checkpoint...")
        try:
            load_checkpoint(model)
        except RuntimeError as e:
            st.error(f"Error loading checkpoint: {e}")
            st.stop()
    else:
        st.warning("No checkpoint found, starting with untrained model.")

    model.eval()  # Set the model to evaluation mode
    st.write("Model is ready for inference.")
    return model


def preprocess_image(image_path):
    """
    Preprocess the input image for the model.
    """
    st.write(f"Preprocessing image: {image_path}")
    try:
        image = Image.open(image_path).convert("RGB")
        image = TRANSFORMS(image).unsqueeze(0)
        return image.to(DEVICE)
    except Exception as e:
        st.error(f"Error preprocessing image: {e}")
        st.stop()


def generate_report(model, image):
    """
    Generates a report for a given image using the model.
    """
    st.write("Generating report...")
    try:
        with torch.no_grad():
            output = model.generate_caption(image, max_length=25)
            report = " ".join(output)
        st.write(f"Generated report: {report}")
        return report
    except Exception as e:
        st.error(f"Error generating report: {e}")
        st.stop()


# Streamlit App
st.title("Medical Image Report Generator")
st.write("Upload an X-ray image to generate a report.")

# File uploader
uploaded_file = st.file_uploader("Choose an image file", type=["png", "jpg", "jpeg"])

if uploaded_file is not None:
    # Ensure 'temp' directory exists
    os.makedirs("temp", exist_ok=True)

    # Save uploaded file to disk
    image_path = os.path.join("temp", uploaded_file.name)
    with open(image_path, "wb") as f:
        f.write(uploaded_file.getbuffer())

    # Load the model
    model = load_model()

    # Preprocess and generate the report
    image = preprocess_image(image_path)
    report = generate_report(model, image)

    # Display the image and the report
    st.image(image_path, caption="Uploaded Image", use_column_width=True)
    st.write("Generated Report:")
    st.write(report)