Ivy1997 commited on
Commit
de6fe0d
·
verified ·
1 Parent(s): 67a4289

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +58 -25
app.py CHANGED
@@ -1,10 +1,24 @@
1
  import gradio as gr
2
- from huggingface_hub import InferenceClient
 
 
 
 
 
 
 
 
3
 
4
- """
5
- For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
6
- """
7
- client = InferenceClient("AI-Safeguard/Ivy-VL-llava")
 
 
 
 
 
 
8
 
9
  def respond(
10
  message,
@@ -23,32 +37,51 @@ def respond(
23
  if val[1]:
24
  messages.append({"role": "assistant", "content": val[1]})
25
 
26
- messages.append({"role": "user", "content": message})
 
 
 
 
 
 
 
 
 
 
 
 
 
27
 
28
- payload = {
29
- "messages": messages,
30
- "max_tokens": max_tokens,
31
- "temperature": temperature,
32
- "top_p": top_p,
33
- }
 
 
34
 
35
- if image is not None:
36
- payload["image"] = image
 
37
 
38
- response = ""
 
 
 
 
 
39
 
40
- for message in client.chat_completion(
41
- payload,
42
- stream=True,
43
- ):
44
- token = message.choices[0].delta.content
 
 
45
 
46
- response += token
47
- yield response
48
 
49
- """
50
- For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
51
- """
52
  demo = gr.ChatInterface(
53
  respond,
54
  additional_inputs=[
 
1
  import gradio as gr
2
+ from llava.model.builder import load_pretrained_model
3
+ from llava.mm_utils import process_images, tokenizer_image_token
4
+ from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN
5
+ from llava.conversation import conv_templates
6
+ from PIL import Image
7
+ import copy
8
+ import torch
9
+ import warnings
10
+ import requests
11
 
12
+ warnings.filterwarnings("ignore")
13
+
14
+ pretrained = "AI-Safeguard/Ivy-VL-llava"
15
+ model_name = "llava_qwen"
16
+ device = "cuda"
17
+ device_map = "auto"
18
+
19
+ # Load model, tokenizer, and image processor
20
+ tokenizer, model, image_processor, max_length = load_pretrained_model(pretrained, None, model_name, device_map=device_map)
21
+ model.eval()
22
 
23
  def respond(
24
  message,
 
37
  if val[1]:
38
  messages.append({"role": "assistant", "content": val[1]})
39
 
40
+ if image:
41
+ # Process image if provided
42
+ image_tensor = process_images([image], image_processor, model.config)
43
+ image_tensor = [_image.to(dtype=torch.float16, device=device) for _image in image_tensor]
44
+
45
+ conv_template = "qwen_1_5"
46
+ question = DEFAULT_IMAGE_TOKEN + "\n" + message
47
+ conv = copy.deepcopy(conv_templates[conv_template])
48
+ conv.append_message(conv.roles[0], question)
49
+ conv.append_message(conv.roles[1], None)
50
+ prompt_question = conv.get_prompt()
51
+
52
+ input_ids = tokenizer_image_token(prompt_question, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt").unsqueeze(0).to(device)
53
+ image_sizes = [image.size]
54
 
55
+ cont = model.generate(
56
+ input_ids,
57
+ images=image_tensor,
58
+ image_sizes=image_sizes,
59
+ do_sample=False,
60
+ temperature=temperature,
61
+ max_new_tokens=max_tokens,
62
+ )
63
 
64
+ response = tokenizer.batch_decode(cont, skip_special_tokens=True)[0]
65
+ else:
66
+ messages.append({"role": "user", "content": message})
67
 
68
+ payload = {
69
+ "messages": messages,
70
+ "max_tokens": max_tokens,
71
+ "temperature": temperature,
72
+ "top_p": top_p,
73
+ }
74
 
75
+ response = ""
76
+ for message in client.chat_completion(
77
+ payload,
78
+ stream=True,
79
+ ):
80
+ token = message.choices[0].delta.content
81
+ response += token
82
 
83
+ yield response
 
84
 
 
 
 
85
  demo = gr.ChatInterface(
86
  respond,
87
  additional_inputs=[