xingqiang commited on
Commit
9d4731d
·
1 Parent(s): 0e523e0

Enhanced UI with dark mode, streaming progress, and updated dependencies

Browse files
Files changed (2) hide show
  1. app.py +99 -10
  2. requirements.txt +16 -10
app.py CHANGED
@@ -11,6 +11,7 @@ import psutil
11
  import plotly.express as px
12
  import plotly.graph_objects as go
13
  import pandas as pd
 
14
 
15
  from model import RadarDetectionModel
16
  from feature_extraction import (calculate_amplitude, classify_amplitude,
@@ -23,7 +24,7 @@ from utils import plot_detection
23
  from database import save_report, get_report_history
24
 
25
  # Set theme and styling
26
- THEME = gr.themes.Soft(
27
  primary_hue="blue",
28
  secondary_hue="indigo",
29
  neutral_hue="slate",
@@ -31,6 +32,21 @@ THEME = gr.themes.Soft(
31
  text_size=gr.themes.sizes.text_md,
32
  )
33
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  class TechnicalReportGenerator:
35
  def __init__(self):
36
  self.timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
@@ -235,11 +251,38 @@ def create_feature_radar_chart(features):
235
 
236
  return fig
237
 
238
- def process_image(image, generate_tech_report=False):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
239
  if image is None:
240
  raise gr.Error("Please upload an image.")
241
 
242
  # Initialize model if needed
 
243
  global model
244
  model, error = initialize_model()
245
  if error:
@@ -251,9 +294,11 @@ def process_image(image, generate_tech_report=False):
251
  image = Image.fromarray(image)
252
 
253
  # Run detection
 
254
  detection_result = model.detect(image)
255
 
256
  # Extract features
 
257
  np_image = np.array(image)
258
  amplitude = calculate_amplitude(np_image)
259
  amplitude_class = classify_amplitude(amplitude)
@@ -279,14 +324,17 @@ def process_image(image, generate_tech_report=False):
279
  }
280
 
281
  # Create visualization charts
 
282
  confidence_chart = create_confidence_chart(
283
  detection_result.get('scores', []),
284
  detection_result.get('labels', [])
285
  )
286
 
287
  feature_chart = create_feature_radar_chart(features)
 
288
 
289
  # Start performance tracking
 
290
  start_time = time.time()
291
  performance_data = {
292
  'pipeline_stats': {},
@@ -336,6 +384,7 @@ def process_image(image, generate_tech_report=False):
336
  performance_data['gpu_util'] = get_gpu_utilization()
337
 
338
  # Generate analysis report
 
339
  analysis_report = generate_report(detection_result, features)
340
 
341
  # Prepare output
@@ -357,10 +406,12 @@ def process_image(image, generate_tech_report=False):
357
  report_path = "temp_tech_report.md"
358
  with open(report_path, "w") as f:
359
  f.write(tech_report)
360
-
361
- return output_image, analysis_report, report_path, confidence_chart, feature_chart
362
 
363
- return output_image, analysis_report, None, confidence_chart, feature_chart
 
 
 
 
364
 
365
  except Exception as e:
366
  error_msg = f"Error processing image: {str(e)}"
@@ -404,9 +455,24 @@ def get_gpu_utilization():
404
  pass
405
  return 0
406
 
 
 
 
 
 
 
 
 
 
 
407
  # Create Gradio interface
408
- with gr.Blocks(theme=THEME) as iface:
409
- gr.Markdown("# Radar Image Analysis System")
 
 
 
 
 
410
  gr.Markdown("Upload a radar image to analyze defects and generate technical reports")
411
 
412
  with gr.Tabs() as tabs:
@@ -417,7 +483,9 @@ with gr.Blocks(theme=THEME) as iface:
417
  input_image = gr.Image(
418
  type="pil",
419
  label="Upload Radar Image",
420
- elem_id="input-image"
 
 
421
  )
422
  tech_report_checkbox = gr.Checkbox(
423
  label="Generate Technical Report",
@@ -460,6 +528,12 @@ with gr.Blocks(theme=THEME) as iface:
460
  label="Feature Analysis",
461
  elem_id="feature-plot"
462
  )
 
 
 
 
 
 
463
 
464
  with gr.TabItem("History", id="history"):
465
  with gr.Row():
@@ -483,6 +557,11 @@ with gr.Blocks(theme=THEME) as iface:
483
 
484
  This system uses PaliGemma, a vision-language model that combines SigLIP-So400m (image encoder) and Gemma-2B (text decoder) for joint object detection and multimodal analysis.
485
 
 
 
 
 
 
486
  ## Troubleshooting
487
 
488
  - If the analysis fails, try uploading a different image format
@@ -491,10 +570,17 @@ with gr.Blocks(theme=THEME) as iface:
491
  """)
492
 
493
  # Set up event handlers
 
 
 
 
 
 
 
494
  analyze_button.click(
495
- fn=process_image,
496
  inputs=[input_image, tech_report_checkbox],
497
- outputs=[output_image, output_report, tech_report_output, confidence_plot, feature_plot],
498
  api_name="analyze"
499
  )
500
 
@@ -512,6 +598,9 @@ with gr.Blocks(theme=THEME) as iface:
512
  if (e.key === 'a' && e.ctrlKey) {
513
  document.getElementById('analyze-btn').click();
514
  }
 
 
 
515
  });
516
  }
517
  """)
 
11
  import plotly.express as px
12
  import plotly.graph_objects as go
13
  import pandas as pd
14
+ from functools import partial
15
 
16
  from model import RadarDetectionModel
17
  from feature_extraction import (calculate_amplitude, classify_amplitude,
 
24
  from database import save_report, get_report_history
25
 
26
  # Set theme and styling
27
+ LIGHT_THEME = gr.themes.Soft(
28
  primary_hue="blue",
29
  secondary_hue="indigo",
30
  neutral_hue="slate",
 
32
  text_size=gr.themes.sizes.text_md,
33
  )
34
 
35
+ DARK_THEME = gr.themes.Soft(
36
+ primary_hue="blue",
37
+ secondary_hue="indigo",
38
+ neutral_hue="slate",
39
+ radius_size=gr.themes.sizes.radius_sm,
40
+ text_size=gr.themes.sizes.text_md,
41
+ ).set(
42
+ body_background_fill="*neutral_950",
43
+ background_fill_primary="*neutral_900",
44
+ background_fill_secondary="*neutral_800",
45
+ text_color="*neutral_200",
46
+ color_accent_soft="*primary_800",
47
+ border_color_accent_subdued="*primary_700",
48
+ )
49
+
50
  class TechnicalReportGenerator:
51
  def __init__(self):
52
  self.timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
 
251
 
252
  return fig
253
 
254
+ def create_heatmap(image_array):
255
+ """Create a heatmap visualization of the image intensity"""
256
+ if image_array is None:
257
+ return None
258
+
259
+ # Convert to grayscale if needed
260
+ if len(image_array.shape) == 3 and image_array.shape[2] == 3:
261
+ gray_img = np.mean(image_array, axis=2)
262
+ else:
263
+ gray_img = image_array
264
+
265
+ fig = px.imshow(
266
+ gray_img,
267
+ color_continuous_scale='inferno',
268
+ title='Signal Intensity Heatmap'
269
+ )
270
+
271
+ fig.update_layout(
272
+ xaxis_title='X Position',
273
+ yaxis_title='Y Position',
274
+ template='plotly_white'
275
+ )
276
+
277
+ return fig
278
+
279
+ def process_image_streaming(image, generate_tech_report=False, progress=gr.Progress()):
280
+ """Process image with streaming progress updates"""
281
  if image is None:
282
  raise gr.Error("Please upload an image.")
283
 
284
  # Initialize model if needed
285
+ progress(0.1, desc="Initializing model...")
286
  global model
287
  model, error = initialize_model()
288
  if error:
 
294
  image = Image.fromarray(image)
295
 
296
  # Run detection
297
+ progress(0.2, desc="Running detection...")
298
  detection_result = model.detect(image)
299
 
300
  # Extract features
301
+ progress(0.3, desc="Extracting features...")
302
  np_image = np.array(image)
303
  amplitude = calculate_amplitude(np_image)
304
  amplitude_class = classify_amplitude(amplitude)
 
324
  }
325
 
326
  # Create visualization charts
327
+ progress(0.5, desc="Creating visualizations...")
328
  confidence_chart = create_confidence_chart(
329
  detection_result.get('scores', []),
330
  detection_result.get('labels', [])
331
  )
332
 
333
  feature_chart = create_feature_radar_chart(features)
334
+ heatmap = create_heatmap(np_image)
335
 
336
  # Start performance tracking
337
+ progress(0.6, desc="Analyzing performance...")
338
  start_time = time.time()
339
  performance_data = {
340
  'pipeline_stats': {},
 
384
  performance_data['gpu_util'] = get_gpu_utilization()
385
 
386
  # Generate analysis report
387
+ progress(0.8, desc="Generating reports...")
388
  analysis_report = generate_report(detection_result, features)
389
 
390
  # Prepare output
 
406
  report_path = "temp_tech_report.md"
407
  with open(report_path, "w") as f:
408
  f.write(tech_report)
 
 
409
 
410
+ progress(1.0, desc="Analysis complete!")
411
+ return output_image, analysis_report, report_path, confidence_chart, feature_chart, heatmap
412
+
413
+ progress(1.0, desc="Analysis complete!")
414
+ return output_image, analysis_report, None, confidence_chart, feature_chart, heatmap
415
 
416
  except Exception as e:
417
  error_msg = f"Error processing image: {str(e)}"
 
455
  pass
456
  return 0
457
 
458
+ def toggle_dark_mode():
459
+ """Toggle between light and dark themes"""
460
+ current_theme = getattr(toggle_dark_mode, "current_theme", "light")
461
+ if current_theme == "light":
462
+ toggle_dark_mode.current_theme = "dark"
463
+ return DARK_THEME
464
+ else:
465
+ toggle_dark_mode.current_theme = "light"
466
+ return LIGHT_THEME
467
+
468
  # Create Gradio interface
469
+ with gr.Blocks(theme=LIGHT_THEME) as iface:
470
+ theme_state = gr.State(LIGHT_THEME)
471
+
472
+ with gr.Row():
473
+ gr.Markdown("# Radar Image Analysis System")
474
+ dark_mode_btn = gr.Button("🌓 Toggle Dark Mode", scale=0)
475
+
476
  gr.Markdown("Upload a radar image to analyze defects and generate technical reports")
477
 
478
  with gr.Tabs() as tabs:
 
483
  input_image = gr.Image(
484
  type="pil",
485
  label="Upload Radar Image",
486
+ elem_id="input-image",
487
+ sources=["upload", "webcam", "clipboard"],
488
+ tool="editor"
489
  )
490
  tech_report_checkbox = gr.Checkbox(
491
  label="Generate Technical Report",
 
528
  label="Feature Analysis",
529
  elem_id="feature-plot"
530
  )
531
+
532
+ with gr.Row():
533
+ heatmap_plot = gr.Plot(
534
+ label="Signal Intensity Heatmap",
535
+ elem_id="heatmap-plot"
536
+ )
537
 
538
  with gr.TabItem("History", id="history"):
539
  with gr.Row():
 
557
 
558
  This system uses PaliGemma, a vision-language model that combines SigLIP-So400m (image encoder) and Gemma-2B (text decoder) for joint object detection and multimodal analysis.
559
 
560
+ ## Keyboard Shortcuts
561
+
562
+ - **Ctrl+A**: Trigger analysis
563
+ - **Ctrl+D**: Toggle dark mode
564
+
565
  ## Troubleshooting
566
 
567
  - If the analysis fails, try uploading a different image format
 
570
  """)
571
 
572
  # Set up event handlers
573
+ dark_mode_btn.click(
574
+ fn=toggle_dark_mode,
575
+ inputs=[],
576
+ outputs=[iface],
577
+ api_name="toggle_theme"
578
+ )
579
+
580
  analyze_button.click(
581
+ fn=process_image_streaming,
582
  inputs=[input_image, tech_report_checkbox],
583
+ outputs=[output_image, output_report, tech_report_output, confidence_plot, feature_plot, heatmap_plot],
584
  api_name="analyze"
585
  )
586
 
 
598
  if (e.key === 'a' && e.ctrlKey) {
599
  document.getElementById('analyze-btn').click();
600
  }
601
+ if (e.key === 'd' && e.ctrlKey) {
602
+ document.querySelector('button:contains("Toggle Dark Mode")').click();
603
+ }
604
  });
605
  }
606
  """)
requirements.txt CHANGED
@@ -1,16 +1,22 @@
1
  gradio>=5.18.0
2
- torch>=2.1.0
3
- transformers>=4.36.0
4
- Pillow>=10.1.0
5
- numpy>=1.26.0
6
- matplotlib>=3.8.0
7
- pandas>=2.1.0
8
- sqlalchemy>=2.0.23
9
  plotly>=5.18.0
10
  scikit-learn>=1.3.2
11
- jinja2>=3.1.2
12
- huggingface-hub>=0.19.4
13
  python-dotenv>=1.0.0
14
  markdown>=3.5.1
15
  psutil>=5.9.6
16
- tqdm>=4.66.1
 
 
 
 
 
 
 
1
  gradio>=5.18.0
2
+ torch>=2.1.2
3
+ transformers>=4.37.2
4
+ Pillow>=10.2.0
5
+ numpy>=1.26.3
6
+ matplotlib>=3.8.2
7
+ pandas>=2.1.4
8
+ sqlalchemy>=2.0.25
9
  plotly>=5.18.0
10
  scikit-learn>=1.3.2
11
+ jinja2>=3.1.3
12
+ huggingface-hub>=0.20.2
13
  python-dotenv>=1.0.0
14
  markdown>=3.5.1
15
  psutil>=5.9.6
16
+ tqdm>=4.66.1
17
+ accelerate>=0.25.0
18
+ safetensors>=0.4.1
19
+ peft>=0.7.1
20
+ optimum>=1.14.0
21
+ colorama>=0.4.6
22
+ rich>=13.7.0