ysharma HF Staff commited on
Commit
d2f200e
·
1 Parent(s): f3a3675

updated layout and added theme

Browse files
Files changed (1) hide show
  1. app.py +18 -21
app.py CHANGED
@@ -12,13 +12,6 @@ tok = AutoTokenizer.from_pretrained("togethercomputer/RedPajama-INCITE-Chat-3B-v
12
  m = AutoModelForCausalLM.from_pretrained("togethercomputer/RedPajama-INCITE-Chat-3B-v1", torch_dtype=torch.float16)
13
  m = m.to('cuda:0')
14
 
15
- start_message = """<|SYSTEM|># RedPajamaAssistant
16
- - RedPajamaAssistant is A helpful and harmless Open Source AI Language Model developed by Stability and CarperAI.
17
- - RedPajamaAssistant is excited to be able to help the user, but will refuse to do anything that could be considered harmful to the user.
18
- - RedPajamaAssistant is more than just an information source, RedPajamaAssistant is also able to write poetry, short stories, and make jokes.
19
- - RedPajamaAssistant will refuse to participate in anything that could harm a human."""
20
-
21
-
22
  class StopOnTokens(StoppingCriteria):
23
  def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
24
  #stop_ids = [[29, 13961, 31], [29, 12042, 31], 1, 0]
@@ -36,7 +29,7 @@ def user(message, history):
36
 
37
 
38
 
39
- def chat(curr_system_message, history):
40
  # Initialize a StopOnTokens object
41
  stop = StopOnTokens()
42
 
@@ -53,32 +46,32 @@ def chat(curr_system_message, history):
53
  streamer=streamer,
54
  max_new_tokens=1024,
55
  do_sample=True,
56
- top_p=0.95,
57
- top_k=1000,
58
- temperature=1.0,
59
  num_beams=1,
60
  stopping_criteria=StoppingCriteriaList([stop])
61
  )
62
  t = Thread(target=m.generate, kwargs=generate_kwargs)
63
  t.start()
64
 
65
- # print(history)
66
  # Initialize an empty string to store the generated text
67
  partial_text = ""
68
  for new_text in streamer:
69
- print(new_text)
70
  if new_text != '<':
71
  partial_text += new_text
72
  history[-1][1] = partial_text.split('<bot>:')[-1]
73
- # Yield an empty string to cleanup the message textbox and the updated conversation history
74
  yield history
75
  return partial_text
76
 
77
 
78
  title = """<h1 align="center">🔥RedPajama-INCITE-Chat-3B-v1</h1><br><h2 align="center">🏃‍♂️💨Streaming with Transformers & Gradio💪</h2>"""
79
- description = """<h3 align="center">This is a RedPajama Chat model fine-tuned using data from Dolly 2.0 and Open Assistant over the RedPajama-INCITE-Base-3B-v1 base model.</h3>"""
 
80
 
81
- with gr.Blocks() as demo:
82
  gr.HTML(title)
83
  gr.HTML('''<center><a href="https://huggingface.co/spaces/ysharma/RedPajama-Chat-3B?duplicate=true"><img src="https://bit.ly/3gLdBN6" alt="Duplicate Space"></a>Duplicate the Space to skip the queue and run in a private space</center>''')
84
  chatbot = gr.Chatbot().style(height=500)
@@ -91,17 +84,21 @@ with gr.Blocks() as demo:
91
  submit = gr.Button("Submit")
92
  stop = gr.Button("Stop")
93
  clear = gr.Button("Clear")
94
- system_msg = gr.Textbox(
95
- start_message, label="System Message", interactive=False, visible=False)
 
 
 
 
96
 
97
  submit_event = msg.submit(fn=user, inputs=[msg, chatbot], outputs=[msg, chatbot], queue=False).then(
98
- fn=chat, inputs=[system_msg, chatbot], outputs=[chatbot], queue=True)
99
  submit_click_event = submit.click(fn=user, inputs=[msg, chatbot], outputs=[msg, chatbot], queue=False).then(
100
- fn=chat, inputs=[system_msg, chatbot], outputs=[chatbot], queue=True)
101
  stop.click(fn=None, inputs=None, outputs=None, cancels=[
102
  submit_event, submit_click_event], queue=False)
