ammariii08 commited on
Commit
6942bb2
·
verified ·
1 Parent(s): ccfaa45

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -14
app.py CHANGED
@@ -119,8 +119,8 @@ def unload_and_reload_models():
119
  new_birefnet = AutoModelForImageSegmentation.from_pretrained(
120
  "zhengpeng7/BiRefNet", trust_remote_code=True, cache_dir=CACHE_DIR
121
  )
122
- # new_birefnet.to(device)
123
- # new_birefnet.eval()
124
  # new_u2net = U2NETP(3, 1)
125
  # new_u2net.load_state_dict(torch.load(os.path.join(CACHE_DIR, "u2netp.pth"), map_location="cpu"))
126
  # new_u2net.to(device)
@@ -455,13 +455,10 @@ def predict(
455
  except Exception:
456
  raise ValueError("Invalid base64 image data")
457
 
458
- # Preprocessing: reduce brightness to 0.5 and enhance sharpness.
459
  if isinstance(image, np.ndarray):
460
  pil_image = Image.fromarray(image)
461
- # Reduce brightness.
462
- dark_image = ImageEnhance.Brightness(pil_image).enhance(0.5)
463
- # Enhance sharpness.
464
- enhanced_image = ImageEnhance.Sharpness(dark_image).enhance(1)
465
  image = np.array(enhanced_image)
466
 
467
  # ---------------------
@@ -682,13 +679,6 @@ def predict(
682
  cv2.LINE_AA
683
  )
684
 
685
- # Restore brightness for display purposes:
686
- # Since we reduced brightness by 0.5 during preprocessing,
687
- # we apply an enhancement factor of 2.0 here to bring it back.
688
- display_img = Image.fromarray(output_img)
689
- display_img = ImageEnhance.Brightness(display_img).enhance(2.0)
690
- output_img = np.array(display_img)
691
-
692
  outlines_color = cv2.cvtColor(new_outlines, cv2.COLOR_BGR2RGB)
693
  print("Total prediction time: {:.2f} seconds".format(time.time() - overall_start))
694
 
 
119
  new_birefnet = AutoModelForImageSegmentation.from_pretrained(
120
  "zhengpeng7/BiRefNet", trust_remote_code=True, cache_dir=CACHE_DIR
121
  )
122
+ new_birefnet.to(device)
123
+ new_birefnet.eval()
124
  # new_u2net = U2NETP(3, 1)
125
  # new_u2net.load_state_dict(torch.load(os.path.join(CACHE_DIR, "u2netp.pth"), map_location="cpu"))
126
  # new_u2net.to(device)
 
455
  except Exception:
456
  raise ValueError("Invalid base64 image data")
457
 
458
+ # Apply brightness and sharpness enhancement.
459
  if isinstance(image, np.ndarray):
460
  pil_image = Image.fromarray(image)
461
+ enhanced_image = ImageEnhance.Sharpness(Bright).enhance(0.5)
 
 
 
462
  image = np.array(enhanced_image)
463
 
464
  # ---------------------
 
679
  cv2.LINE_AA
680
  )
681
 
 
 
 
 
 
 
 
682
  outlines_color = cv2.cvtColor(new_outlines, cv2.COLOR_BGR2RGB)
683
  print("Total prediction time: {:.2f} seconds".format(time.time() - overall_start))
684