Amit Gazal commited on
Commit
6a046fb
·
1 Parent(s): 7bce5c7
Files changed (2) hide show
  1. app.py +65 -0
  2. requirements.txt +7 -0
app.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from PIL import Image
3
+ import matplotlib.pyplot as plt
4
+ import torch
5
+ from torchvision import transforms
6
+ from transformers import AutoModelForImageSegmentation
7
+
8
+ model = AutoModelForImageSegmentation.from_pretrained('briaai/RMBG-2.0', trust_remote_code=True)
9
+ torch.set_float32_matmul_precision(['high', 'highest'][0])
10
+ model.eval()
11
+
12
+ def remove_background(input_image, holiday, message):
13
+ image_size = (1024, 1024)
14
+ # Transform the input image
15
+ transform_image = transforms.Compose([
16
+ transforms.Resize(image_size),
17
+ transforms.ToTensor(),
18
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
19
+ ])
20
+
21
+ # Process the image
22
+ input_tensor = transform_image(input_image).unsqueeze(0)
23
+ if torch.cuda.is_available():
24
+ input_tensor = input_tensor.to('cuda')
25
+
26
+ # Generate prediction
27
+ with torch.no_grad():
28
+ preds = model(input_tensor)[-1].sigmoid().cpu()
29
+ pred = preds[0].squeeze()
30
+ pred_pil = transforms.ToPILImage()(pred)
31
+ mask = pred_pil.resize(input_image.size)
32
+
33
+ # Create image without background
34
+ result_image = input_image.copy()
35
+ result_image.putalpha(mask)
36
+
37
+ # Create image with only background
38
+ only_background_image = input_image.copy()
39
+ inverted_mask = Image.eval(mask, lambda x: 255 - x) # Invert the mask
40
+ only_background_image.putalpha(inverted_mask)
41
+
42
+ first_output_image = result_image
43
+ second_output_image = only_background_image
44
+ third_output_image = result_image
45
+
46
+ return first_output_image, second_output_image, third_output_image
47
+
48
+ # Replace the demo interface
49
+ demo = gr.Interface(
50
+ fn=remove_background,
51
+ inputs=[
52
+ gr.Image(type="pil"),
53
+ gr.Text(label="Holiday (e.g. Christmas, New Year's, etc.)"),
54
+ gr.Text(label="Optional Message", placeholder="Enter your holiday message here...")
55
+ ],
56
+ outputs=[
57
+ gr.Image(type="pil", label="First Output"),
58
+ gr.Image(type="pil", label="Second Output"),
59
+ gr.Image(type="pil", label="Third Output")
60
+ ],
61
+ title="Background Removal Tool",
62
+ description="Upload an image to remove its background"
63
+ )
64
+
65
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ gradio
2
+ torch
3
+ torchvision
4
+ pillow
5
+ kornia
6
+ transformers
7
+ timm