OriLib commited on
Commit
50bce24
1 Parent(s): 62f92ff

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +4 -4
README.md CHANGED
@@ -114,10 +114,10 @@ import torch
114
  from torchvision import transforms
115
  from transformers import AutoModelForImageSegmentation
116
 
117
- birefnet = AutoModelForImageSegmentation.from_pretrained('briaai/RMBG-2.0', trust_remote_code=True)
118
  torch.set_float32_matmul_precision(['high', 'highest'][0])
119
- birefnet.to('cuda')
120
- birefnet.eval()
121
 
122
  # Data settings
123
  image_size = (1024, 1024)
@@ -132,7 +132,7 @@ input_images = transform_image(image).unsqueeze(0).to('cuda')
132
 
133
  # Prediction
134
  with torch.no_grad():
135
- preds = birefnet(input_images)[-1].sigmoid().cpu()
136
  pred = preds[0].squeeze()
137
  pred_pil = transforms.ToPILImage()(pred)
138
  mask = pred_pil.resize(image.size)
 
114
  from torchvision import transforms
115
  from transformers import AutoModelForImageSegmentation
116
 
117
+ model = AutoModelForImageSegmentation.from_pretrained('briaai/RMBG-2.0', trust_remote_code=True)
118
  torch.set_float32_matmul_precision(['high', 'highest'][0])
119
+ model.to('cuda')
120
+ model.eval()
121
 
122
  # Data settings
123
  image_size = (1024, 1024)
 
132
 
133
  # Prediction
134
  with torch.no_grad():
135
+ preds = model(input_images)[-1].sigmoid().cpu()
136
  pred = preds[0].squeeze()
137
  pred_pil = transforms.ToPILImage()(pred)
138
  mask = pred_pil.resize(image.size)