prateekbh commited on
Commit
c5f4497
1 Parent(s): 6bf6d32

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +74 -10
app.py CHANGED
@@ -3,19 +3,29 @@ import gradio as gr
3
  import torch
4
  from transformers import AutoModel, AutoProcessor
5
  from transformers import StoppingCriteria, TextIteratorStreamer, StoppingCriteriaList
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
 
7
 
8
  device = "cuda:0" if torch.cuda.is_available() else "cpu"
9
 
10
-
11
- title = """<h1 style="text-align: center;">Product description generator</h1>"""
12
- css = """
13
- div#col-container {
14
- margin: 0 auto;
15
- max-width: 840px;
16
- }
17
- """
18
-
19
  model = AutoModel.from_pretrained("unum-cloud/uform-gen2-qwen-500m", trust_remote_code=True).to(device)
20
  processor = AutoProcessor.from_pretrained("unum-cloud/uform-gen2-qwen-500m", trust_remote_code=True)
21
 
@@ -83,6 +93,53 @@ def response(history, image):
83
  history[-1][1] = partial_response
84
  yield history
85
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86
  with gr.Blocks(css=css) as demo:
87
  gr.HTML(title)
88
  with gr.Row():
@@ -92,7 +149,7 @@ with gr.Blocks(css=css) as demo:
92
  chat = gr.Chatbot(show_label=False)
93
  submit = gr.Button(value="Upload", variant="primary")
94
  with gr.Column():
95
- output = gr.Image(type="pil")
96
 
97
  response_handler = (
98
  response,
@@ -100,6 +157,12 @@ with gr.Blocks(css=css) as demo:
100
  [chat]
101
  )
102
 
 
 
 
 
 
 
103
  # postresponse_handler = (
104
  # lambda: (gr.Button(visible=False), gr.Button(visible=True)),
105
  # None,
@@ -107,6 +170,7 @@ with gr.Blocks(css=css) as demo:
107
  # )
108
 
109
  event = submit.click(*response_handler)
 
110
  # event.then(*postresponse_handler)
111
 
112
  demo.launch()
 
3
  import torch
4
  from transformers import AutoModel, AutoProcessor
5
  from transformers import StoppingCriteria, TextIteratorStreamer, StoppingCriteriaList
6
+ import numpy as np
7
+ import torch.nn.functional as F
8
+ from torchvision.transforms.functional import normalize
9
+ from huggingface_hub import hf_hub_download
10
+ from briarmbg import BriaRMBG
11
+ import PIL
12
+ from PIL import Image
13
+ from typing import Tuple
14
+
15
+
16
+ net=BriaRMBG()
17
+ # model_path = "./model1.pth"
18
+ model_path = hf_hub_download("briaai/RMBG-1.4", 'model.pth')
19
+ if torch.cuda.is_available():
20
+ net.load_state_dict(torch.load(model_path))
21
+ net=net.cuda()
22
+ else:
23
+ net.load_state_dict(torch.load(model_path,map_location="cpu"))
24
+ net.eval()
25
 
26
 
27
  device = "cuda:0" if torch.cuda.is_available() else "cpu"
28
 
 
 
 
 
 
 
 
 
 
29
  model = AutoModel.from_pretrained("unum-cloud/uform-gen2-qwen-500m", trust_remote_code=True).to(device)
30
  processor = AutoProcessor.from_pretrained("unum-cloud/uform-gen2-qwen-500m", trust_remote_code=True)
31
 
 
93
  history[-1][1] = partial_response
94
  yield history
95
 
96
+ def resize_image(image):
97
+ image = image.convert('RGB')
98
+ model_input_size = (1024, 1024)
99
+ image = image.resize(model_input_size, Image.BILINEAR)
100
+ return image
101
+
102
+
103
+ def process(image):
104
+
105
+ # prepare input
106
+ orig_image = Image.fromarray(image)
107
+ w,h = orig_im_size = orig_image.size
108
+ image = resize_image(orig_image)
109
+ im_np = np.array(image)
110
+ im_tensor = torch.tensor(im_np, dtype=torch.float32).permute(2,0,1)
111
+ im_tensor = torch.unsqueeze(im_tensor,0)
112
+ im_tensor = torch.divide(im_tensor,255.0)
113
+ im_tensor = normalize(im_tensor,[0.5,0.5,0.5],[1.0,1.0,1.0])
114
+ if torch.cuda.is_available():
115
+ im_tensor=im_tensor.cuda()
116
+
117
+ #inference
118
+ result=net(im_tensor)
119
+ # post process
120
+ result = torch.squeeze(F.interpolate(result[0][0], size=(h,w), mode='bilinear') ,0)
121
+ ma = torch.max(result)
122
+ mi = torch.min(result)
123
+ result = (result-mi)/(ma-mi)
124
+ # image to pil
125
+ im_array = (result*255).cpu().data.numpy().astype(np.uint8)
126
+ pil_im = Image.fromarray(np.squeeze(im_array))
127
+ # paste the mask on the original image
128
+ new_im = Image.new("RGBA", pil_im.size, (0,0,0,0))
129
+ new_im.paste(orig_image, mask=pil_im)
130
+ # new_orig_image = orig_image.convert('RGBA')
131
+
132
+ return new_im
133
+
134
+
135
+ title = """<h1 style="text-align: center;">Product description generator</h1>"""
136
+ css = """
137
+ div#col-container {
138
+ margin: 0 auto;
139
+ max-width: 840px;
140
+ }
141
+ """
142
+
143
  with gr.Blocks(css=css) as demo:
144
  gr.HTML(title)
145
  with gr.Row():
 
149
  chat = gr.Chatbot(show_label=False)
150
  submit = gr.Button(value="Upload", variant="primary")
151
  with gr.Column():
152
+ output = gr.Image(type="pil", sources="none")
153
 
154
  response_handler = (
155
  response,
 
157
  [chat]
158
  )
159
 
160
+ background_remover_handler = (
161
+ process,
162
+ [image],
163
+ [output]
164
+ )
165
+
166
  # postresponse_handler = (
167
  # lambda: (gr.Button(visible=False), gr.Button(visible=True)),
168
  # None,
 
170
  # )
171
 
172
  event = submit.click(*response_handler)
173
+ event2 = submit.click(*background_remover_handler)
174
  # event.then(*postresponse_handler)
175
 
176
  demo.launch()