adithiyyha commited on
Commit
e9e0f3e
·
verified ·
1 Parent(s): 0faaa54

Update inference.py

Browse files
Files changed (1) hide show
  1. inference.py +37 -25
inference.py CHANGED
@@ -1,6 +1,7 @@
1
  import os
2
  import torch
3
  import config
 
4
  from utils import (
5
  load_dataset,
6
  get_model_instance,
@@ -14,75 +15,86 @@ import torchvision.transforms as transforms
14
  # Define device
15
  DEVICE = 'cpu'
16
 
17
- # Define image transformations (adjust based on training setup)
18
  TRANSFORMS = transforms.Compose([
19
  transforms.Resize((224, 224)), # Replace with your model's expected input size
20
  transforms.ToTensor(),
21
  transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
22
  ])
23
 
24
-
25
  def load_model():
26
  """
27
  Loads the model with the vocabulary and checkpoint.
28
  """
29
- print("Loading dataset and vocabulary...")
30
  dataset = load_dataset() # Load dataset to access vocabulary
31
  vocabulary = dataset.vocab # Assuming 'vocab' is an attribute of the dataset
32
 
33
- print("Initializing the model...")
34
  model = get_model_instance(vocabulary) # Initialize the model
35
 
36
  if can_load_checkpoint():
37
- print("Loading checkpoint...")
38
  load_checkpoint(model)
39
  else:
40
- print("No checkpoint found, starting with untrained model.")
41
 
42
  model.eval() # Set the model to evaluation mode
43
- print("Model is ready for inference.")
44
  return model
45
 
46
-
47
  def preprocess_image(image_path):
48
  """
49
  Preprocess the input image for the model.
50
  """
51
- print(f"Preprocessing image: {image_path}")
52
  image = Image.open(image_path).convert("RGB") # Ensure RGB format
53
  image = TRANSFORMS(image).unsqueeze(0) # Add batch dimension
54
  return image.to(DEVICE)
55
 
56
-
57
  def generate_report(model, image_path):
58
  """
59
  Generates a report for a given image using the model.
60
  """
61
  image = preprocess_image(image_path)
62
 
63
- print("Generating report...")
64
  with torch.no_grad():
65
  # Assuming the model has a 'generate_caption' method
66
  output = model.generate_caption(image, max_length=25)
67
  report = " ".join(output)
68
 
69
- print(f"Generated report: {report}")
70
  return report
71
 
 
 
 
 
72
 
73
- if __name__ == "__main__":
74
- # Path to the checkpoint file
75
- CHECKPOINT_PATH = config.CHECKPOINT_FILE # Ensure config.CHECKPOINT_FILE is correctly set
76
 
77
- # Path to the input image
78
- IMAGE_PATH = "./dataset/images/CXR1178_IM-0121-1001.png" # Replace with your image path
 
79
 
80
- # Load the model
81
- model = load_model()
 
 
 
82
 
83
- # Ensure the image exists before inference
84
- if os.path.exists(IMAGE_PATH):
85
- report = generate_report(model, IMAGE_PATH)
86
- print("Final Report:", report)
87
- else:
88
- print(f"Image not found at path: {IMAGE_PATH}")
 
 
 
 
 
 
 
 
1
  import os
2
  import torch
3
  import config
4
+ import streamlit as st
5
  from utils import (
6
  load_dataset,
7
  get_model_instance,
 
15
  # Define device
16
  DEVICE = 'cpu'
17
 
18
+ # Define image transformations
19
  TRANSFORMS = transforms.Compose([
20
  transforms.Resize((224, 224)), # Replace with your model's expected input size
21
  transforms.ToTensor(),
22
  transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
23
  ])
24
 
 
25
  def load_model():
26
  """
27
  Loads the model with the vocabulary and checkpoint.
28
  """
29
+ st.write("Loading dataset and vocabulary...")
30
  dataset = load_dataset() # Load dataset to access vocabulary
31
  vocabulary = dataset.vocab # Assuming 'vocab' is an attribute of the dataset
32
 
33
+ st.write("Initializing the model...")
34
  model = get_model_instance(vocabulary) # Initialize the model
35
 
36
  if can_load_checkpoint():
37
+ st.write("Loading checkpoint...")
38
  load_checkpoint(model)
39
  else:
40
+ st.write("No checkpoint found, starting with untrained model.")
41
 
42
  model.eval() # Set the model to evaluation mode
43
+ st.write("Model is ready for inference.")
44
  return model
45
 
 
46
  def preprocess_image(image_path):
47
  """
48
  Preprocess the input image for the model.
49
  """
50
+ st.write(f"Preprocessing image: {image_path}")
51
  image = Image.open(image_path).convert("RGB") # Ensure RGB format
52
  image = TRANSFORMS(image).unsqueeze(0) # Add batch dimension
53
  return image.to(DEVICE)
54
 
 
55
  def generate_report(model, image_path):
56
  """
57
  Generates a report for a given image using the model.
58
  """
59
  image = preprocess_image(image_path)
60
 
61
+ st.write("Generating report...")
62
  with torch.no_grad():
63
  # Assuming the model has a 'generate_caption' method
64
  output = model.generate_caption(image, max_length=25)
65
  report = " ".join(output)
66
 
67
+ st.write(f"Generated report: {report}")
68
  return report
69
 
70
+ # Streamlit app
71
+ def main():
72
+ st.title("Chest X-Ray Report Generator")
73
+ st.write("Upload a Chest X-Ray image to generate a medical report.")
74
 
75
+ # Upload image
76
+ uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
 
77
 
78
+ if uploaded_file is not None:
79
+ st.image(uploaded_file, caption="Uploaded Image", use_column_width=True)
80
+ st.write("")
81
 
82
+ # Save the uploaded file temporarily
83
+ image_path = "./temp_image.png"
84
+ with open(image_path, "wb") as f:
85
+ f.write(uploaded_file.getbuffer())
86
+ st.write("Image uploaded successfully.")
87
 
88
+ # Load the model
89
+ model = load_model()
90
+
91
+ # Generate report
92
+ report = generate_report(model, image_path)
93
+ st.write("### Generated Report:")
94
+ st.write(report)
95
+
96
+ # Clean up temporary file
97
+ os.remove(image_path)
98
+
99
+ if __name__ == "__main__":
100
+ main()