import time from threading import Thread import gradio as gr import torch from PIL import Image from transformers import AutoProcessor, LlavaForConditionalGeneration, TextIteratorStreamer, TextStreamer import spaces import argparse from llava_llama3.model.builder import load_pretrained_model from llava_llama3.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN from llava_llama3.conversation import conv_templates, SeparatorStyle from llava_llama3.utils import disable_torch_init from llava_llama3.mm_utils import process_images, tokenizer_image_token, get_model_name_from_path from llava_llama3.serve.cli import chat_llava import requests from io import BytesIO import base64 import os import glob import pandas as pd from tqdm import tqdm import json root_path = os.path.dirname(os.path.abspath(__file__)) print(root_path) os.environ['GRADIO_TEMP_DIR'] = root_path parser = argparse.ArgumentParser() parser.add_argument("--model-path", type=str, default="TheFinAI/FinLLaVA") parser.add_argument("--device", type=str, default="cuda") parser.add_argument("--conv-mode", type=str, default="llama_3") parser.add_argument("--temperature", type=float, default=0.01) parser.add_argument("--max-new-tokens", type=int, default=512) parser.add_argument("--load-8bit", action="store_true") parser.add_argument("--load-4bit", action="store_true") args = parser.parse_args() args.load_8bit = True # Load model tokenizer, llava_model, image_processor, context_len = load_pretrained_model( args.model_path, None, 'llava_llama3', args.load_8bit, args.load_4bit, device=args.device) def bot_streaming(message, history): print ("triggered") print(message) image_file = None if message["files"]: if type(message["files"][-1]) == dict: image_file = message["files"][-1]["path"] else: image_file = message["files"][-1] else: for hist in history: if type(hist[0]) == tuple: image_file = hist[0][0] if image_file is None: gr.Error("You need to upload an image for LLaVA to work.") return streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) @spaces.GPU def generate(): print('Running chat') output = chat_llava( args=args, image_file=image_file, text=message['text'], tokenizer=tokenizer, model=llava_model, image_processor=image_processor, context_len=context_len, streamer=streamer) return output thread = Thread(target=generate) thread.start() # thread.join() buffer = "" # output = generate() for new_text in streamer: buffer += new_text generated_text_without_prompt = buffer time.sleep(0.06) print (generated_text_without_prompt) yield generated_text_without_prompt chatbot = gr.Chatbot(scale=1) chat_input = gr.MultimodalTextbox(interactive=True, file_types=["image"], placeholder="Enter message or upload file...", show_label=False) with gr.Blocks(fill_height=True) as demo: gr.ChatInterface( fn=bot_streaming, title="FinLLaVA Demo", examples=[ {"text": "What is in this picture?", "files": ["http://images.cocodataset.org/val2017/000000039769.jpg"]}, {"text": "What is the spending on Healthcare in July? A. 450 B. 600 C. 520 D. 510", "files": ["image_107.png"]}, {"text": "If 2012 net periodic opeb cost increased at the same pace as the pension cost, what would the estimated 2013 cost be in millions? A. 14.83333 B. 12.5 C. 15.5 D. 13.5", "files": ["image_659.png"]}, ], description="", stop_btn="Stop Generation", multimodal=True, textbox=chat_input, chatbot=chatbot, ) demo.queue(api_open=False) demo.launch(show_api=False, share=False)