Yiming-M commited on
Commit
e85ffa0
ยท
1 Parent(s): ef5adb7

2025-07-31 21:55 ๐Ÿš€

Browse files
Files changed (1) hide show
  1. app.py +107 -32
app.py CHANGED
@@ -25,22 +25,46 @@ loaded_model = None
25
  current_model_config = {"variant": None, "dataset": None, "metric": None}
26
 
27
  pretrained_models = [
28
- "ZIP-B @ ShanghaiTech A", "ZIP-B @ ShanghaiTech B", "ZIP-B @ UCF-QNRF", "ZIP-B @ NWPU-Crowd",
29
- "ZIP-S @ ShanghaiTech A", "ZIP-S @ ShanghaiTech B", "ZIP-S @ UCF-QNRF",
30
- "ZIP-T @ ShanghaiTech A", "ZIP-T @ ShanghaiTech B", "ZIP-T @ UCF-QNRF",
31
- "ZIP-N @ ShanghaiTech A", "ZIP-N @ ShanghaiTech B", "ZIP-N @ UCF-QNRF",
32
- "ZIP-P @ ShanghaiTech A", "ZIP-P @ ShanghaiTech B", "ZIP-P @ UCF-QNRF"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  ]
34
 
35
  # -----------------------------
36
  # Model management functions
37
  # -----------------------------
38
- def update_model_if_needed(variant_dataset: str, metric: str):
39
  """
40
  Load a new model only if the configuration has changed.
41
  """
42
  global loaded_model, current_model_config
43
- variant, dataset = variant_dataset.split(" @ ")
 
 
 
 
 
 
 
 
 
44
 
45
  if dataset == "ShanghaiTech A":
46
  dataset_name = "sha"
@@ -50,6 +74,8 @@ def update_model_if_needed(variant_dataset: str, metric: str):
50
  dataset_name = "qnrf"
51
  elif dataset == "NWPU-Crowd":
52
  dataset_name = "nwpu"
 
 
53
 
