File size: 2,925 Bytes
0faaa54
 
 
e9e0f3e
0faaa54
 
 
 
 
 
 
 
 
 
 
 
 
e9e0f3e
0faaa54
 
 
 
 
 
 
 
 
 
e9e0f3e
0faaa54
 
 
e9e0f3e
0faaa54
 
 
e9e0f3e
0faaa54
 
e9e0f3e
0faaa54
 
e9e0f3e
0faaa54
 
 
 
 
 
e9e0f3e
0faaa54
 
 
 
 
 
 
 
 
 
e9e0f3e
0faaa54
 
 
 
 
e9e0f3e
0faaa54
 
e9e0f3e
 
 
 
0faaa54
e9e0f3e
 
0faaa54
e9e0f3e
 
 
0faaa54
e9e0f3e
 
 
 
 
0faaa54
e9e0f3e
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import os
import torch
import config
import streamlit as st
from utils import (
    load_dataset,
    get_model_instance,
    load_checkpoint,
    can_load_checkpoint,
    normalize_text,
)
from PIL import Image
import torchvision.transforms as transforms

# Define device
DEVICE = 'cpu'

# Define image transformations
TRANSFORMS = transforms.Compose([
    transforms.Resize((224, 224)),  # Replace with your model's expected input size
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

def load_model():
    """
    Loads the model with the vocabulary and checkpoint.
    """
    st.write("Loading dataset and vocabulary...")
    dataset = load_dataset()  # Load dataset to access vocabulary
    vocabulary = dataset.vocab  # Assuming 'vocab' is an attribute of the dataset

    st.write("Initializing the model...")
    model = get_model_instance(vocabulary)  # Initialize the model

    if can_load_checkpoint():
        st.write("Loading checkpoint...")
        load_checkpoint(model)
    else:
        st.write("No checkpoint found, starting with untrained model.")

    model.eval()  # Set the model to evaluation mode
    st.write("Model is ready for inference.")
    return model

def preprocess_image(image_path):
    """
    Preprocess the input image for the model.
    """
    st.write(f"Preprocessing image: {image_path}")
    image = Image.open(image_path).convert("RGB")  # Ensure RGB format
    image = TRANSFORMS(image).unsqueeze(0)  # Add batch dimension
    return image.to(DEVICE)

def generate_report(model, image_path):
    """
    Generates a report for a given image using the model.
    """
    image = preprocess_image(image_path)

    st.write("Generating report...")
    with torch.no_grad():
        # Assuming the model has a 'generate_caption' method
        output = model.generate_caption(image, max_length=25)
        report = " ".join(output)

    st.write(f"Generated report: {report}")
    return report

# Streamlit app
def main():
    st.title("Chest X-Ray Report Generator")
    st.write("Upload a Chest X-Ray image to generate a medical report.")

    # Upload image
    uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])

    if uploaded_file is not None:
        st.image(uploaded_file, caption="Uploaded Image", use_column_width=True)
        st.write("")

        # Save the uploaded file temporarily
        image_path = "./temp_image.png"
        with open(image_path, "wb") as f:
            f.write(uploaded_file.getbuffer())
        st.write("Image uploaded successfully.")

        # Load the model
        model = load_model()

        # Generate report
        report = generate_report(model, image_path)
        st.write("### Generated Report:")
        st.write(report)

        # Clean up temporary file
        os.remove(image_path)

if __name__ == "__main__":
    main()