sablab commited on
Commit
da72fb0
·
verified ·
1 Parent(s): fe14c93

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +61 -12
app.py CHANGED
@@ -17,15 +17,15 @@ def format_prompt(message, history):
17
  def generate(
18
  prompt, history, temperature=0.9, max_new_tokens=16000, top_p=0.95, repetition_penalty=1.0,
19
  ):
20
- generate_kwargs = dict(
21
- temperature=0.9,
22
- max_new_tokens=16000,
23
- top_p=0.9,
24
- repetition_penalty=1.0,
 
25
  do_sample=True,
26
  seed=42,
27
  )
28
-
29
  formatted_prompt = format_prompt(prompt, history)
30
 
31
  stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
@@ -44,10 +44,59 @@ css = """
44
  }
45
  """
46
 
47
- with gr.Blocks(css=css) as demo:
48
- gr.HTML("<h1><center>Mistral 7B Instruct<h1><center>")
49
- gr.ChatInterface(
50
- generate
51
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
 
53
- demo.queue().launch(debug=True)
 
 
 
 
 
 
 
17
  def generate(
18
  prompt, history, temperature=0.9, max_new_tokens=16000, top_p=0.95, repetition_penalty=1.0,
19
  ):
20
+
21
+ generate_kwargs = dict(
22
+ temperature=temperature,
23
+ max_new_tokens=max_new_tokens,
24
+ top_p=top_p,
25
+ repetition_penalty=repetition_penalty,
26
  do_sample=True,
27
  seed=42,
28
  )
 
29
  formatted_prompt = format_prompt(prompt, history)
30
 
31
  stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
 
44
  }
45
  """
46
 
47
+ additional_inputs=[
48
+ gr.Textbox(
49
+ label="System Prompt",
50
+ max_lines=1,
51
+ interactive=True,
52
+ ),
53
+ gr.Slider(
54
+ label="Temperature",
55
+ value=0.9,
56
+ minimum=0.0,
57
+ maximum=1.0,
58
+ step=0.05,
59
+ interactive=True,
60
+ info="Higher values produce more diverse outputs",
61
+ ),
62
+ gr.Slider(
63
+ label="Max new tokens",
64
+ value=4192,
65
+ minimum=4192,
66
+ maximum=33536,
67
+ step=64,
68
+ interactive=True,
69
+ info="The maximum numbers of new tokens",
70
+ ),
71
+ gr.Slider(
72
+ label="Top-p (nucleus sampling)",
73
+ value=0.90,
74
+ minimum=0.0,
75
+ maximum=1,
76
+ step=0.05,
77
+ interactive=True,
78
+ info="Higher values sample more low-probability tokens",
79
+ ),
80
+ gr.Slider(
81
+ label="Repetition penalty",
82
+ value=1.2,
83
+ minimum=1.0,
84
+ maximum=2.0,
85
+ step=0.05,
86
+ interactive=True,
87
+ info="Penalize repeated tokens",
88
+ )
89
+ )
90
+ )
91
+ )
92
+ )
93
+ )
94
+ ]
95
 
96
+ gr.ChatInterface(
97
+ fn=generate,
98
+ chatbot=gr.Chatbot(show_label=False, show_share_button=False, show_copy_button=True, likeable=True, layout="panel"),
99
+ additional_inputs=additional_inputs,
100
+ title="Mistral 7B Instruct",
101
+ concurrency_limit=20,
102
+ ).launch(show_api=True)