import torch from PIL import Image import gradio as gr import spaces from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer import os from threading import Thread MODEL_ID = "./coin_model_funtuned" TITLE = f'
🚀 Coin Generative Recognition
' DESCRIPTION = f"""

A Space for Vision/Multimodal

✨ Tips: Send messages or upload multiple IMAGES at a time.
✨ Tips: Please increase MAX LENGTH when dealing with files.
🤙 Supported Format: png, jpg, webp
🙇‍♂️ May be rebuilding from time to time.

""" CSS = """ h1 { text-align: center; display: block; } img { max-width: 100%; /* Make sure images are not wider than their container */ height: auto; /* Maintain aspect ratio */ max-height: 300px; /* Limit the height of images */ } """ # Load model directly model = AutoModelForCausalLM.from_pretrained( MODEL_ID, torch_dtype=torch.bfloat16, low_cpu_mem_usage=True, trust_remote_code=True ).to(0) tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True) model.eval() def merge_images(paths): images = [Image.open(path).convert('RGB') for path in paths] widths, heights = zip(*(i.size for i in images)) total_width = sum(widths) max_height = max(heights) new_im = Image.new('RGB', (total_width, max_height)) x_offset = 0 for im in images: new_im.paste(im, (x_offset,0)) x_offset += im.width return new_im def mode_load(paths): if all(path.lower().endswith(('png', 'jpg', 'jpeg', 'webp')) for path in paths): content = merge_images(paths) choice = "image" return choice, content else: raise gr.Error("Unsupported file types. Please upload only images.") @spaces.GPU() def stream_chat(message, history: list, temperature: float, max_length: int, top_p: float, top_k: int, penalty: float): conversation = [] if message["files"]: choice, contents = mode_load(message["files"]) conversation.append({"role": "user", "image": contents, "content": message['text']}) elif message["files"] and len(message["files"]) == 1: content = Image.open( message["files"][-1]).convert('RGB') choice = "image" conversation.append({"role": "user", "image": content, "content": message['text']}) else: raise gr.Error("Please upload one or more images.") input_ids = tokenizer.apply_chat_template(conversation, tokenize=True, add_generation_prompt=True, return_tensors="pt", return_dict=True).to(model.device) streamer = TextIteratorStreamer(tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True) generate_kwargs = dict( max_length=max_length, streamer=streamer, do_sample=True, top_p=top_p, top_k=top_k, temperature=temperature, repetition_penalty=penalty, eos_token_id=[151329, 151336, 151338], ) gen_kwargs = {**input_ids, **generate_kwargs} with torch.no_grad(): thread = Thread(target=model.generate, kwargs=gen_kwargs) thread.start() buffer = "" for new_text in streamer: buffer += new_text yield buffer chatbot = gr.Chatbot(label="Chatbox", height=600, placeholder=DESCRIPTION) chat_input = gr.MultimodalTextbox( interactive=True, placeholder="Enter message or upload images...", show_label=False, file_count="multiple", ) EXAMPLES = [ [{"text": "Give me Country,Denomination and year as json format.", "files": ["./135_back.jpg", "./135_front.jpg"]}], [{"text": "Give me Country,Denomination and year as json format.", "files": ["./141_back.jpg","./141_front.jpg"]}] ] with gr.Blocks(css=CSS, theme="soft", fill_height=True) as demo: gr.HTML(TITLE) gr.ChatInterface( fn=stream_chat, multimodal=True, textbox=chat_input, chatbot=chatbot, fill_height=True, additional_inputs_accordion=gr.Accordion(label="⚙️ Parameters", open=False, render=False), additional_inputs=[ gr.Slider( minimum=0, maximum=1, step=0.1, value=0.8, label="Temperature", render=False, ), gr.Slider( minimum=1024, maximum=8192, step=1, value=4096, label="Max Length", render=False, ), gr.Slider( minimum=0.0, maximum=1.0, step=0.1, value=1.0, label="top_p", render=False, ), gr.Slider( minimum=1, maximum=20, step=1, value=10, label="top_k", render=False, ), gr.Slider( minimum=0.0, maximum=2.0, step=0.1, value=1.0, label="Repetition penalty", render=False, ), ], ), gr.Examples(EXAMPLES, [chat_input]) if __name__ == "__main__": demo.queue(api_open=False).launch(show_api=False, share=False)