adithiyyha commited on
Commit
6c28e99
·
verified ·
1 Parent(s): aeae044

Update AKSHAYRAJAA/inference.py

Browse files
Files changed (1) hide show
  1. AKSHAYRAJAA/inference.py +35 -29
AKSHAYRAJAA/inference.py CHANGED
@@ -10,11 +10,12 @@ from utils import (
10
  )
11
  from PIL import Image
12
  import torchvision.transforms as transforms
 
13
 
14
  # Define device
15
  DEVICE = 'cpu'
16
 
17
- # Define image transformations (adjust based on training setup)
18
  TRANSFORMS = transforms.Compose([
19
  transforms.Resize((224, 224)), # Replace with your model's expected input size
20
  transforms.ToTensor(),
@@ -26,21 +27,21 @@ def load_model():
26
  """
27
  Loads the model with the vocabulary and checkpoint.
28
  """
29
- print("Loading dataset and vocabulary...")
30
- dataset = load_dataset() # Load dataset to access vocabulary
31
- vocabulary = dataset.vocab # Assuming 'vocab' is an attribute of the dataset
32
 
33
- print("Initializing the model...")
34
- model = get_model_instance(vocabulary) # Initialize the model
35
 
36
  if can_load_checkpoint():
37
- print("Loading checkpoint...")
38
  load_checkpoint(model)
39
  else:
40
- print("No checkpoint found, starting with untrained model.")
41
 
42
  model.eval() # Set the model to evaluation mode
43
- print("Model is ready for inference.")
44
  return model
45
 
46
 
@@ -48,41 +49,46 @@ def preprocess_image(image_path):
48
  """
49
  Preprocess the input image for the model.
50
  """
51
- print(f"Preprocessing image: {image_path}")
52
- image = Image.open(image_path).convert("RGB") # Ensure RGB format
53
- image = TRANSFORMS(image).unsqueeze(0) # Add batch dimension
54
  return image.to(DEVICE)
55
 
56
 
57
- def generate_report(model, image_path):
58
  """
59
  Generates a report for a given image using the model.
60
  """
61
- image = preprocess_image(image_path)
62
-
63
- print("Generating report...")
64
  with torch.no_grad():
65
- # Assuming the model has a 'generate_caption' method
66
  output = model.generate_caption(image, max_length=25)
67
  report = " ".join(output)
68
 
69
- print(f"Generated report: {report}")
70
  return report
71
 
72
 
73
- if __name__ == "__main__":
74
- # Path to the checkpoint file
75
- CHECKPOINT_PATH = config.CHECKPOINT_FILE # Ensure config.CHECKPOINT_FILE is correctly set
 
 
 
76
 
77
- # Path to the input image
78
- IMAGE_PATH = "D:\AKSHAYRAJAA\dataset\images\CXR387_IM-1962-1001.png" # Replace with your image path
 
 
 
79
 
80
  # Load the model
81
  model = load_model()
82
 
83
- # Ensure the image exists before inference
84
- if os.path.exists(IMAGE_PATH):
85
- report = generate_report(model, IMAGE_PATH)
86
- print("Final Report:", report)
87
- else:
88
- print(f"Image not found at path: {IMAGE_PATH}")
 
 
 
10
  )
11
  from PIL import Image
12
  import torchvision.transforms as transforms
13
+ import streamlit as st
14
 
15
  # Define device
16
  DEVICE = 'cpu'
17
 
18
+ # Define image transformations
19
  TRANSFORMS = transforms.Compose([
20
  transforms.Resize((224, 224)), # Replace with your model's expected input size
21
  transforms.ToTensor(),
 
27
  """
28
  Loads the model with the vocabulary and checkpoint.
29
  """
30
+ st.write("Loading dataset and vocabulary...")
31
+ dataset = load_dataset()
32
+ vocabulary = dataset.vocab
33
 
34
+ st.write("Initializing the model...")
35
+ model = get_model_instance(vocabulary)
36
 
37
  if can_load_checkpoint():
38
+ st.write("Loading checkpoint...")
39
  load_checkpoint(model)
40
  else:
41
+ st.write("No checkpoint found, starting with untrained model.")
42
 
43
  model.eval() # Set the model to evaluation mode
44
+ st.write("Model is ready for inference.")
45
  return model
46
 
47
 
 
49
  """
50
  Preprocess the input image for the model.
51
  """
52
+ st.write(f"Preprocessing image: {image_path}")
53
+ image = Image.open(image_path).convert("RGB")
54
+ image = TRANSFORMS(image).unsqueeze(0)
55
  return image.to(DEVICE)
56
 
57
 
58
+ def generate_report(model, image):
59
  """
60
  Generates a report for a given image using the model.
61
  """
62
+ st.write("Generating report...")
 
 
63
  with torch.no_grad():
 
64
  output = model.generate_caption(image, max_length=25)
65
  report = " ".join(output)
66
 
67
+ st.write(f"Generated report: {report}")
68
  return report
69
 
70
 
71
+ # Streamlit App
72
+ st.title("Medical Image Report Generator")
73
+ st.write("Upload an X-ray image to generate a report.")
74
+
75
+ # File uploader
76
+ uploaded_file = st.file_uploader("Choose an image file", type=["png", "jpg", "jpeg"])
77
 
78
+ if uploaded_file is not None:
79
+ # Save uploaded file to disk
80
+ image_path = os.path.join("temp", uploaded_file.name)
81
+ with open(image_path, "wb") as f:
82
+ f.write(uploaded_file.getbuffer())
83
 
84
  # Load the model
85
  model = load_model()
86
 
87
+ # Preprocess and generate the report
88
+ image = preprocess_image(image_path)
89
+ report = generate_report(model, image)
90
+
91
+ # Display the image and the report
92
+ st.image(image_path, caption="Uploaded Image", use_column_width=True)
93
+ st.write("Generated Report:")
94
+ st.write(report)