Artples commited on
Commit
1c84354
·
verified ·
1 Parent(s): 5bd9cae

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +36 -83
app.py CHANGED
@@ -5,7 +5,7 @@ from typing import Iterator
5
  import gradio as gr
6
  import spaces
7
  import torch
8
- from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
9
 
10
  MAX_MAX_NEW_TOKENS = 2048
11
  DEFAULT_MAX_NEW_TOKENS = 1024
@@ -13,26 +13,21 @@ MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
13
 
14
  DESCRIPTION = """\
15
  # L-MChat
16
-
17
  This Space demonstrates [L-MChat](https://huggingface.co/collections/Artples/l-mchat-663265a8351231c428318a8f) by L-AI.
18
-
19
  """
20
 
21
-
22
  if not torch.cuda.is_available():
23
  DESCRIPTION += "\n<p>Running on CPU! This demo does not work on CPU.</p>"
24
 
25
-
26
- if torch.cuda.is_available():
27
- model_id = "Artples/L-MChat-7b"
28
- model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto")
29
- tokenizer = AutoTokenizer.from_pretrained(model_id)
30
- tokenizer.use_default_system_prompt = False
31
-
32
 
33
  @spaces.GPU(enable_queue=True, duration=90)
34
  def generate(
35
  message: str,
 
36
  chat_history: list[tuple[str, str]],
37
  system_prompt: str,
38
  max_new_tokens: int = 1024,
@@ -41,6 +36,11 @@ def generate(
41
  top_k: int = 50,
42
  repetition_penalty: float = 1.2,
43
  ) -> Iterator[str]:
 
 
 
 
 
44
  conversation = []
45
  if system_prompt:
46
  conversation.append({"role": "system", "content": system_prompt})
@@ -48,87 +48,40 @@ def generate(
48
  conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])
49
  conversation.append({"role": "user", "content": message})
50
 
51
- input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt", add_generation_prompt=True)
52
- if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
53
- input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
54
- gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
55
- input_ids = input_ids.to(model.device)
56
 
57
- streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
58
- generate_kwargs = dict(
59
- {"input_ids": input_ids},
60
- streamer=streamer,
61
- max_new_tokens=max_new_tokens,
62
- do_sample=True,
63
  top_p=top_p,
64
  top_k=top_k,
65
  temperature=temperature,
66
- num_beams=1,
67
- repetition_penalty=repetition_penalty,
68
  )
69
- t = Thread(target=model.generate, kwargs=generate_kwargs)
70
- t.start()
71
 
72
- outputs = []
73
- for text in streamer:
74
- outputs.append(text)
75
- yield "".join(outputs)
76
 
77
-
78
- chat_interface = gr.ChatInterface(
79
- theme='ehristoforu/RE_Theme',
80
  fn=generate,
81
- additional_inputs=[
82
- gr.Textbox(label="System prompt", lines=6),
83
- gr.Slider(
84
- label="Max new tokens",
85
- minimum=1,
86
- maximum=MAX_MAX_NEW_TOKENS,
87
- step=1,
88
- value=DEFAULT_MAX_NEW_TOKENS,
89
- ),
90
- gr.Slider(
91
- label="Temperature",
92
- minimum=0.1,
93
- maximum=4.0,
94
- step=0.1,
95
- value=0.6,
96
- ),
97
- gr.Slider(
98
- label="Top-p (nucleus sampling)",
99
- minimum=0.05,
100
- maximum=1.0,
101
- step=0.05,
102
- value=0.9,
103
- ),
104
- gr.Slider(
105
- label="Top-k",
106
- minimum=1,
107
- maximum=1000,
108
- step=1,
109
- value=50,
110
- ),
111
- gr.Slider(
112
- label="Repetition penalty",
113
- minimum=1.0,
114
- maximum=2.0,
115
- step=0.05,
116
- value=1.2,
117
- ),
118
- ],
119
- stop_btn=None,
120
- examples=[
121
- ["Hello there! How are you doing?"],
122
- ["Can you explain briefly to me what is the Python programming language?"],
123
- ["Explain the plot of Cinderella in a sentence."],
124
- ["How many hours does it take a man to eat a Helicopter?"],
125
- ["Write a 100-word article on 'Benefits of Open-Source in AI research'"],
126
  ],
 
 
 
127
  )
