File size: 2,558 Bytes
0faaa54
 
 
a8175a2
98b1d34
 
0faaa54
 
 
 
 
 
 
 
 
a8175a2
0faaa54
 
 
 
e9e0f3e
0faaa54
 
 
 
 
 
a8175a2
0faaa54
 
 
 
e9e0f3e
a8175a2
 
0faaa54
e9e0f3e
a8175a2
0faaa54
 
e9e0f3e
0faaa54
 
e9e0f3e
0faaa54
 
e9e0f3e
0faaa54
 
a8175a2
0faaa54
 
 
 
e9e0f3e
a8175a2
 
0faaa54
 
a8175a2
 
0faaa54
 
 
e9e0f3e
0faaa54
 
 
 
e9e0f3e
0faaa54
 
 
a8175a2
 
 
0faaa54
a8175a2
 
0faaa54
a8175a2
 
 
 
 
0faaa54
a8175a2
 
e9e0f3e
a8175a2
 
 
e9e0f3e
a8175a2
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import os
import torch
import config
import spacy
spacy.cli.download("en_core_web_sm")

from utils import (
    load_dataset,
    get_model_instance,
    load_checkpoint,
    can_load_checkpoint,
    normalize_text,
)
from PIL import Image
import torchvision.transforms as transforms
import streamlit as st

# Define device
DEVICE = 'cpu'

# Define image transformations
TRANSFORMS = transforms.Compose([
    transforms.Resize((224, 224)),  # Replace with your model's expected input size
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])


def load_model():
    """
    Loads the model with the vocabulary and checkpoint.
    """
    st.write("Loading dataset and vocabulary...")
    dataset = load_dataset()
    vocabulary = dataset.vocab

    st.write("Initializing the model...")
    model = get_model_instance(vocabulary)

    if can_load_checkpoint():
        st.write("Loading checkpoint...")
        load_checkpoint(model)
    else:
        st.write("No checkpoint found, starting with untrained model.")

    model.eval()  # Set the model to evaluation mode
    st.write("Model is ready for inference.")
    return model


def preprocess_image(image_path):
    """
    Preprocess the input image for the model.
    """
    st.write(f"Preprocessing image: {image_path}")
    image = Image.open(image_path).convert("RGB")
    image = TRANSFORMS(image).unsqueeze(0)
    return image.to(DEVICE)


def generate_report(model, image):
    """
    Generates a report for a given image using the model.
    """
    st.write("Generating report...")
    with torch.no_grad():
        output = model.generate_caption(image, max_length=25)
        report = " ".join(output)

    st.write(f"Generated report: {report}")
    return report


# Streamlit App
st.title("Medical Image Report Generator")
st.write("Upload an X-ray image to generate a report.")

# File uploader
uploaded_file = st.file_uploader("Choose an image file", type=["png", "jpg", "jpeg"])

if uploaded_file is not None:
    # Save uploaded file to disk
    image_path = os.path.join("temp", uploaded_file.name)
    with open(image_path, "wb") as f:
        f.write(uploaded_file.getbuffer())

    # Load the model
    model = load_model()

    # Preprocess and generate the report
    image = preprocess_image(image_path)
    report = generate_report(model, image)

    # Display the image and the report
    st.image(image_path, caption="Uploaded Image", use_column_width=True)
    st.write("Generated Report:")
    st.write(report)