adithiyyha commited on
Commit
1daa088
·
verified ·
1 Parent(s): 5d22509

Update AKSHAYRAJAA/inference.py

Browse files
Files changed (1) hide show
  1. AKSHAYRAJAA/inference.py +53 -55
AKSHAYRAJAA/inference.py CHANGED
@@ -1,42 +1,40 @@
1
  import os
2
  import torch
3
- import spacy
4
- import spacy.cli
5
- from PIL import Image
6
- import torchvision.transforms as transforms
7
  import streamlit as st
 
 
 
8
  from utils import (
9
  load_dataset,
10
  get_model_instance,
11
  load_checkpoint,
12
  can_load_checkpoint,
 
13
  )
14
-
15
- # Ensure the SpaCy model is downloaded
16
- spacy.cli.download("en_core_web_sm")
17
- nlp = spacy.load("en_core_web_sm")
18
 
19
  # Define device
20
  DEVICE = 'cpu'
21
 
22
  # Define image transformations
23
  TRANSFORMS = transforms.Compose([
24
- transforms.Resize((224, 224)), # Adjust to your model's input size
25
  transforms.ToTensor(),
26
  transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
27
  ])
28
 
29
-
30
  def load_model():
31
  """
32
  Loads the model with the vocabulary and checkpoint.
33
  """
34
  st.write("Loading dataset and vocabulary...")
35
- dataset = load_dataset()
36
- vocabulary = dataset.vocab
37
 
38
  st.write("Initializing the model...")
39
- model = get_model_instance(vocabulary)
40
 
41
  if can_load_checkpoint():
42
  st.write("Loading checkpoint...")
@@ -48,58 +46,58 @@ def load_model():
48
  st.write("Model is ready for inference.")
49
  return model
50
 
51
-
52
  def preprocess_image(image_path):
53
  """
54
  Preprocess the input image for the model.
55
  """
56
  st.write(f"Preprocessing image: {image_path}")
57
- image = Image.open(image_path).convert("RGB")
58
- image = TRANSFORMS(image).unsqueeze(0)
59
  return image.to(DEVICE)
60
 
61
-
62
- def generate_report(model, image):
63
  """
64
  Generates a report for a given image using the model.
65
  """
66
- st.write("Generating report...")
67
- try:
68
- with torch.no_grad():
69
- output = model.generate_caption(image, max_length=25) # Ensure `generate_caption` is implemented
70
- report = " ".join(output)
71
- st.write(f"Generated report: {report}")
72
- return report
73
- except Exception as e:
74
- st.error(f"Error during report generation: {e}")
75
- return None
76
-
77
-
78
- # Streamlit App
79
- st.title("Medical Image Report Generator")
80
- st.write("Upload an X-ray image to generate a report.")
81
-
82
- # Create temp directory
83
- os.makedirs("temp", exist_ok=True)
84
-
85
- # File uploader
86
- uploaded_file = st.file_uploader("Choose an image file", type=["png", "jpg", "jpeg"])
87
-
88
- if uploaded_file is not None:
89
- # Save uploaded file to disk
90
- image_path = os.path.join("temp", uploaded_file.name)
91
- with open(image_path, "wb") as f:
92
- f.write(uploaded_file.getbuffer())
93
-
94
- # Load the model
95
- model = load_model()
96
-
97
- # Preprocess and generate the report
98
  image = preprocess_image(image_path)
99
- report = generate_report(model, image)
100
 
101
- # Display the image and the report
102
- if report:
103
- st.image(image_path, caption="Uploaded Image", use_column_width=True)
104
- st.write("Generated Report:")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
105
  st.write(report)
 
 
 
 
 
 
 
1
  import os
2
  import torch
3
+ import config
 
 
 
4
  import streamlit as st
5
+ import spacy
6
+ spacy.cli.download("en_core_web_sm")
7
+
8
  from utils import (
9
  load_dataset,
10
  get_model_instance,
11
  load_checkpoint,
12
  can_load_checkpoint,
13
+ normalize_text,
14
  )
15
+ from PIL import Image
16
+ import torchvision.transforms as transforms
 
 
17
 
18
  # Define device
19
  DEVICE = 'cpu'
20
 
21
  # Define image transformations
22
  TRANSFORMS = transforms.Compose([
23
+ transforms.Resize((224, 224)), # Replace with your model's expected input size
24
  transforms.ToTensor(),
25
  transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
26
  ])
27
 
 
28
  def load_model():
29
  """
30
  Loads the model with the vocabulary and checkpoint.
31
  """
32
  st.write("Loading dataset and vocabulary...")
33
+ dataset = load_dataset() # Load dataset to access vocabulary
34
+ vocabulary = dataset.vocab # Assuming 'vocab' is an attribute of the dataset
35
 
36
  st.write("Initializing the model...")
37
+ model = get_model_instance(vocabulary) # Initialize the model
38
 
39
  if can_load_checkpoint():
40
  st.write("Loading checkpoint...")
 
46
  st.write("Model is ready for inference.")
47
  return model
48
 
 
49
  def preprocess_image(image_path):
50
  """
51
  Preprocess the input image for the model.
52
  """
53
  st.write(f"Preprocessing image: {image_path}")
54
+ image = Image.open(image_path).convert("RGB") # Ensure RGB format
55
+ image = TRANSFORMS(image).unsqueeze(0) # Add batch dimension
56
  return image.to(DEVICE)
57
 
58
+ def generate_report(model, image_path):
 
59
  """
60
  Generates a report for a given image using the model.
61
  """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
  image = preprocess_image(image_path)
 
63
 
64
+ st.write("Generating report...")
65
+ with torch.no_grad():
66
+ # Assuming the model has a 'generate_caption' method
67
+ output = model.generate_caption(image, max_length=25)
68
+ report = " ".join(output)
69
+
70
+ st.write(f"Generated report: {report}")
71
+ return report
72
+
73
+ # Streamlit app
74
+ def main():
75
+ st.title("Chest X-Ray Report Generator")
76
+ st.write("Upload a Chest X-Ray image to generate a medical report.")
77
+
78
+ # Upload image
79
+ uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
80
+
81
+ if uploaded_file is not None:
82
+ st.image(uploaded_file, caption="Uploaded Image", use_column_width=True)
83
+ st.write("")
84
+
85
+ # Save the uploaded file temporarily
86
+ image_path = "./temp_image.png"
87
+ with open(image_path, "wb") as f:
88
+ f.write(uploaded_file.getbuffer())
89
+ st.write("Image uploaded successfully.")
90
+
91
+ # Load the model
92
+ model = load_model()
93
+
94
+ # Generate report
95
+ report = generate_report(model, image_path)
96
+ st.write("### Generated Report:")
97
  st.write(report)
98
+
99
+ # Clean up temporary file
100
+ os.remove(image_path)
101
+
102
+ if __name__ == "__main__":
103
+ main()