import gradio as gr import torch import numpy as np import torch.nn.functional as F import PIL import random from threading import Thread from transformers import AutoModel, AutoProcessor from transformers import StoppingCriteria, TextIteratorStreamer, StoppingCriteriaList from torchvision.transforms.functional import normalize from huggingface_hub import hf_hub_download, InferenceClient from briarmbg import BriaRMBG from PIL import Image from typing import Tuple net=BriaRMBG() # model_path = "./model1.pth" model_path = hf_hub_download("briaai/RMBG-1.4", 'model.pth') if torch.cuda.is_available(): net.load_state_dict(torch.load(model_path)) net=net.cuda() else: net.load_state_dict(torch.load(model_path,map_location="cpu")) net.eval() device = "cuda:0" if torch.cuda.is_available() else "cpu" model = AutoModel.from_pretrained("unum-cloud/uform-gen2-qwen-500m", trust_remote_code=True).to(device) processor = AutoProcessor.from_pretrained("unum-cloud/uform-gen2-qwen-500m", trust_remote_code=True) class StopOnTokens(StoppingCriteria): def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: stop_ids = [151645] for stop_id in stop_ids: if input_ids[0][-1] == stop_id: return True return False def format_prompt(message, history): prompt = "" if history: for user_prompt, bot_response in history: prompt += f"user{user_prompt}" prompt += f"model{bot_response}" prompt += f"user{message}modelo" return prompt def getProductTitle(history, context, image): product_description=getImageDescription(image) prompt="We have a product which is a" + context + ". Product description is as follows: " + product_description + ". Please write a product title options for it." yield interactWithModel(history, prompt) def getProductDescription(history): prompt="Please also write an SEO friendly description for it describing its value to its users." yield interactWithModel(history, prompt) def interactWithModel(history, prompt): system_prompt="You're a helpful e-commerce marketing assitant working on art products." client = InferenceClient("google/gemma-7b-it") rand_val = random.randint(1, 1111111111111111) if not history: history = [] generate_kwargs = dict( temperature=0.67, max_new_tokens=1024, top_p=0.9, repetition_penalty=1, do_sample=True, seed=rand_val, ) formatted_prompt = format_prompt(f"{system_prompt}, {prompt}", history) stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False) output = "" for response in stream: output += response.token.text history.append((prompt, output)) return history @torch.no_grad() def getImageDescription(image): message = "Generate an ecommerce product description for the image" stop = StopOnTokens() messages = [{"role": "system", "content": "You are a helpful assistant."}] if len(messages) == 1: message = f" {message}" messages.append({"role": "user", "content": message}) model_inputs = processor.tokenizer.apply_chat_template( messages, add_generation_prompt=True, return_tensors="pt" ) image = ( processor.feature_extractor(image) .unsqueeze(0) ) attention_mask = torch.ones( 1, model_inputs.shape[1] + processor.num_image_latents - 1 ) model_inputs = { "input_ids": model_inputs, "images": image, "attention_mask": attention_mask } model_inputs = {k: v.to(device) for k, v in model_inputs.items()} streamer = TextIteratorStreamer(processor.tokenizer, timeout=30., skip_prompt=True, skip_special_tokens=True) generate_kwargs = dict( model_inputs, streamer=streamer, max_new_tokens=1024, stopping_criteria=StoppingCriteriaList([stop]) ) t = Thread(target=model.generate, kwargs=generate_kwargs) t.start() # history.append([message, ""]) partial_response = "" for new_token in streamer: partial_response += new_token # history[-1][1] = partial_response # yield history return partial_response def resize_image(image): image = image.convert('RGB') model_input_size = (1024, 1024) image = image.resize(model_input_size, Image.BILINEAR) return image def process(image): # prepare input orig_image = image w,h = orig_im_size = orig_image.size image = resize_image(orig_image) im_np = np.array(image) im_tensor = torch.tensor(im_np, dtype=torch.float32).permute(2,0,1) im_tensor = torch.unsqueeze(im_tensor,0) im_tensor = torch.divide(im_tensor,255.0) im_tensor = normalize(im_tensor,[0.5,0.5,0.5],[1.0,1.0,1.0]) if torch.cuda.is_available(): im_tensor=im_tensor.cuda() #inference result=net(im_tensor) # post process result = torch.squeeze(F.interpolate(result[0][0], size=(h,w), mode='bilinear') ,0) ma = torch.max(result) mi = torch.min(result) result = (result-mi)/(ma-mi) # image to pil im_array = (result*255).cpu().data.numpy().astype(np.uint8) pil_im = Image.fromarray(np.squeeze(im_array)) # paste the mask on the original image new_im = Image.new("RGBA", pil_im.size, (0,0,0,0)) new_im.paste(orig_image, mask=pil_im) # new_orig_image = orig_image.convert('RGBA') return new_im title = """

Product description generator

""" css = """ div#col-container { margin: 0 auto; max-width: 840px; } """ with gr.Blocks(css=css) as demo: gr.HTML(title) with gr.Row(): with gr.Column(elem_id="col-container"): image = gr.Image(type="pil") output = gr.Image(type="pil", interactive=False, label="Without background") context = gr.Textbox(label="Small description") submit = gr.Button(value="Upload", variant="primary") with gr.Column(): chat = gr.Chatbot(show_label=False) user_input= gr.Textbox() send = gr.Button(value="Send") title_handler = ( getProductTitle, [chat, context, image], [chat] ) description_handler = ( getProductDescription, [chat], [chat] ) interaction_handler = ( interactWithModel, [chat, user_input], [chat] ) background_remover_handler = ( process, [image], [output] ) # postresponse_handler = ( # lambda: (gr.Button(visible=False), gr.Button(visible=True)), # None, # [submit] # ) submit.click(*title_handler).then(*description_handler) submit.click(*background_remover_handler) send.click(*interaction_handler) # event.then(*postresponse_handler) demo.launch(share=True)