leonardlin commited on
Commit
f00ac1d
โ€ข
1 Parent(s): 0e02ca5

swap models, examples, check for multigpu, example

Browse files
Files changed (1) hide show
  1. app.py +10 -78
app.py CHANGED
@@ -9,19 +9,17 @@ from threading import Thread
9
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
10
 
11
  # Model
12
- model_name = "mistralai/Mistral-7B-Instruct-v0.1"
13
  model_name = "TinyLlama/TinyLlama-1.1B-Chat-v0.3"
14
- model_name = "/models/llm/hf/mistralai_Mistral-7B-Instruct-v0.1"
15
 
16
  # UI Settings
17
  title = "Shisa 7B"
18
  description = "Test out Shisa 7B in either English or Japanese."
19
  placeholder = "Type Here / ใ“ใ“ใซๅ…ฅๅŠ›ใ—ใฆใใ ใ•ใ„"
20
  examples = [
21
- "Hello, how are you?",
22
- "ใ“ใ‚“ใซใกใฏใ€ๅ…ƒๆฐ—ใงใ™ใ‹๏ผŸ",
23
- "ใŠใฃใ™ใ€ๅ…ƒๆฐ—๏ผŸ",
24
- "ใ“ใ‚“ใซใกใฏใ€ใ„ใ‹ใŒใŠ้Žใ”ใ—ใงใ™ใ‹๏ผŸ",
25
  ]
26
 
27
  # LLM Settings
@@ -39,7 +37,11 @@ streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_
39
 
40
  def chat(message, history):
41
  chat_history.append({"role": "user", "content": message})
42
- input_ids = tokenizer.apply_chat_template(chat_history, add_generation_prompt=True, return_tensors="pt").to('cuda')
 
 
 
 
43
  generate_kwargs = dict(
44
  inputs=input_ids,
45
  streamer=streamer,
@@ -57,13 +59,6 @@ def chat(message, history):
57
  partial_message += new_token # html.escape(new_token)
58
  yield partial_message
59
 
60
- '''
61
- # https://www.gradio.app/main/guides/creating-a-chatbot-fast#streaming-chatbots
62
- for i in range(len(message)):
63
- time.sleep(0.3)
64
- yield message[: i+1]
65
- '''
66
-
67
 
68
  chat_interface = gr.ChatInterface(
69
  chat,
@@ -81,69 +76,6 @@ chat_interface = gr.ChatInterface(
81
  # https://huggingface.co/spaces/ysharma/Explore_llamav2_with_TGI/blob/main/app.py#L219 - we use this with construction b/c Gradio barfs on autoreload otherwise
82
  with gr.Blocks() as demo:
83
  chat_interface.render()
84
- gr.Markdown("You can try these greetings in English, Japanese, familiar Japanese, or formal Japanese. We limit output to 200 tokens.")
85
-
86
 
87
  demo.queue().launch()
88
-
89
- '''
90
- # Works for Text input...
91
- demo = gr.Interface.from_pipeline(pipe)
92
- '''
93
-
94
- '''
95
- def chat(message, history):
96
- print("foo")
97
- for i in range(len(message)):
98
- time.sleep(0.3)
99
- yield "You typed: " + message[: i+1]
100
- # print('history:', history)
101
- # print('message:', message)
102
- # for new_next in streamer:
103
- # yield new_text
104
-
105
-
106
- '''
107
-
108
-
109
- '''
110
- # Docs: https://github.com/huggingface/transformers/blob/main/src/transformers/pipelines/conversational.py
111
- conversation = Conversation()
112
- conversation.add_message({"role": "system", "content": system})
113
- device = torch.device('cuda')
114
- pipe = pipeline(
115
- 'conversational',
116
- model=model,
117
- tokenizer=tokenizer,
118
- streamer=streamer,
119
-
120
- )
121
-
122
- def chat(input, history):
123
- conversation.add_message({"role": "user", "content": input})
124
- # we do this shuffle so local shadow response doesn't get created
125
- response_conversation = pipe(conversation)
126
- print("foo:", response_conversation.messages[-1]["content"])
127
-
128
- conversation.add_message(response_conversation.messages[-1])
129
- print("boo:", response_conversation.messages[-1]["content"])
130
- response = conversation.messages[-1]["content"]
131
- response = "ping"
132
- return response
133
-
134
- demo = gr.ChatInterface(
135
- chat,
136
- chatbot=gr.Chatbot(height=400),
137
- textbox=gr.Textbox(placeholder=placeholder, container=False, scale=7),
138
- title=title,
139
- description=description,
140
- theme="soft",
141
- examples=examples,
142
- cache_examples=False,
143
- undo_btn="Delete Previous",
144
- clear_btn="Clear",
145
- ).launch()
146
-
147
- # For async
148
- # ).queue().launch()
149
- '''
 
9
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
10
 
11
  # Model
 
12
  model_name = "TinyLlama/TinyLlama-1.1B-Chat-v0.3"
13
+ model_name = "mistralai/Mistral-7B-Instruct-v0.1"
14
 
15
  # UI Settings
16
  title = "Shisa 7B"
17
  description = "Test out Shisa 7B in either English or Japanese."
18
  placeholder = "Type Here / ใ“ใ“ใซๅ…ฅๅŠ›ใ—ใฆใใ ใ•ใ„"
19
  examples = [
20
+ "What's the best ramen in Tokyo?",
21
+ "ๆฑไบฌใงใŠใ™ใ™ใ‚ใฎใƒฉใƒผใƒกใƒณๅฑ‹ใ•ใ‚“ใ‚’ๆ•™ใˆใฆใ„ใŸใ ใ‘ใพใ™ใ‹ใ€‚",
22
+ "ๆฑไบฌใงใŠใ™ใ™ใ‚ใฎใƒฉใƒผใƒกใƒณๅฑ‹ใฃใฆใฉใ“๏ผŸ",
 
23
  ]
24
 
25
  # LLM Settings
 
37
 
38
  def chat(message, history):
39
  chat_history.append({"role": "user", "content": message})
40
+ input_ids = tokenizer.apply_chat_template(chat_history, add_generation_prompt=True, return_tensors="pt")
41
+ # for multi-gpu, find the device of the first parameter of the model
42
+ first_param_device = next(model.parameters()).device
43
+ input_ids = input_ids.to(first_param_device)
44
+
45
  generate_kwargs = dict(
46
  inputs=input_ids,
47
  streamer=streamer,
 
59
  partial_message += new_token # html.escape(new_token)
60
  yield partial_message
61
 
 
 
 
 
 
 
 
62
 
63
  chat_interface = gr.ChatInterface(
64
  chat,
 
76
  # https://huggingface.co/spaces/ysharma/Explore_llamav2_with_TGI/blob/main/app.py#L219 - we use this with construction b/c Gradio barfs on autoreload otherwise
77
  with gr.Blocks() as demo:
78
  chat_interface.render()
79
+ gr.Markdown("You can try asking this question in English, formal Japanese, and informal Japanese. You might need to ask it to reply informally with something like ใ‚‚ใฃใจๅ‹้”ใฟใŸใ„ใซ่ฉฑใใ†ใ‚ˆใ€‚ใ‚ใ‚“ใพใ‚Šๅ …่‹ฆใ—ใใชใใฆใ€‚to get informal replies. We limit output to 200 tokens.")
 
80
 
81
  demo.queue().launch()