adithiyyha commited on
Commit
5d22509
·
verified ·
1 Parent(s): eb41174

Update AKSHAYRAJAA/inference.py

Browse files
Files changed (1) hide show
  1. AKSHAYRAJAA/inference.py +22 -15
AKSHAYRAJAA/inference.py CHANGED
@@ -1,18 +1,18 @@
1
  import os
2
  import torch
3
- import config
4
  import spacy
5
  import spacy.cli
 
 
 
6
  from utils import (
7
  load_dataset,
8
  get_model_instance,
9
  load_checkpoint,
10
  can_load_checkpoint,
11
- normalize_text,
12
  )
13
- from PIL import Image
14
- import torchvision.transforms as transforms
15
- import streamlit as st
16
  spacy.cli.download("en_core_web_sm")
17
  nlp = spacy.load("en_core_web_sm")
18
 
@@ -21,7 +21,7 @@ DEVICE = 'cpu'
21
 
22
  # Define image transformations
23
  TRANSFORMS = transforms.Compose([
24
- transforms.Resize((224, 224)), # Replace with your model's expected input size
25
  transforms.ToTensor(),
26
  transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
27
  ])
@@ -64,18 +64,24 @@ def generate_report(model, image):
64
  Generates a report for a given image using the model.
65
  """
66
  st.write("Generating report...")
67
- with torch.no_grad():
68
- output = model.generate_caption(image, max_length=25)
69
- report = " ".join(output)
70
-
71
- st.write(f"Generated report: {report}")
72
- return report
 
 
 
73
 
74
 
75
  # Streamlit App
76
  st.title("Medical Image Report Generator")
77
  st.write("Upload an X-ray image to generate a report.")
78
 
 
 
 
79
  # File uploader
80
  uploaded_file = st.file_uploader("Choose an image file", type=["png", "jpg", "jpeg"])
81
 
@@ -93,6 +99,7 @@ if uploaded_file is not None:
93
  report = generate_report(model, image)
94
 
95
  # Display the image and the report
96
- st.image(image_path, caption="Uploaded Image", use_column_width=True)
97
- st.write("Generated Report:")
98
- st.write(report)
 
 
1
  import os
2
  import torch
 
3
  import spacy
4
  import spacy.cli
5
+ from PIL import Image
6
+ import torchvision.transforms as transforms
7
+ import streamlit as st
8
  from utils import (
9
  load_dataset,
10
  get_model_instance,
11
  load_checkpoint,
12
  can_load_checkpoint,
 
13
  )
14
+
15
+ # Ensure the SpaCy model is downloaded
 
16
  spacy.cli.download("en_core_web_sm")
17
  nlp = spacy.load("en_core_web_sm")
18
 
 
21
 
22
  # Define image transformations
23
  TRANSFORMS = transforms.Compose([
24
+ transforms.Resize((224, 224)), # Adjust to your model's input size
25
  transforms.ToTensor(),
26
  transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
27
  ])
 
64
  Generates a report for a given image using the model.
65
  """
66
  st.write("Generating report...")
67
+ try:
68
+ with torch.no_grad():
69
+ output = model.generate_caption(image, max_length=25) # Ensure `generate_caption` is implemented
70
+ report = " ".join(output)
71
+ st.write(f"Generated report: {report}")
72
+ return report
73
+ except Exception as e:
74
+ st.error(f"Error during report generation: {e}")
75
+ return None
76
 
77
 
78
  # Streamlit App
79
  st.title("Medical Image Report Generator")
80
  st.write("Upload an X-ray image to generate a report.")
81
 
82
+ # Create temp directory
83
+ os.makedirs("temp", exist_ok=True)
84
+
85
  # File uploader
86
  uploaded_file = st.file_uploader("Choose an image file", type=["png", "jpg", "jpeg"])
87
 
 
99
  report = generate_report(model, image)
100
 
101
  # Display the image and the report
102
+ if report:
103
+ st.image(image_path, caption="Uploaded Image", use_column_width=True)
104
+ st.write("Generated Report:")
105
+ st.write(report)