Spaces:
Sleeping
Sleeping
File size: 2,925 Bytes
0faaa54 e9e0f3e 0faaa54 e9e0f3e 0faaa54 e9e0f3e 0faaa54 e9e0f3e 0faaa54 e9e0f3e 0faaa54 e9e0f3e 0faaa54 e9e0f3e 0faaa54 e9e0f3e 0faaa54 e9e0f3e 0faaa54 e9e0f3e 0faaa54 e9e0f3e 0faaa54 e9e0f3e 0faaa54 e9e0f3e 0faaa54 e9e0f3e 0faaa54 e9e0f3e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 |
import os
import torch
import config
import streamlit as st
from utils import (
load_dataset,
get_model_instance,
load_checkpoint,
can_load_checkpoint,
normalize_text,
)
from PIL import Image
import torchvision.transforms as transforms
# Define device
DEVICE = 'cpu'
# Define image transformations
TRANSFORMS = transforms.Compose([
transforms.Resize((224, 224)), # Replace with your model's expected input size
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
def load_model():
"""
Loads the model with the vocabulary and checkpoint.
"""
st.write("Loading dataset and vocabulary...")
dataset = load_dataset() # Load dataset to access vocabulary
vocabulary = dataset.vocab # Assuming 'vocab' is an attribute of the dataset
st.write("Initializing the model...")
model = get_model_instance(vocabulary) # Initialize the model
if can_load_checkpoint():
st.write("Loading checkpoint...")
load_checkpoint(model)
else:
st.write("No checkpoint found, starting with untrained model.")
model.eval() # Set the model to evaluation mode
st.write("Model is ready for inference.")
return model
def preprocess_image(image_path):
"""
Preprocess the input image for the model.
"""
st.write(f"Preprocessing image: {image_path}")
image = Image.open(image_path).convert("RGB") # Ensure RGB format
image = TRANSFORMS(image).unsqueeze(0) # Add batch dimension
return image.to(DEVICE)
def generate_report(model, image_path):
"""
Generates a report for a given image using the model.
"""
image = preprocess_image(image_path)
st.write("Generating report...")
with torch.no_grad():
# Assuming the model has a 'generate_caption' method
output = model.generate_caption(image, max_length=25)
report = " ".join(output)
st.write(f"Generated report: {report}")
return report
# Streamlit app
def main():
st.title("Chest X-Ray Report Generator")
st.write("Upload a Chest X-Ray image to generate a medical report.")
# Upload image
uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
if uploaded_file is not None:
st.image(uploaded_file, caption="Uploaded Image", use_column_width=True)
st.write("")
# Save the uploaded file temporarily
image_path = "./temp_image.png"
with open(image_path, "wb") as f:
f.write(uploaded_file.getbuffer())
st.write("Image uploaded successfully.")
# Load the model
model = load_model()
# Generate report
report = generate_report(model, image_path)
st.write("### Generated Report:")
st.write(report)
# Clean up temporary file
os.remove(image_path)
if __name__ == "__main__":
main()
|