File size: 7,238 Bytes
563f98d 7d58261 c5f4497 6e02423 95b0407 6e02423 c5f4497 c49cbd7 c5f4497 563f98d e0c81f0 563f98d 7d58261 7beb62a 2b76edc 040bdbd 23f3832 040bdbd 6dd0a79 59348e4 6dd0a79 48030b2 734cb58 ad4c8af 9492880 1bbc851 572b329 2b76edc 63db7c6 f617eac 63db7c6 f617eac 63db7c6 2b76edc 55dc744 2b76edc dd1124f 48030b2 572b329 7d58261 3fb1f8e fc81c90 29ba44d 7d58261 cb86cd4 605b0aa 7d58261 cb86cd4 3fb1f8e 7d58261 3fb1f8e 2eee6e2 b15ce46 7d58261 c5f4497 aeba1bd 32a8e2c 688c0f1 c5f4497 c739636 6bf6d32 c739636 5c80ad3 040bdbd c739636 6bf6d32 ad4c8af 252e8ea 59348e4 040bdbd 23f3832 fb7a950 bc3802f 59348e4 ad4c8af 2cb6eab ad4c8af c5f4497 604742b c5f4497 14eb553 60eaa44 ad4c8af c4d9f0b e80c4ee a89d883 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 |
import gradio as gr
import torch
import numpy as np
import torch.nn.functional as F
import PIL
import random
from threading import Thread
from transformers import AutoModel, AutoProcessor
from transformers import StoppingCriteria, TextIteratorStreamer, StoppingCriteriaList
from torchvision.transforms.functional import normalize
from huggingface_hub import hf_hub_download, InferenceClient
from briarmbg import BriaRMBG
from PIL import Image
from typing import Tuple
net=BriaRMBG()
# model_path = "./model1.pth"
model_path = hf_hub_download("briaai/RMBG-1.4", 'model.pth')
if torch.cuda.is_available():
net.load_state_dict(torch.load(model_path))
net=net.cuda()
else:
net.load_state_dict(torch.load(model_path,map_location="cpu"))
net.eval()
device = "cuda:0" if torch.cuda.is_available() else "cpu"
model = AutoModel.from_pretrained("unum-cloud/uform-gen2-qwen-500m", trust_remote_code=True).to(device)
processor = AutoProcessor.from_pretrained("unum-cloud/uform-gen2-qwen-500m", trust_remote_code=True)
class StopOnTokens(StoppingCriteria):
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
stop_ids = [151645]
for stop_id in stop_ids:
if input_ids[0][-1] == stop_id:
return True
return False
def format_prompt(message, history):
prompt = ""
if history:
for user_prompt, bot_response in history:
prompt += f"<start_of_turn>user{user_prompt}<end_of_turn>"
prompt += f"<start_of_turn>model{bot_response}"
prompt += f"<start_of_turn>user{message}<end_of_turn><start_of_turn>modelo"
return prompt
def getProductTitle(history, context, image):
product_description=getImageDescription(image)
prompt="We have a product which is a" + context + ". Product description is as follows: " + product_description + ". Please write a product title options for it."
yield interactWithModel(history, prompt)
def getProductDescription(history):
prompt="Please also write an SEO friendly description for it describing its value to its users."
yield interactWithModel(history, prompt)
def interactWithModel(history, prompt):
system_prompt="You're a helpful e-commerce marketing assitant working on art products."
client = InferenceClient("google/gemma-2b-it")
rand_val = random.randint(1, 1111111111111111)
if not history:
history = []
generate_kwargs = dict(
temperature=0.67,
max_new_tokens=1024,
top_p=0.9,
repetition_penalty=1,
do_sample=True,
seed=rand_val,
)
formatted_prompt = format_prompt(f"{system_prompt}, {prompt}", history)
stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
output = ""
for response in stream:
output += response.token.text
history.append((prompt, output))
return history
@torch.no_grad()
def getImageDescription(image):
message = "Generate an ecommerce product description for the image"
gr.Info('Starting...' + message)
stop = StopOnTokens()
messages = [{"role": "system", "content": "You are a helpful assistant."}]
if len(messages) == 1:
message = f" <image>{message}"
messages.append({"role": "user", "content": message})
model_inputs = processor.tokenizer.apply_chat_template(
messages,
add_generation_prompt=True,
return_tensors="pt"
)
image = (
processor.feature_extractor(image)
.unsqueeze(0)
)
attention_mask = torch.ones(
1, model_inputs.shape[1] + processor.num_image_latents - 1
)
model_inputs = {
"input_ids": model_inputs,
"images": image,
"attention_mask": attention_mask
}
model_inputs = {k: v.to(device) for k, v in model_inputs.items()}
streamer = TextIteratorStreamer(processor.tokenizer, timeout=30., skip_prompt=True, skip_special_tokens=True)
generate_kwargs = dict(
model_inputs,
streamer=streamer,
max_new_tokens=1024,
stopping_criteria=StoppingCriteriaList([stop])
)
t = Thread(target=model.generate, kwargs=generate_kwargs)
t.start()
# history.append([message, ""])
partial_response = ""
for new_token in streamer:
partial_response += new_token
# history[-1][1] = partial_response
# yield history
gr.Info('Got:' + partial_response)
return partial_response
def resize_image(image):
image = image.convert('RGB')
model_input_size = (1024, 1024)
image = image.resize(model_input_size, Image.BILINEAR)
return image
def process(image):
# prepare input
orig_image = image
w,h = orig_im_size = orig_image.size
image = resize_image(orig_image)
im_np = np.array(image)
im_tensor = torch.tensor(im_np, dtype=torch.float32).permute(2,0,1)
im_tensor = torch.unsqueeze(im_tensor,0)
im_tensor = torch.divide(im_tensor,255.0)
im_tensor = normalize(im_tensor,[0.5,0.5,0.5],[1.0,1.0,1.0])
if torch.cuda.is_available():
im_tensor=im_tensor.cuda()
#inference
result=net(im_tensor)
# post process
result = torch.squeeze(F.interpolate(result[0][0], size=(h,w), mode='bilinear') ,0)
ma = torch.max(result)
mi = torch.min(result)
result = (result-mi)/(ma-mi)
# image to pil
im_array = (result*255).cpu().data.numpy().astype(np.uint8)
pil_im = Image.fromarray(np.squeeze(im_array))
# paste the mask on the original image
new_im = Image.new("RGBA", pil_im.size, (0,0,0,0))
new_im.paste(orig_image, mask=pil_im)
# new_orig_image = orig_image.convert('RGBA')
return new_im
title = """<h1 style="text-align: center;">Product description generator</h1>"""
css = """
div#col-container {
margin: 0 auto;
max-width: 840px;
}
"""
with gr.Blocks(css=css) as demo:
gr.HTML(title)
with gr.Row():
with gr.Column(elem_id="col-container"):
image = gr.Image(type="pil")
output = gr.Image(type="pil", interactive=False, label="Without background")
context = gr.Textbox(label="Small description")
submit = gr.Button(value="Upload", variant="primary")
with gr.Column():
chat = gr.Chatbot(show_label=False)
user_input= gr.Textbox()
send = gr.Button(value="Send")
title_handler = (
getProductTitle,
[chat, context, image],
[chat]
)
description_handler = (
getProductDescription,
[chat],
[chat]
)
interaction_handler = (
interactWithModel,
[chat, user_input],
[chat]
)
background_remover_handler = (
process,
[image],
[output]
)
# postresponse_handler = (
# lambda: (gr.Button(visible=False), gr.Button(visible=True)),
# None,
# [submit]
# )
submit.click(*title_handler).then(*description_handler)
submit.click(*background_remover_handler)
send.click(*interaction_handler)
# event.then(*postresponse_handler)
demo.launch(share=True) |