adithiyyha commited on
Commit
e90f280
·
verified ·
1 Parent(s): 8fd8412

Update inference.py

Browse files
Files changed (1) hide show
  1. inference.py +31 -15
inference.py CHANGED
@@ -1,9 +1,7 @@
1
  import os
2
  import torch
3
- import config
4
  import spacy
5
- spacy.cli.download("en_core_web_sm")
6
-
7
  from utils import (
8
  load_dataset,
9
  get_model_instance,
@@ -15,6 +13,9 @@ from PIL import Image
15
  import torchvision.transforms as transforms
16
  import streamlit as st
17
 
 
 
 
18
  # Define device
19
  DEVICE = 'cpu'
20
 
@@ -26,6 +27,7 @@ TRANSFORMS = transforms.Compose([
26
  ])
27
 
28
 
 
29
  def load_model():
30
  """
31
  Loads the model with the vocabulary and checkpoint.
@@ -39,9 +41,13 @@ def load_model():
39
 
40
  if can_load_checkpoint():
41
  st.write("Loading checkpoint...")
42
- load_checkpoint(model)
 
 
 
 
43
  else:
44
- st.write("No checkpoint found, starting with untrained model.")
45
 
46
  model.eval() # Set the model to evaluation mode
47
  st.write("Model is ready for inference.")
@@ -53,9 +59,13 @@ def preprocess_image(image_path):
53
  Preprocess the input image for the model.
54
  """
55
  st.write(f"Preprocessing image: {image_path}")
56
- image = Image.open(image_path).convert("RGB")
57
- image = TRANSFORMS(image).unsqueeze(0)
58
- return image.to(DEVICE)
 
 
 
 
59
 
60
 
61
  def generate_report(model, image):
@@ -63,12 +73,15 @@ def generate_report(model, image):
63
  Generates a report for a given image using the model.
64
  """
65
  st.write("Generating report...")
66
- with torch.no_grad():
67
- output = model.generate_caption(image, max_length=25)
68
- report = " ".join(output)
69
-
70
- st.write(f"Generated report: {report}")
71
- return report
 
 
 
72
 
73
 
74
  # Streamlit App
@@ -79,6 +92,9 @@ st.write("Upload an X-ray image to generate a report.")
79
  uploaded_file = st.file_uploader("Choose an image file", type=["png", "jpg", "jpeg"])
80
 
81
  if uploaded_file is not None:
 
 
 
82
  # Save uploaded file to disk
83
  image_path = os.path.join("temp", uploaded_file.name)
84
  with open(image_path, "wb") as f:
@@ -94,4 +110,4 @@ if uploaded_file is not None:
94
  # Display the image and the report
95
  st.image(image_path, caption="Uploaded Image", use_column_width=True)
96
  st.write("Generated Report:")
97
- st.write(report)
 
1
  import os
2
  import torch
 
3
  import spacy
4
+ import config
 
5
  from utils import (
6
  load_dataset,
7
  get_model_instance,
 
13
  import torchvision.transforms as transforms
14
  import streamlit as st
15
 
16
+ # Download Spacy model (only once during runtime)
17
+ spacy.cli.download("en_core_web_sm")
18
+
19
  # Define device
20
  DEVICE = 'cpu'
21
 
 
27
  ])
28
 
29
 
30
+ @st.cache_resource
31
  def load_model():
32
  """
33
  Loads the model with the vocabulary and checkpoint.
 
41
 
42
  if can_load_checkpoint():
43
  st.write("Loading checkpoint...")
44
+ try:
45
+ load_checkpoint(model)
46
+ except RuntimeError as e:
47
+ st.error(f"Error loading checkpoint: {e}")
48
+ st.stop()
49
  else:
50
+ st.warning("No checkpoint found, starting with untrained model.")
51
 
52
  model.eval() # Set the model to evaluation mode
53
  st.write("Model is ready for inference.")
 
59
  Preprocess the input image for the model.
60
  """
61
  st.write(f"Preprocessing image: {image_path}")
62
+ try:
63
+ image = Image.open(image_path).convert("RGB")
64
+ image = TRANSFORMS(image).unsqueeze(0)
65
+ return image.to(DEVICE)
66
+ except Exception as e:
67
+ st.error(f"Error preprocessing image: {e}")
68
+ st.stop()
69
 
70
 
71
  def generate_report(model, image):
 
73
  Generates a report for a given image using the model.
74
  """
75
  st.write("Generating report...")
76
+ try:
77
+ with torch.no_grad():
78
+ output = model.generate_caption(image, max_length=25)
79
+ report = " ".join(output)
80
+ st.write(f"Generated report: {report}")
81
+ return report
82
+ except Exception as e:
83
+ st.error(f"Error generating report: {e}")
84
+ st.stop()
85
 
86
 
87
  # Streamlit App
 
92
  uploaded_file = st.file_uploader("Choose an image file", type=["png", "jpg", "jpeg"])
93
 
94
  if uploaded_file is not None:
95
+ # Ensure 'temp' directory exists
96
+ os.makedirs("temp", exist_ok=True)
97
+
98
  # Save uploaded file to disk
99
  image_path = os.path.join("temp", uploaded_file.name)
100
  with open(image_path, "wb") as f:
 
110
  # Display the image and the report
111
  st.image(image_path, caption="Uploaded Image", use_column_width=True)
112
  st.write("Generated Report:")
113
+ st.write(report)