Amit Gazal commited on
Commit
e25fc54
·
1 Parent(s): ef1b761
Files changed (2) hide show
  1. app.py +67 -0
  2. requirements.txt +8 -0
app.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from PIL import Image
3
+ import matplotlib.pyplot as plt
4
+ import torch
5
+ from torchvision import transforms
6
+ from transformers import AutoModelForImageSegmentation
7
+
8
+ model = AutoModelForImageSegmentation.from_pretrained('briaai/RMBG-2.0', trust_remote_code=True)
9
+ torch.set_float32_matmul_precision(['high', 'highest'][0])
10
+ if torch.cuda.is_available():
11
+ model = model.to('cuda')
12
+ model.eval()
13
+
14
+ def remove_background(input_image, holiday, message):
15
+ image_size = (1024, 1024)
16
+ # Transform the input image
17
+ transform_image = transforms.Compose([
18
+ transforms.Resize(image_size),
19
+ transforms.ToTensor(),
20
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
21
+ ])
22
+
23
+ # Process the image
24
+ input_tensor = transform_image(input_image).unsqueeze(0)
25
+ if torch.cuda.is_available():
26
+ input_tensor = input_tensor.to('cuda')
27
+
28
+ # Generate prediction
29
+ with torch.no_grad():
30
+ preds = model(input_tensor)[-1].sigmoid().cpu()
31
+ pred = preds[0].squeeze()
32
+ pred_pil = transforms.ToPILImage()(pred)
33
+ mask = pred_pil.resize(input_image.size)
34
+
35
+ # Create image without background
36
+ result_image = input_image.copy()
37
+ result_image.putalpha(mask)
38
+
39
+ # Create image with only background
40
+ only_background_image = input_image.copy()
41
+ inverted_mask = Image.eval(mask, lambda x: 255 - x) # Invert the mask
42
+ only_background_image.putalpha(inverted_mask)
43
+
44
+ first_output_image = result_image
45
+ second_output_image = only_background_image
46
+ third_output_image = result_image
47
+
48
+ return first_output_image, second_output_image, third_output_image
49
+
50
+ # Replace the demo interface
51
+ demo = gr.Interface(
52
+ fn=remove_background,
53
+ inputs=[
54
+ gr.Image(type="pil"),
55
+ gr.Text(label="Holiday (e.g. Christmas, New Year's, etc.)"),
56
+ gr.Text(label="Optional Message", placeholder="Enter your holiday message here...")
57
+ ],
58
+ outputs=[
59
+ gr.Image(type="pil", label="First Output"),
60
+ gr.Image(type="pil", label="Second Output"),
61
+ gr.Image(type="pil", label="Third Output")
62
+ ],
63
+ title="Holiday Card Generator",
64
+ description="Upload an image to generate a holiday card"
65
+ )
66
+
67
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ gradio
2
+ torch
3
+ torchvision
4
+ pillow
5
+ kornia
6
+ transformers
7
+ timm
8
+ matplotlib