File size: 4,271 Bytes
aeae044
 
1daa088
5d22509
1daa088
 
 
aeae044
 
 
 
 
1daa088
aeae044
1daa088
 
aeae044
 
 
 
6c28e99
aeae044
1daa088
aeae044
 
 
 
05d6807
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
aeae044
 
 
 
58af0b8
1daa088
 
aeae044
58af0b8
1daa088
aeae044
 
58af0b8
 
 
 
 
 
 
 
 
 
 
 
 
aeae044
58af0b8
aeae044
 
58af0b8
aeae044
 
05d6807
aeae044
 
 
 
6c28e99
1daa088
 
aeae044
 
1daa088
aeae044
 
 
6c28e99
 
1daa088
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5d22509
1daa088
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
import os
import torch
import config
import streamlit as st
import spacy
spacy.cli.download("en_core_web_sm")

from utils import (
    load_dataset,
    get_model_instance,
    load_checkpoint,
    can_load_checkpoint,
    normalize_text,
)
from PIL import Image
import torchvision.transforms as transforms

# Define device
DEVICE = 'cpu'

# Define image transformations
TRANSFORMS = transforms.Compose([
    transforms.Resize((224, 224)),  # Replace with your model's expected input size
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# def load_model():
#     """
#     Loads the model with the vocabulary and checkpoint.
#     """
#     st.write("Loading dataset and vocabulary...")
#     dataset = load_dataset()  # Load dataset to access vocabulary
#     vocabulary = dataset.vocab  # Assuming 'vocab' is an attribute of the dataset

#     st.write("Initializing the model...")
#     model = get_model_instance(vocabulary)  # Initialize the model

#     if can_load_checkpoint():
#         st.write("Loading checkpoint...")
#         load_checkpoint(model)
#     else:
#         st.write("No checkpoint found, starting with untrained model.")

#     model.eval()  # Set the model to evaluation mode
#     st.write("Model is ready for inference.")
#     return model
def load_model():
    """
    Loads the model with the vocabulary and checkpoint.
    """
    st.write("Loading dataset and vocabulary...")
    dataset = load_dataset()  # Load dataset to access vocabulary
    vocabulary = dataset.vocab  # Assuming 'vocab' is an attribute of the dataset

    st.write("Initializing the model...")
    model = get_model_instance(vocabulary)  # Initialize the model

    if can_load_checkpoint():
        st.write("Loading checkpoint...")
        checkpoint = torch.load(config.CHECKPOINT_FILE, map_location=DEVICE)
        
        # Print out the checkpoint layer sizes for debugging
        print({k: v.shape for k, v in checkpoint['state_dict'].items()})

        # Try loading the checkpoint with strict=False to ignore mismatched layers
        try:
            model.load_state_dict(checkpoint['state_dict'], strict=False)
            st.write("Checkpoint loaded successfully.")
        except RuntimeError as e:
            st.write(f"Error loading checkpoint: {e}")
            st.write("Starting with untrained model.")
    else:
        st.write("No checkpoint found, starting with untrained model.")

    model.eval()  # Set the model to evaluation mode
    st.write("Model is ready for inference.")
    return model


def preprocess_image(image_path):
    """
    Preprocess the input image for the model.
    """
    st.write(f"Preprocessing image: {image_path}")
    image = Image.open(image_path).convert("RGB")  # Ensure RGB format
    image = TRANSFORMS(image).unsqueeze(0)  # Add batch dimension
    return image.to(DEVICE)

def generate_report(model, image_path):
    """
    Generates a report for a given image using the model.
    """
    image = preprocess_image(image_path)

    st.write("Generating report...")
    with torch.no_grad():
        # Assuming the model has a 'generate_caption' method
        output = model.generate_caption(image, max_length=25)
        report = " ".join(output)

    st.write(f"Generated report: {report}")
    return report

# Streamlit app
def main():
    st.title("Chest X-Ray Report Generator")
    st.write("Upload a Chest X-Ray image to generate a medical report.")

    # Upload image
    uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])

    if uploaded_file is not None:
        st.image(uploaded_file, caption="Uploaded Image", use_column_width=True)
        st.write("")

        # Save the uploaded file temporarily
        image_path = "./temp_image.png"
        with open(image_path, "wb") as f:
            f.write(uploaded_file.getbuffer())
        st.write("Image uploaded successfully.")

        # Load the model
        model = load_model()

        # Generate report
        report = generate_report(model, image_path)
        st.write("### Generated Report:")
        st.write(report)

        # Clean up temporary file
        os.remove(image_path)

if __name__ == "__main__":
    main()