|
import gradio as gr |
|
import torch |
|
import numpy as np |
|
import torch.nn.functional as F |
|
import PIL |
|
import random |
|
from threading import Thread |
|
from transformers import AutoModel, AutoProcessor |
|
from transformers import StoppingCriteria, TextIteratorStreamer, StoppingCriteriaList |
|
from torchvision.transforms.functional import normalize |
|
from huggingface_hub import hf_hub_download, InferenceClient |
|
from briarmbg import BriaRMBG |
|
from PIL import Image |
|
from typing import Tuple |
|
|
|
|
|
net=BriaRMBG() |
|
|
|
model_path = hf_hub_download("briaai/RMBG-1.4", 'model.pth') |
|
if torch.cuda.is_available(): |
|
net.load_state_dict(torch.load(model_path)) |
|
net=net.cuda() |
|
else: |
|
net.load_state_dict(torch.load(model_path,map_location="cpu")) |
|
net.eval() |
|
|
|
|
|
device = "cuda:0" if torch.cuda.is_available() else "cpu" |
|
|
|
model = AutoModel.from_pretrained("unum-cloud/uform-gen2-qwen-500m", trust_remote_code=True).to(device) |
|
processor = AutoProcessor.from_pretrained("unum-cloud/uform-gen2-qwen-500m", trust_remote_code=True) |
|
|
|
class StopOnTokens(StoppingCriteria): |
|
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: |
|
stop_ids = [151645] |
|
for stop_id in stop_ids: |
|
if input_ids[0][-1] == stop_id: |
|
return True |
|
return False |
|
|
|
def format_prompt(message, history): |
|
prompt = "" |
|
if history: |
|
for user_prompt, bot_response in history: |
|
prompt += f"<start_of_turn>user{user_prompt}<end_of_turn>" |
|
prompt += f"<start_of_turn>model{bot_response}" |
|
prompt += f"<start_of_turn>user{message}<end_of_turn><start_of_turn>modelo" |
|
return prompt |
|
|
|
def getProductTitle(history, image): |
|
product_description=getImageDescription(image) |
|
prompt="Our product description is as follows: " + product_description + ". Please write a product title options for it." |
|
yield interactWithModel(history, system_prompt, prompt) |
|
|
|
def getProductDescription(history): |
|
prompt="Please also write an SEO friendly description for it describing its value to its users." |
|
yield interactWithModel(history, system_prompt, prompt) |
|
|
|
|
|
def interactWithModel(history, prompt): |
|
system_prompt="You're a helpful e-commerce marketing assitant working on art products." |
|
client = InferenceClient("google/gemma-7b-it") |
|
rand_val = random.randint(1, 1111111111111111) |
|
if not history: |
|
history = [] |
|
generate_kwargs = dict( |
|
temperature=0.67, |
|
max_new_tokens=1024, |
|
top_p=0.9, |
|
repetition_penalty=1, |
|
do_sample=True, |
|
seed=rand_val, |
|
) |
|
formatted_prompt = format_prompt(f"{system_prompt}, {prompt}", history) |
|
stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False) |
|
output = "" |
|
for response in stream: |
|
output += response.token.text |
|
history.append((prompt, output)) |
|
return history |
|
|
|
@torch.no_grad() |
|
def getImageDescription(image): |
|
message = "Generate an ecommerce product description for the image" |
|
gr.Info('Starting...' + message) |
|
stop = StopOnTokens() |
|
messages = [{"role": "system", "content": "You are a helpful assistant."}] |
|
|
|
if len(messages) == 1: |
|
message = f" <image>{message}" |
|
|
|
messages.append({"role": "user", "content": message}) |
|
|
|
model_inputs = processor.tokenizer.apply_chat_template( |
|
messages, |
|
add_generation_prompt=True, |
|
return_tensors="pt" |
|
) |
|
|
|
image = ( |
|
processor.feature_extractor(image) |
|
.unsqueeze(0) |
|
) |
|
|
|
attention_mask = torch.ones( |
|
1, model_inputs.shape[1] + processor.num_image_latents - 1 |
|
) |
|
|
|
model_inputs = { |
|
"input_ids": model_inputs, |
|
"images": image, |
|
"attention_mask": attention_mask |
|
} |
|
|
|
model_inputs = {k: v.to(device) for k, v in model_inputs.items()} |
|
|
|
streamer = TextIteratorStreamer(processor.tokenizer, timeout=30., skip_prompt=True, skip_special_tokens=True) |
|
generate_kwargs = dict( |
|
model_inputs, |
|
streamer=streamer, |
|
max_new_tokens=1024, |
|
stopping_criteria=StoppingCriteriaList([stop]) |
|
) |
|
t = Thread(target=model.generate, kwargs=generate_kwargs) |
|
t.start() |
|
|
|
|
|
partial_response = "" |
|
for new_token in streamer: |
|
partial_response += new_token |
|
|
|
|
|
gr.Info('Got:' + partial_response) |
|
return partial_response |
|
|
|
def resize_image(image): |
|
image = image.convert('RGB') |
|
model_input_size = (1024, 1024) |
|
image = image.resize(model_input_size, Image.BILINEAR) |
|
return image |
|
|
|
|
|
def process(image): |
|
|
|
orig_image = image |
|
w,h = orig_im_size = orig_image.size |
|
image = resize_image(orig_image) |
|
im_np = np.array(image) |
|
im_tensor = torch.tensor(im_np, dtype=torch.float32).permute(2,0,1) |
|
im_tensor = torch.unsqueeze(im_tensor,0) |
|
im_tensor = torch.divide(im_tensor,255.0) |
|
im_tensor = normalize(im_tensor,[0.5,0.5,0.5],[1.0,1.0,1.0]) |
|
if torch.cuda.is_available(): |
|
im_tensor=im_tensor.cuda() |
|
|
|
|
|
result=net(im_tensor) |
|
|
|
result = torch.squeeze(F.interpolate(result[0][0], size=(h,w), mode='bilinear') ,0) |
|
ma = torch.max(result) |
|
mi = torch.min(result) |
|
result = (result-mi)/(ma-mi) |
|
|
|
im_array = (result*255).cpu().data.numpy().astype(np.uint8) |
|
pil_im = Image.fromarray(np.squeeze(im_array)) |
|
|
|
new_im = Image.new("RGBA", pil_im.size, (0,0,0,0)) |
|
new_im.paste(orig_image, mask=pil_im) |
|
|
|
|
|
return new_im |
|
|
|
|
|
title = """<h1 style="text-align: center;">Product description generator</h1>""" |
|
css = """ |
|
div#col-container { |
|
margin: 0 auto; |
|
max-width: 840px; |
|
} |
|
""" |
|
|
|
with gr.Blocks(css=css) as demo: |
|
gr.HTML(title) |
|
with gr.Row(): |
|
with gr.Column(elem_id="col-container"): |
|
image = gr.Image(type="pil") |
|
output = gr.Image(type="pil", interactive=False) |
|
submit = gr.Button(value="Upload", variant="primary") |
|
with gr.Column(): |
|
chat = gr.Chatbot(show_label=False) |
|
user_input= gr.Textbox() |
|
send = gr.Button(value="Send") |
|
|
|
title_handler = ( |
|
getProductTitle, |
|
[chat, image], |
|
[chat] |
|
) |
|
|
|
description_handler = ( |
|
getProductDescription, |
|
[chat], |
|
[chat] |
|
) |
|
|
|
interaction_handler = ( |
|
getProductDescription, |
|
[chat, user_input], |
|
[chat] |
|
) |
|
|
|
background_remover_handler = ( |
|
process, |
|
[image], |
|
[output] |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
submit.click(*title_handler).then(*description_handler) |
|
submit.click(*background_remover_handler) |
|
send.click(*interaction_handler) |
|
|
|
|
|
demo.launch() |