avanish07 commited on
Commit
3b4985d
·
1 Parent(s): 963d428

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -13
app.py CHANGED
@@ -81,23 +81,16 @@ transform = transforms.Compose([
81
 
82
  # Define the prediction function
83
  def predict_count(input_image):
84
- # Preprocess the input image
85
  image = transform(input_image).unsqueeze(0).cpu()
86
-
87
- # Perform the forward pass
88
  output = csrmodel(image)
89
-
90
- # Calculate the predicted count
91
  predicted_count = int(output.detach().cpu().sum().numpy())
 
 
 
92
 
93
- return predicted_count
94
-
95
- # Define the input and output interfaces for Gradio
96
- input_interface = gr.inputs.Image()
97
- output_interface = gr.outputs.Textbox()
98
-
99
- # Create the Gradio app
100
- grapp = gr.Interface(fn=predict_count, inputs=input_interface, outputs=output_interface)
101
 
 
102
  # Launch the app
103
  grapp.launch()
 
81
 
82
  # Define the prediction function
83
  def predict_count(input_image):
 
84
  image = transform(input_image).unsqueeze(0).cpu()
 
 
85
  output = csrmodel(image)
 
 
86
  predicted_count = int(output.detach().cpu().sum().numpy())
87
+ density_map = output.detach().cpu().numpy().reshape(output.shape[2], output.shape[3])
88
+ density_map_color = plt.cm.jet(density_map / np.max(density_map))
89
+ return predicted_count, density_map_color
90
 
91
+ output_interface = gr.outputs.Textbox(label="Predicted Count")
92
+ density_map_interface = gr.outputs.Image(label="Density Map")
 
 
 
 
 
 
93
 
94
+ grapp = gr.Interface(fn=predict_count, inputs=input_interface, outputs=[output_interface, density_map_interface])
95
  # Launch the app
96
  grapp.launch()