TobyYang7 commited on
Commit
ee668ff
1 Parent(s): 6a83074

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +96 -48
app.py CHANGED
@@ -1,57 +1,105 @@
1
- import gradio as gr
2
- from transformers import LlavaNextProcessor, LlavaNextForConditionalGeneration, TextIteratorStreamer
3
  from threading import Thread
4
- import re
5
- import time
6
- from PIL import Image
7
  import torch
 
 
 
 
8
  import spaces
9
 
10
- processor = LlavaNextProcessor.from_pretrained("llava-hf/llava-v1.6-mistral-7b-hf")
11
 
12
- model = LlavaNextForConditionalGeneration.from_pretrained("TheFinAI/FinLLaVA", torch_dtype=torch.float16, low_cpu_mem_usage=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  model.to("cuda:0")
 
 
14
 
15
  @spaces.GPU
16
  def bot_streaming(message, history):
17
- print(message)
18
- if message["files"]:
19
- image = message["files"][-1]["path"]
20
- else:
21
- # if there's no image uploaded for this turn, look for images in the past turns
22
- # kept inside tuples, take the last one
23
- for hist in history:
24
- if type(hist[0])==tuple:
25
- image = hist[0][0]
26
-
27
- if image is None:
28
- gr.Error("You need to upload an image for LLaVA to work.")
29
- prompt=f"[INST] <image>\n{message['text']} [/INST]"
30
- image = Image.open(image).convert("RGB")
31
- inputs = processor(prompt, image, return_tensors="pt").to("cuda:0")
32
-
33
- streamer = TextIteratorStreamer(processor, **{"skip_special_tokens": True})
34
- generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=100)
35
- generated_text = ""
36
-
37
- thread = Thread(target=model.generate, kwargs=generation_kwargs)
38
- thread.start()
39
-
40
- text_prompt =f"[INST] \n{message['text']} [/INST]"
41
-
42
-
43
- buffer = ""
44
- for new_text in streamer:
45
-
46
- buffer += new_text
47
-
48
- generated_text_without_prompt = buffer[len(text_prompt):]
49
- time.sleep(0.04)
50
- yield generated_text_without_prompt
51
-
52
-
53
- demo = gr.ChatInterface(fn=bot_streaming, title="LLaVA NeXT", examples=[{"text": "What is on the flower?", "files":["./bee.jpg"]},
54
- {"text": "How to make this pastry?", "files":["./baklava.png"]}],
55
- description="Try [LLaVA NeXT](https://huggingface.co/docs/transformers/main/en/model_doc/llava_next) in this demo (more specifically, the [Mistral-7B variant](https://huggingface.co/llava-hf/llava-v1.6-mistral-7b-hf)). Upload an image and start chatting about it, or simply try one of the examples below. If you don't upload an image, you will receive an error.",
56
- stop_btn="Stop Generation", multimodal=True)
57
- demo.launch(debug=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
 
2
  from threading import Thread
3
+
4
+ import gradio as gr
 
5
  import torch
6
+ from PIL import Image
7
+ from transformers import AutoProcessor, LlavaForConditionalGeneration
8
+ from transformers import TextIteratorStreamer
9
+
10
  import spaces
11
 
 
12
 
13
+ PLACEHOLDER = """
14
+ <div style="padding: 30px; text-align: center; display: flex; flex-direction: column; align-items: center;">
15
+ <img src="https://cdn-uploads.huggingface.co/production/uploads/64ccdc322e592905f922a06e/DDIW0kbWmdOQWwy4XMhwX.png" style="width: 80%; max-width: 550px; height: auto; opacity: 0.55; ">
16
+ <h1 style="font-size: 28px; margin-bottom: 2px; opacity: 0.55;">LLaVA-Llama-3-8B</h1>
17
+ <p style="font-size: 18px; margin-bottom: 2px; opacity: 0.65;">Llava-Llama-3-8b is a LLaVA model fine-tuned from Meta-Llama-3-8B-Instruct and CLIP-ViT-Large-patch14-336 with ShareGPT4V-PT and InternVL-SFT by XTuner</p>
18
+ </div>
19
+ """
20
+
21
+
22
+ model_id = "TheFinAI/FinLLaVA"
23
+
24
+ processor = AutoProcessor.from_pretrained(model_id)
25
+
26
+ model = LlavaForConditionalGeneration.from_pretrained(
27
+ model_id,
28
+ torch_dtype=torch.float16,
29
+ low_cpu_mem_usage=True,
30
+ )
31
+
32
  model.to("cuda:0")
33
+ model.generation_config.eos_token_id = 128009
34
+
35
 
36
  @spaces.GPU
37
  def bot_streaming(message, history):
38
+ print(message)
39
+ if message["files"]:
40
+ # message["files"][-1] is a Dict or just a string
41
+ if type(message["files"][-1]) == dict:
42
+ image = message["files"][-1]["path"]
43
+ else:
44
+ image = message["files"][-1]
45
+ else:
46
+ # if there's no image uploaded for this turn, look for images in the past turns
47
+ # kept inside tuples, take the last one
48
+ for hist in history:
49
+ if type(hist[0]) == tuple:
50
+ image = hist[0][0]
51
+ try:
52
+ if image is None:
53
+ # Handle the case where image is None
54
+ gr.Error("You need to upload an image for LLaVA to work.")
55
+ except NameError:
56
+ # Handle the case where 'image' is not defined at all
57
+ gr.Error("You need to upload an image for LLaVA to work.")
58
+
59
+ prompt = f"<|start_header_id|>user<|end_header_id|>\n\n<image>\n{message['text']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
60
+ # print(f"prompt: {prompt}")
61
+ image = Image.open(image)
62
+ inputs = processor(prompt, image, return_tensors='pt').to(0, torch.float16)
63
+
64
+ streamer = TextIteratorStreamer(processor, **{"skip_special_tokens": False, "skip_prompt": True})
65
+ generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=1024, do_sample=False)
66
+
67
+ thread = Thread(target=model.generate, kwargs=generation_kwargs)
68
+ thread.start()
69
+
70
+ text_prompt = f"<|start_header_id|>user<|end_header_id|>\n\n{message['text']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
71
+ # print(f"text_prompt: {text_prompt}")
72
+
73
+ buffer = ""
74
+ time.sleep(0.5)
75
+ for new_text in streamer:
76
+ # find <|eot_id|> and remove it from the new_text
77
+ if "<|eot_id|>" in new_text:
78
+ new_text = new_text.split("<|eot_id|>")[0]
79
+ buffer += new_text
80
+
81
+ # generated_text_without_prompt = buffer[len(text_prompt):]
82
+ generated_text_without_prompt = buffer
83
+ # print(generated_text_without_prompt)
84
+ time.sleep(0.06)
85
+ # print(f"new_text: {generated_text_without_prompt}")
86
+ yield generated_text_without_prompt
87
+
88
+
89
+ chatbot=gr.Chatbot(placeholder=PLACEHOLDER,scale=1)
90
+ chat_input = gr.MultimodalTextbox(interactive=True, file_types=["image"], placeholder="Enter message or upload file...", show_label=False)
91
+ with gr.Blocks(fill_height=True, ) as demo:
92
+ gr.ChatInterface(
93
+ fn=bot_streaming,
94
+ title="LLaVA Llama-3-8B",
95
+ examples=[{"text": "What is on the flower?", "files": ["./bee.jpg"]},
96
+ {"text": "How to make this pastry?", "files": ["./baklava.png"]}],
97
+ description="Try [LLaVA Llama-3-8B](https://huggingface.co/xtuner/llava-llama-3-8b-v1_1-transformers). Upload an image and start chatting about it, or simply try one of the examples below. If you don't upload an image, you will receive an error.",
98
+ stop_btn="Stop Generation",
99
+ multimodal=True,
100
+ textbox=chat_input,
101
+ chatbot=chatbot,
102
+ )
103
+
104
+ demo.queue(api_open=False)
105
+ demo.launch(show_api=False, share=False)