from threading import Thread from llava_llama3.serve.cli import chat_llava from llava_llama3.model.builder import load_pretrained_model import gradio as gr import torch from PIL import Image import argparse import spaces import os import time root_path = os.path.dirname(os.path.abspath(__file__)) print(root_path) parser = argparse.ArgumentParser() parser.add_argument("--model-path", type=str, default="TheFinAI/FinLLaVA") parser.add_argument("--device", type=str, default="cuda:0") parser.add_argument("--conv-mode", type=str, default="llama_3") parser.add_argument("--temperature", type=float, default=0.7) parser.add_argument("--max-new-tokens", type=int, default=512) parser.add_argument("--load-8bit", action="store_true") parser.add_argument("--load-4bit", action="store_true") args = parser.parse_args() # load model tokenizer, llava_model, image_processor, context_len = load_pretrained_model( args.model_path, None, 'llava_llama3', args.load_8bit, args.load_4bit, device=args.device ) @spaces.GPU def bot_streaming(message, history): print(message) image_path = None # Check if there's an image in the current message if message["files"]: # message["files"][-1] could be a dictionary or a string if isinstance(message["files"][-1], dict): image_path = message["files"][-1]["path"] else: image_path = message["files"][-1] else: # If no image in the current message, look in the history for the last image path for hist in history: if isinstance(hist[0], tuple): image_path = hist[0][0] # Error handling if no image path is found if image_path is None: raise gr.Error("You need to upload an image for LLaVA to work.") # If the image_path is a string, no need to load it into a PIL image # Just use the path directly in the next steps print(f"\033[91m{image_path}, {type(image_path)}\033[0m") # Generate the prompt for the model prompt = message['text'] # Use a streamer to generate the output in a streaming fashion streamer = [] # Define a function to call chat_llava in a separate thread def generate_output(): output = chat_llava( args=args, image_file=image_path, text=prompt, tokenizer=tokenizer, model=llava_model, image_processor=image_processor, context_len=context_len ) for new_text in output: streamer.append(new_text) # Start the generation in a separate thread thread = Thread(target=generate_output) thread.start() # Stream the output buffer = "" while thread.is_alive() or streamer: while streamer: new_text = streamer.pop(0) buffer += new_text yield buffer time.sleep(0.1) # Ensure any remaining text is yielded after the thread completes while streamer: new_text = streamer.pop(0) buffer += new_text yield buffer chatbot = gr.Chatbot(scale=1) chat_input = gr.MultimodalTextbox(interactive=True, file_types=["image"], placeholder="Enter message or upload file...", show_label=False) with gr.Blocks(fill_height=True) as demo: gr.ChatInterface( fn=bot_streaming, title="FinLLaVA", examples=[{"text": "What is on the flower?", "files": ["./bee.jpg"]}, {"text": "How to make this pastry?", "files": ["./baklava.png"]}, {"text":"What is this?","files":["http://images.cocodataset.org/val2017/000000039769.jpg"]}], stop_btn="Stop Generation", multimodal=True, textbox=chat_input, chatbot=chatbot, ) demo.queue(api_open=False) demo.launch(show_api=False, share=False)