adithiyyha commited on
Commit
a8175a2
·
verified ·
1 Parent(s): 1363912

Update inference.py

Browse files
Files changed (1) hide show
  1. inference.py +32 -37
inference.py CHANGED
@@ -1,10 +1,10 @@
1
  import os
2
  import torch
3
  import config
4
- import streamlit as st
5
- import spacy
6
- spacy.cli.download("en_core_web_sm")
7
 
 
 
8
  from utils import (
9
  load_dataset,
10
  get_model_instance,
@@ -14,6 +14,7 @@ from utils import (
14
  )
15
  from PIL import Image
16
  import torchvision.transforms as transforms
 
17
 
18
  # Define device
19
  DEVICE = 'cpu'
@@ -25,16 +26,17 @@ TRANSFORMS = transforms.Compose([
25
  transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
26
  ])
27
 
 
28
  def load_model():
29
  """
30
  Loads the model with the vocabulary and checkpoint.
31
  """
32
  st.write("Loading dataset and vocabulary...")
33
- dataset = load_dataset() # Load dataset to access vocabulary
34
- vocabulary = dataset.vocab # Assuming 'vocab' is an attribute of the dataset
35
 
36
  st.write("Initializing the model...")
37
- model = get_model_instance(vocabulary) # Initialize the model
38
 
39
  if can_load_checkpoint():
40
  st.write("Loading checkpoint...")
@@ -46,58 +48,51 @@ def load_model():
46
  st.write("Model is ready for inference.")
47
  return model
48
 
 
49
  def preprocess_image(image_path):
50
  """
51
  Preprocess the input image for the model.
52
  """
53
  st.write(f"Preprocessing image: {image_path}")
54
- image = Image.open(image_path).convert("RGB") # Ensure RGB format
55
- image = TRANSFORMS(image).unsqueeze(0) # Add batch dimension
56
  return image.to(DEVICE)
57
 
58
- def generate_report(model, image_path):
 
59
  """
60
  Generates a report for a given image using the model.
61
  """
62
- image = preprocess_image(image_path)
63
-
64
  st.write("Generating report...")
65
  with torch.no_grad():
66
- # Assuming the model has a 'generate_caption' method
67
  output = model.generate_caption(image, max_length=25)
68
  report = " ".join(output)
69
 
70
  st.write(f"Generated report: {report}")
71
  return report
72
 
73
- # Streamlit app
74
- def main():
75
- st.title("Chest X-Ray Report Generator")
76
- st.write("Upload a Chest X-Ray image to generate a medical report.")
77
 
78
- # Upload image
79
- uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
 
80
 
81
- if uploaded_file is not None:
82
- st.image(uploaded_file, caption="Uploaded Image", use_column_width=True)
83
- st.write("")
84
 
85
- # Save the uploaded file temporarily
86
- image_path = "./temp_image.png"
87
- with open(image_path, "wb") as f:
88
- f.write(uploaded_file.getbuffer())
89
- st.write("Image uploaded successfully.")
90
 
91
- # Load the model
92
- model = load_model()
93
 
94
- # Generate report
95
- report = generate_report(model, image_path)
96
- st.write("### Generated Report:")
97
- st.write(report)
98
-
99
- # Clean up temporary file
100
- os.remove(image_path)
101
 
102
- if __name__ == "__main__":
103
- main()
 
 
 
1
  import os
2
  import torch
3
  import config
4
+ !python -m spacy download en_core_web_sm
 
 
5
 
6
+ import spacy
7
+ nlp = spacy.load("en_core_web_sm")
8
  from utils import (
9
  load_dataset,
10
  get_model_instance,
 
14
  )
15
  from PIL import Image
16
  import torchvision.transforms as transforms
17
+ import streamlit as st
18
 
19
  # Define device
20
  DEVICE = 'cpu'
 
26
  transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
27
  ])
28
 
29
+
30
  def load_model():
31
  """
32
  Loads the model with the vocabulary and checkpoint.
33
  """
34
  st.write("Loading dataset and vocabulary...")
35
+ dataset = load_dataset()
36
+ vocabulary = dataset.vocab
37
 
38
  st.write("Initializing the model...")
39
+ model = get_model_instance(vocabulary)
40
 
41
  if can_load_checkpoint():
42
  st.write("Loading checkpoint...")
 
48
  st.write("Model is ready for inference.")
49
  return model
50
 
51
+
52
  def preprocess_image(image_path):
53
  """
54
  Preprocess the input image for the model.
55
  """
56
  st.write(f"Preprocessing image: {image_path}")
57
+ image = Image.open(image_path).convert("RGB")
58
+ image = TRANSFORMS(image).unsqueeze(0)
59
  return image.to(DEVICE)
60
 
61
+
62
+ def generate_report(model, image):
63
  """
64
  Generates a report for a given image using the model.
65
  """
 
 
66
  st.write("Generating report...")
67
  with torch.no_grad():
 
68
  output = model.generate_caption(image, max_length=25)
69
  report = " ".join(output)
70
 
71
  st.write(f"Generated report: {report}")
72
  return report
73
 
 
 
 
 
74
 
75
+ # Streamlit App
76
+ st.title("Medical Image Report Generator")
77
+ st.write("Upload an X-ray image to generate a report.")
78
 
79
+ # File uploader
80
+ uploaded_file = st.file_uploader("Choose an image file", type=["png", "jpg", "jpeg"])
 
81
 
82
+ if uploaded_file is not None:
83
+ # Save uploaded file to disk
84
+ image_path = os.path.join("temp", uploaded_file.name)
85
+ with open(image_path, "wb") as f:
86
+ f.write(uploaded_file.getbuffer())
87
 
88
+ # Load the model
89
+ model = load_model()
90
 
91
+ # Preprocess and generate the report
92
+ image = preprocess_image(image_path)
93
+ report = generate_report(model, image)
 
 
 
 
94
 
95
+ # Display the image and the report
96
+ st.image(image_path, caption="Uploaded Image", use_column_width=True)
97
+ st.write("Generated Report:")
98
+ st.write(report)