128
 
129
- with gr.Blocks(css="style.css") as demo:
130
- gr.Markdown(DESCRIPTION)
131
- chat_interface.render()
132
-
133
  if __name__ == "__main__":
134
- demo.queue(max_size=20).launch()
 
5
  import gradio as gr
6
  import spaces
7
  import torch
8
+ from transformers import AutoModelForCausalLM, AutoTokenizer
9
 
10
  MAX_MAX_NEW_TOKENS = 2048
11
  DEFAULT_MAX_NEW_TOKENS = 1024
 
13
 
14
  DESCRIPTION = """\
15
  # L-MChat
 
16
  This Space demonstrates [L-MChat](https://huggingface.co/collections/Artples/l-mchat-663265a8351231c428318a8f) by L-AI.
 
17
  """
18
 
 
19
  if not torch.cuda.is_available():
20
  DESCRIPTION += "\n<p>Running on CPU! This demo does not work on CPU.</p>"
21
 
22
+ model_options = {
23
+ "Fast-Model": "Artples/L-MChat-Small",
24
+ "Quality-Model": "Artples/L-MChat-7b"
25
+ }
 
 
 
26
 
27
  @spaces.GPU(enable_queue=True, duration=90)
28
  def generate(
29
  message: str,
30
+ model_choice: str,
31
  chat_history: list[tuple[str, str]],
32
  system_prompt: str,
33
  max_new_tokens: int = 1024,
 
36
  top_k: int = 50,
37
  repetition_penalty: float = 1.2,
38
  ) -> Iterator[str]:
39
+ model_id = model_options[model_choice]
40
+ model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto")
41
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
42
+ tokenizer.use_default_system_prompt = False
43
+
44
  conversation = []
45
  if system_prompt:
46
  conversation.append({"role": "system", "content": system_prompt})
 
48
  conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])
49
  conversation.append({"role": "user", "content": message})
50
 
51
+ input_ids = tokenizer(conversation, return_tensors="pt", padding=True, truncation=True)
52
+ if input_ids['input_ids'].shape[1] > MAX_INPUT_TOKEN_LENGTH:
53
+ input_ids['input_ids'] = input_ids['input_ids'][:, -MAX_INPUT_TOKEN_LENGTH:]
 
 
54
 
55
+ outputs = model.generate(
56
+ **input_ids,
57
+ max_length=input_ids['input_ids'].shape[1] + max_new_tokens,
 
 
 
58
  top_p=top_p,
59
  top_k=top_k,
60
  temperature=temperature,
61
+ num_return_sequences=1,
62
+ repetition_penalty=repetition_penalty
63
  )
 
 
64
 
65
+ generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
66
+ yield generated_text
 
 
67
 
68
+ chat_interface = gr.Interface(
 
 
69
  fn=generate,
70
+ inputs=[
71
+ gr.Textbox(lines=2, placeholder="Type your message here..."),
72
+ gr.Dropdown(label="Choose Model", choices=list(model_options.keys())),
73
+ gr.State(label="Chat History", default=[]),
74
+ gr.Textbox(label="System Prompt", lines=6, placeholder="Enter system prompt if any..."),
75
+ gr.Slider(label="Max new tokens", minimum=1, maximum=MAX_MAX_NEW_TOKENS, step=1, value=DEFAULT_MAX_NEW_TOKENS),
76
+ gr.Slider(label="Temperature", minimum=0.1, maximum=4.0, step=0.1, value=0.1),
77
+ gr.Slider(label="Top-p (nucleus sampling)", minimum=0.05, maximum=1.0, step=0.05, value=0.9),
78
+ gr.Slider(label="Top-k", minimum=1, maximum=1000, step=1, value=50),
79
+ gr.Slider(label="Repetition penalty", minimum=1.0, maximum=2.0, step=0.05, value=1.2),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
  ],
81
+ outputs=[gr.Textbox(label="Response")],
82
+ theme="default",
83
+ description=DESCRIPTION
84
  )
85
 
 
 
 
 
86
  if __name__ == "__main__":
87
+ chat_interface.launch()