103
  clear.click(lambda: None, None, [chatbot], queue=False)
104
  gr.HTML(description)
105
 
106
  demo.queue(max_size=32, concurrency_count=2)
107
- demo.launch(debug=True)
 
12
  m = AutoModelForCausalLM.from_pretrained("togethercomputer/RedPajama-INCITE-Chat-3B-v1", torch_dtype=torch.float16)
13
  m = m.to('cuda:0')
14
 
 
 
 
 
 
 
 
15
  class StopOnTokens(StoppingCriteria):
16
  def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
17
  #stop_ids = [[29, 13961, 31], [29, 12042, 31], 1, 0]
 
29
 
30
 
31
 
32
+ def chat(history, top_p, top_k, temperature):
33
  # Initialize a StopOnTokens object
34
  stop = StopOnTokens()
35
 
 
46
  streamer=streamer,
47
  max_new_tokens=1024,
48
  do_sample=True,
49
+ top_p=top_p, #0.95,
50
+ top_k=top_k, #1000,
51
+ temperature=temperature, #1.0,
52
  num_beams=1,
53
  stopping_criteria=StoppingCriteriaList([stop])
54
  )
55
  t = Thread(target=m.generate, kwargs=generate_kwargs)
56
  t.start()
57
 
 
58
  # Initialize an empty string to store the generated text
59
  partial_text = ""
60
  for new_text in streamer:
61
+ #print(new_text)
62
  if new_text != '<':
63
  partial_text += new_text
64
  history[-1][1] = partial_text.split('<bot>:')[-1]
65
+ # Yield an empty string to clean up the message textbox and the updated conversation history
66
  yield history
67
  return partial_text
68
 
69
 
70
  title = """<h1 align="center">🔥RedPajama-INCITE-Chat-3B-v1</h1><br><h2 align="center">🏃‍♂️💨Streaming with Transformers & Gradio💪</h2>"""
71
+ description = """<br><br><h3 align="center">This is a RedPajama Chat model fine-tuned using data from Dolly 2.0 and Open Assistant over the RedPajama-INCITE-Base-3B-v1 base model.</h3>"""
72
+ theme = gr.themes.Soft(primary_hue="red", secondary_hue= "red", neutral_hue="red",)
73
 
74
+ with gr.Blocks(theme=theme) as demo:
75
  gr.HTML(title)
76
  gr.HTML('''<center><a href="https://huggingface.co/spaces/ysharma/RedPajama-Chat-3B?duplicate=true"><img src="https://bit.ly/3gLdBN6" alt="Duplicate Space"></a>Duplicate the Space to skip the queue and run in a private space</center>''')
77
  chatbot = gr.Chatbot().style(height=500)
 
84
  submit = gr.Button("Submit")
85
  stop = gr.Button("Stop")
86
  clear = gr.Button("Clear")
87
+
88
+ #Advanced options - top_p, temperature, top_k
89
+ with gr.Accordion("Advanced Options:", open=False):
90
+ top_p = gr.Slider( minimum=-0, maximum=1.0, value=0.95, step=0.05, interactive=True, label="Top-p",)
91
+ top_k = gr.Slider(minimum=0.0, maximum=1000, value=1000, step=1, interactive=True, label="Top-k", )
92
+ temperature = gr.Slider( minimum=-0, maximum=5.0, value=1.0, step=0.1, interactive=True, label="Temperature",)
93
 
94
  submit_event = msg.submit(fn=user, inputs=[msg, chatbot], outputs=[msg, chatbot], queue=False).then(
95
+ fn=chat, inputs=[chatbot, top_p, top_k, temperature], outputs=[chatbot], queue=True)
96
  submit_click_event = submit.click(fn=user, inputs=[msg, chatbot], outputs=[msg, chatbot], queue=False).then(
97
+ fn=chat, inputs=[chatbot, top_p, top_k, temperature], outputs=[chatbot], queue=True)
98
  stop.click(fn=None, inputs=None, outputs=None, cancels=[
99
  submit_event, submit_click_event], queue=False)
100
  clear.click(lambda: None, None, [chatbot], queue=False)
101
  gr.HTML(description)
102
 
103
  demo.queue(max_size=32, concurrency_count=2)
104
+ demo.launch(debug=True)