54
  if (loaded_model is None or
55
  current_model_config["variant"] != variant or
@@ -271,17 +297,25 @@ def _sliding_window_predict(
271
  # Inference function
272
  # -----------------------------
273
  @spaces.GPU(duration=120)
274
- def predict(image: Image.Image, variant_dataset: str, metric: str):
275
  """
276
  Given an input image, preprocess it, run the model to obtain a density map,
277
  compute the total crowd count, and prepare the density map for display.
278
  """
279
  global loaded_model, current_model_config
280
 
 
 
 
 
281
  # ็กฎไฟๆจกๅž‹ๆญฃ็กฎๅŠ ่ฝฝ
282
- update_model_if_needed(variant_dataset, metric)
283
 
284
- variant, dataset = variant_dataset.split(" @ ")
 
 
 
 
285
 
286
  if dataset == "ShanghaiTech A":
287
  dataset_name = "sha"
@@ -291,6 +325,8 @@ def predict(image: Image.Image, variant_dataset: str, metric: str):
291
  dataset_name = "qnrf"
292
  elif dataset == "NWPU-Crowd":
293
  dataset_name = "nwpu"
 
 
294
 
295
  if not hasattr(loaded_model, "input_size"):
296
  if dataset_name == "sha":
@@ -424,7 +460,57 @@ def predict(image: Image.Image, variant_dataset: str, metric: str):
424
  # -----------------------------
425
  # Build Gradio Interface using Blocks for a two-column layout
426
  # -----------------------------
427
- with gr.Blocks() as demo:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
428
  gr.Markdown("# Crowd Counting by ZIP")
429
  gr.Markdown("Upload an image or select an example below to see the predicted crowd density map and total count.")
430
 
@@ -433,17 +519,9 @@ with gr.Blocks() as demo:
433
  # Dropdown for model variant
434
  model_dropdown = gr.Dropdown(
435
  choices=pretrained_models,
436
- value="ZIP-B @ NWPU-Crowd",
437
  label="Select a pretrained model"
438
  )
439
-
440
- # Dropdown for metric, always the same choices
441
- metric_dropdown = gr.Dropdown(
442
- choices=["mae", "rmse", "nae"],
443
- value="mae",
444
- label="Select Best Metric"
445
- )
446
-
447
  model_status = gr.Textbox(
448
  label="Model Status",
449
  value="No model loaded",
@@ -463,31 +541,28 @@ with gr.Blocks() as demo:
463
  output_sampling_zero_map = gr.Image(label="Sampling Zero Map", type="pil")
464
  output_complete_zero_map = gr.Image(label="Complete Zero Map", type="pil")
465
 
466
- # ๅฝ“ๆจกๅž‹ๆˆ–ๅบฆ้‡ๅ‚ๆ•ฐๅ˜ๅŒ–ๆ—ถ๏ผŒ่‡ชๅŠจๆ›ดๆ–ฐๆจกๅž‹
467
- def on_model_change(variant_dataset, metric):
468
- return update_model_if_needed(variant_dataset, metric)
 
 
 
469
 
470
  model_dropdown.change(
471
  fn=on_model_change,
472
- inputs=[model_dropdown, metric_dropdown],
473
- outputs=[model_status]
474
- )
475
-
476
- metric_dropdown.change(
477
- fn=on_model_change,
478
- inputs=[model_dropdown, metric_dropdown],
479
  outputs=[model_status]
480
  )
481
 
482
  # ้กต้ขๅŠ ่ฝฝๆ—ถ่‡ชๅŠจๅŠ ่ฝฝ้ป˜่ฎคๆจกๅž‹
483
  demo.load(
484
- fn=lambda: update_model_if_needed("ZIP-B @ NWPU-Crowd", "mae"),
485
  outputs=[model_status]
486
  )
487
 
488
  submit_btn.click(
489
  fn=predict,
490
- inputs=[input_img, model_dropdown, metric_dropdown],
491
  outputs=[input_img, output_den_map, output_lambda_map, output_text, output_structural_zero_map, output_sampling_zero_map, output_complete_zero_map]
492
  )
493
 
 
25
  current_model_config = {"variant": None, "dataset": None, "metric": None}
26
 
27
  pretrained_models = [
28
+ "ZIP-B @ ShanghaiTech A @ MAE", "ZIP-B @ ShanghaiTech A @ RMSE", "ZIP-B @ ShanghaiTech A @ NAE",
29
+ "ZIP-B @ ShanghaiTech B @ MAE", "ZIP-B @ ShanghaiTech B @ RMSE", "ZIP-B @ ShanghaiTech B @ NAE",
30
+ "ZIP-B @ UCF-QNRF @ MAE", "ZIP-B @ UCF-QNRF @ RMSE", "ZIP-B @ UCF-QNRF @ NAE",
31
+ "ZIP-B @ NWPU-Crowd @ MAE", "ZIP-B @ NWPU-Crowd @ RMSE", "ZIP-B @ NWPU-Crowd @ NAE",
32
+ "โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”",
33
+ "ZIP-S @ ShanghaiTech A @ MAE", "ZIP-S @ ShanghaiTech A @ RMSE", "ZIP-S @ ShanghaiTech A @ NAE",
34
+ "ZIP-S @ ShanghaiTech B @ MAE", "ZIP-S @ ShanghaiTech B @ RMSE", "ZIP-S @ ShanghaiTech B @ NAE",
35
+ "ZIP-S @ UCF-QNRF @ MAE", "ZIP-S @ UCF-QNRF @ RMSE", "ZIP-S @ UCF-QNRF @ NAE",
36
+ "โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”",
37
+ "ZIP-T @ ShanghaiTech A @ MAE", "ZIP-T @ ShanghaiTech A @ RMSE", "ZIP-T @ ShanghaiTech A @ NAE",
38
+ "ZIP-T @ ShanghaiTech B @ MAE", "ZIP-T @ ShanghaiTech B @ RMSE", "ZIP-T @ ShanghaiTech B @ NAE",
39
+ "ZIP-T @ UCF-QNRF @ MAE", "ZIP-T @ UCF-QNRF @ RMSE", "ZIP-T @ UCF-QNRF @ NAE",
40
+ "โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”",
41
+ "ZIP-N @ ShanghaiTech A @ MAE", "ZIP-N @ ShanghaiTech A @ RMSE", "ZIP-N @ ShanghaiTech A @ NAE",
42
+ "ZIP-N @ ShanghaiTech B @ MAE", "ZIP-N @ ShanghaiTech B @ RMSE", "ZIP-N @ ShanghaiTech B @ NAE",
43
+ "ZIP-N @ UCF-QNRF @ MAE", "ZIP-N @ UCF-QNRF @ RMSE", "ZIP-N @ UCF-QNRF @ NAE",
44
+ "โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”",
45
+ "ZIP-P @ ShanghaiTech A @ MAE", "ZIP-P @ ShanghaiTech A @ RMSE", "ZIP-P @ ShanghaiTech A @ NAE",
46
+ "ZIP-P @ ShanghaiTech B @ MAE", "ZIP-P @ ShanghaiTech B @ RMSE", "ZIP-P @ ShanghaiTech B @ NAE",
47
+ "ZIP-P @ UCF-QNRF @ MAE", "ZIP-P @ UCF-QNRF @ RMSE", "ZIP-P @ UCF-QNRF @ NAE",
48
  ]
49
 
50
  # -----------------------------
51
  # Model management functions
52
  # -----------------------------
53
+ def update_model_if_needed(variant_dataset_metric: str):
54
  """
55
  Load a new model only if the configuration has changed.
56
  """
57
  global loaded_model, current_model_config
58
+
59
+ # ๅฆ‚ๆžœๆ˜ฏๅˆ†ๅ‰ฒ็บฟ๏ผŒๅˆ™่ทณ่ฟ‡
60
+ if "โ”โ”โ”โ”โ”โ”" in variant_dataset_metric:
61
+ return "Please select a valid model configuration"
62
+
63
+ parts = variant_dataset_metric.split(" @ ")
64
+ if len(parts) != 3:
65
+ return "Invalid model configuration format"
66
+
67
+ variant, dataset, metric = parts[0], parts[1], parts[2].lower()
68
 
69
  if dataset == "ShanghaiTech A":
70
  dataset_name = "sha"
 
74
  dataset_name = "qnrf"
75
  elif dataset == "NWPU-Crowd":
76
  dataset_name = "nwpu"
77
+ else:
78
+ return f"Unknown dataset: {dataset}"
79
 
80
  if (loaded_model is None or
81
  current_model_config["variant"] != variant or
 
297
  # Inference function
298
  # -----------------------------
299
  @spaces.GPU(duration=120)
300
+ def predict(image: Image.Image, variant_dataset_metric: str):
301
  """
302
  Given an input image, preprocess it, run the model to obtain a density map,
303
  compute the total crowd count, and prepare the density map for display.
304
  """
305
  global loaded_model, current_model_config
306
 
307
+ # ๅฆ‚ๆžœ้€‰ๆ‹ฉ็š„ๆ˜ฏๅˆ†ๅ‰ฒ็บฟ๏ผŒ่ฟ”ๅ›ž้”™่ฏฏไฟกๆฏ
308
+ if "โ”โ”โ”โ”โ”โ”" in variant_dataset_metric:
309
+ return image, None, None, "Please select a valid model configuration", None, None, None
310
+
311
  # ็กฎไฟๆจกๅž‹ๆญฃ็กฎๅŠ ่ฝฝ
312
+ update_model_if_needed(variant_dataset_metric)
313
 
314
+ parts = variant_dataset_metric.split(" @ ")
315
+ if len(parts) != 3:
316
+ return image, None, None, "Invalid model configuration format", None, None, None
317
+
318
+ variant, dataset, metric = parts[0], parts[1], parts[2].lower()
319
 
320
  if dataset == "ShanghaiTech A":
321
  dataset_name = "sha"
 
325
  dataset_name = "qnrf"
326
  elif dataset == "NWPU-Crowd":
327
  dataset_name = "nwpu"
328
+ else:
329
+ return image, None, None, f"Unknown dataset: {dataset}", None, None, None
330
 
331
  if not hasattr(loaded_model, "input_size"):
332
  if dataset_name == "sha":
 
460
  # -----------------------------
461
  # Build Gradio Interface using Blocks for a two-column layout
462
  # -----------------------------
463
+ css = """
464
+ /* ๅˆ†ๅ‰ฒ็บฟๆ ทๅผ - ็ฐ่‰ฒไธๅฏ้€‰ๆ‹ฉ */
465
+ .dropdown select option[value*="โ”โ”โ”โ”โ”โ”"] {
466
+ color: #999 !important;
467
+ background-color: #f0f0f0 !important;
468
+ font-style: italic !important;
469
+ text-align: center !important;
470
+ pointer-events: none !important;
471
+ cursor: not-allowed !important;
472
+ border: none !important;
473
+ }
474
+
475
+ /* Gradioไธ‹ๆ‹‰่œๅ•ไธญ็š„ๅˆ†ๅ‰ฒ็บฟๆ ทๅผ */
476
+ .gr-dropdown .choices__item[data-value*="โ”โ”โ”โ”โ”โ”"] {
477
+ color: #999 !important;
478
+ background-color: #f0f0f0 !important;
479
+ font-style: italic !important;
480
+ text-align: center !important;
481
+ pointer-events: none !important;
482
+ cursor: not-allowed !important;
483
+ user-select: none !important;
484
+ opacity: 0.6 !important;
485
+ }
486
+
487
+ /* ๆ‚ฌๅœๆ—ถไฟๆŒ็ฐ่‰ฒ */
488
+ .gr-dropdown .choices__item[data-value*="โ”โ”โ”โ”โ”โ”"]:hover {
489
+ background-color: #f0f0f0 !important;
490
+ color: #999 !important;
491
+ cursor: not-allowed !important;
492
+ }
493
+
494
+ /* ้€š็”จ็š„ๅˆ†ๅ‰ฒ็บฟๆ ทๅผ */
495
+ option:disabled {
496
+ color: #999 !important;
497
+ background-color: #f0f0f0 !important;
498
+ font-style: italic !important;
499
+ }
500
+
501
+ /* ไธบๅŒ…ๅซๅˆ†ๅ‰ฒ็บฟๅญ—็ฌฆ็š„้€‰้กนๆทปๅŠ ๆ ทๅผ */
502
+ option[value*="โ”โ”โ”โ”โ”โ”"],
503
+ select option[value*="โ”โ”โ”โ”โ”โ”"] {
504
+ color: #999 !important;
505
+ background-color: #f0f0f0 !important;
506
+ cursor: not-allowed !important;
507
+ pointer-events: none !important;
508
+ text-align: center !important;
509
+ opacity: 0.6 !important;
510
+ }
511
+ """
512
+
513
+ with gr.Blocks(css=css) as demo:
514
  gr.Markdown("# Crowd Counting by ZIP")
515
  gr.Markdown("Upload an image or select an example below to see the predicted crowd density map and total count.")
516
 
 
519
  # Dropdown for model variant
520
  model_dropdown = gr.Dropdown(
521
  choices=pretrained_models,
522
+ value="ZIP-B @ NWPU-Crowd @ MAE",
523
  label="Select a pretrained model"
524
  )
 
 
 
 
 
 
 
 
525
  model_status = gr.Textbox(
526
  label="Model Status",
527
  value="No model loaded",
 
541
  output_sampling_zero_map = gr.Image(label="Sampling Zero Map", type="pil")
542
  output_complete_zero_map = gr.Image(label="Complete Zero Map", type="pil")
543
 
544
+ # ๅฝ“ๆจกๅž‹ๅ˜ๅŒ–ๆ—ถ๏ผŒ่‡ชๅŠจๆ›ดๆ–ฐๆจกๅž‹
545
+ def on_model_change(variant_dataset_metric):
546
+ # ๅฆ‚ๆžœ้€‰ๆ‹ฉ็š„ๆ˜ฏๅˆ†ๅ‰ฒ็บฟ๏ผŒไฟๆŒๅฝ“ๅ‰้€‰ๆ‹ฉไธๅ˜
547
+ if "โ”โ”โ”โ”โ”โ”" in variant_dataset_metric:
548
+ return "Please select a valid model configuration"
549
+ return update_model_if_needed(variant_dataset_metric)
550
 
551
  model_dropdown.change(
552
  fn=on_model_change,
553
+ inputs=[model_dropdown],
 
 
 
 
 
 
554
  outputs=[model_status]
555
  )
556
 
557
  # ้กต้ขๅŠ ่ฝฝๆ—ถ่‡ชๅŠจๅŠ ่ฝฝ้ป˜่ฎคๆจกๅž‹
558
  demo.load(
559
+ fn=lambda: update_model_if_needed("ZIP-B @ NWPU-Crowd @ MAE"),
560
  outputs=[model_status]
561
  )
562
 
563
  submit_btn.click(
564
  fn=predict,
565
+ inputs=[input_img, model_dropdown],
566
  outputs=[input_img, output_den_map, output_lambda_map, output_text, output_structural_zero_map, output_sampling_zero_map, output_complete_zero_map]
567
  )
568