basit123796 commited on
Commit
208be49
·
verified ·
1 Parent(s): 47511ce

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +54 -113
app.py CHANGED
@@ -4,25 +4,27 @@ from huggingface_hub import InferenceClient
4
 
5
  ss_client = Client("https://omnibus-html-image-current-tab.hf.space/")
6
 
7
- models = [
8
  "google/gemma-7b",
9
  "google/gemma-7b-it",
10
  "google/gemma-2b",
11
  "google/gemma-2b-it"
12
  ]
13
- clients = [
14
- InferenceClient(models[0]),
15
- InferenceClient(models[1]),
16
- InferenceClient(models[2]),
17
- InferenceClient(models[3]),
18
  ]
19
 
20
- VERBOSE = False
21
-
22
-
23
- def load_models():
24
- return gr.update(label=models[0])
25
 
 
 
 
 
 
 
26
 
27
  def format_prompt(message, history):
28
  prompt = ""
@@ -30,34 +32,28 @@ def format_prompt(message, history):
30
  for user_prompt, bot_response in history:
31
  prompt += f"<start_of_turn>user{user_prompt}<end_of_turn>"
32
  prompt += f"<start_of_turn>model{bot_response}<end_of_turn>"
33
- if VERBOSE:
34
  print(prompt)
35
  prompt += message
36
  return prompt
37
 
38
-
39
- def chat_inf(prompt, history, memory, temp, tokens, top_p, rep_p, chat_mem):
40
- hist_len = 0
41
- client = clients[0]
42
  if not history:
43
  history = []
44
- hist_len = 0
45
  if not memory:
46
  memory = []
47
- mem_len = 0
48
  if memory:
49
- for ea in memory[0 - chat_mem :]:
50
- hist_len += len(str(ea))
51
- in_len = len(prompt) + hist_len
52
 
53
- if (in_len + tokens) > 8000:
54
- history.append(
55
- (
56
- prompt,
57
- "Wait, that's too many tokens, please reduce the 'Chat Memory' value, or reduce the 'Max new tokens' value",
58
- )
59
- )
60
- yield history, memory
61
  else:
62
  generate_kwargs = dict(
63
  temperature=temp,
@@ -66,59 +62,33 @@ def chat_inf(prompt, history, memory, temp, tokens, top_p, rep_p, chat_mem):
66
  repetition_penalty=rep_p,
67
  do_sample=True,
68
  )
69
- formatted_prompt = format_prompt(prompt, memory[0 - chat_mem :])
70
- stream = client.text_generation(
71
- formatted_prompt,
72
- **generate_kwargs,
73
- stream=True,
74
- details=True,
75
- return_full_text=True,
76
- )
77
  output = ""
78
  for response in stream:
79
  output += response.token.text
80
- yield [(prompt, output)], memory
81
- history.append((prompt, output))
82
- memory.append((prompt, output))
83
- yield history, memory
84
-
85
- if VERBOSE:
86
- print("\n######### HIST " + str(in_len))
87
- print("\n######### TOKENS " + str(tokens))
88
-
89
-
90
- def get_screenshot(
91
- chat: list,
92
- height=5000,
93
- width=600,
94
- chatblock=[],
95
- theme="light",
96
- wait=3000,
97
- header=True,
98
- ):
99
  tog = 0
100
  if chatblock:
101
  tog = 3
102
- result = ss_client.predict(
103
- str(chat),
104
- height,
105
- width,
106
- chatblock,
107
- header,
108
- theme,
109
- wait,
110
- api_name="/run_script",
111
- )
112
  out = f'https://omnibus-html-image-current-tab.hf.space/file={result[tog]}'
113
  return out
114
 
115
-
116
  def clear_fn():
117
- return None, None, None, None
118
-
119
 
120
  with gr.Blocks() as app:
121
- memory = gr.State()
122
  chat_b = gr.Chatbot(height=500)
123
  with gr.Group():
124
  with gr.Row():
@@ -127,46 +97,17 @@ with gr.Blocks() as app:
127
  btn = gr.Button("Chat")
128
  with gr.Column(scale=1):
129
  with gr.Group():
130
- temp = gr.Slider(
131
- label="Temperature",
132
- step=0.01,
133
- minimum=0.01,
134
- maximum=1.0,
135
- value=0.49,
136
- )
137
- tokens = gr.Slider(
138
- label="Max new tokens",
139
- value=1600,
140
- minimum=0,
141
- maximum=8000,
142
- step=64,
143
- interactive=True,
144
- visible=True,
145
- info="The maximum number of tokens",
146
- )
147
- top_p = gr.Slider(
148
- label="Top-P",
149
- step=0.01,
150
- minimum=0.01,
151
- maximum=1.0,
152
- value=0.49,
153
- )
154
- rep_p = gr.Slider(
155
- label="Repetition Penalty",
156
- step=0.01,
157
- minimum=0.1,
158
- maximum=2.0,
159
- value=0.99,
160
- )
161
- chat_mem = gr.Number(
162
- label="Chat Memory",
163
- info="Number of previous chats to retain",
164
- value=4,
165
- )
166
-
167
- app.load(load_models)
168
- chat_sub = inp.submit().then(
169
- chat_inf, [inp, chat_b, memory, temp, tokens, top_p, rep_p, chat_mem], [chat_b, memory]
170
- )
171
- go = btn.click().then(
172
- chat_inf,
 
4
 
5
  ss_client = Client("https://omnibus-html-image-current-tab.hf.space/")
6
 
7
+ models=[
8
  "google/gemma-7b",
9
  "google/gemma-7b-it",
10
  "google/gemma-2b",
11
  "google/gemma-2b-it"
12
  ]
13
+ clients=[
14
+ InferenceClient(models[0]),
15
+ InferenceClient(models[1]),
16
+ InferenceClient(models[2]),
17
+ InferenceClient(models[3]),
18
  ]
19
 
20
+ VERBOSE=False
 
 
 
 
21
 
22
+ def load_models(inp):
23
+ if VERBOSE==True:
24
+ print(type(inp))
25
+ print(inp)
26
+ print(models[inp])
27
+ return gr.update(label=models[inp])
28
 
29
  def format_prompt(message, history):
30
  prompt = ""
 
32
  for user_prompt, bot_response in history:
33
  prompt += f"<start_of_turn>user{user_prompt}<end_of_turn>"
34
  prompt += f"<start_of_turn>model{bot_response}<end_of_turn>"
35
+ if VERBOSE==True:
36
  print(prompt)
37
  prompt += message
38
  return prompt
39
 
40
+ def chat_inf(prompt,history,memory,client_choice,temp,tokens,top_p,rep_p,chat_mem):
41
+ hist_len=0
42
+ client=clients[int(client_choice)-1]
 
43
  if not history:
44
  history = []
45
+ hist_len=0
46
  if not memory:
47
  memory = []
48
+ mem_len=0
49
  if memory:
50
+ for ea in memory[0-chat_mem:]:
51
+ hist_len+=len(str(ea))
52
+ in_len=len(prompt)+hist_len
53
 
54
+ if (in_len+tokens) > 8000:
55
+ history.append((prompt,"Wait, that's too many tokens, please reduce the 'Chat Memory' value, or reduce the 'Max new tokens' value"))
56
+ yield history,memory
 
 
 
 
 
57
  else:
58
  generate_kwargs = dict(
59
  temperature=temp,
 
62
  repetition_penalty=rep_p,
63
  do_sample=True,
64
  )
65
+ formatted_prompt = format_prompt(prompt, memory[0-chat_mem:])
66
+ stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=True)
 
 
 
 
 
 
67
  output = ""
68
  for response in stream:
69
  output += response.token.text
70
+ yield [(prompt,output)],memory
71
+ history.append((prompt,output))
72
+ memory.append((prompt,output))
73
+ yield history,memory
74
+
75
+ if VERBOSE==True:
76
+ print("\n######### HIST "+str(in_len))
77
+ print("\n######### TOKENS "+str(tokens))
78
+
79
+ def get_screenshot(chat: list,height=5000,width=600,chatblock=[],theme="light",wait=3000,header=True):
 
 
 
 
 
 
 
 
 
80
  tog = 0
81
  if chatblock:
82
  tog = 3
83
+ result = ss_client.predict(str(chat),height,width,chatblock,header,theme,wait,api_name="/run_script")
 
 
 
 
 
 
 
 
 
84
  out = f'https://omnibus-html-image-current-tab.hf.space/file={result[tog]}'
85
  return out
86
 
 
87
  def clear_fn():
88
+ return None,None,None,None
 
89
 
90
  with gr.Blocks() as app:
91
+ memory=gr.State()
92
  chat_b = gr.Chatbot(height=500)
93
  with gr.Group():
94
  with gr.Row():
 
97
  btn = gr.Button("Chat")
98
  with gr.Column(scale=1):
99
  with gr.Group():
100
+ temp=gr.Slider(label="Temperature",step=0.01, minimum=0.01, maximum=1.0, value=0.49)
101
+ tokens = gr.Slider(label="Max new tokens",value=1600,minimum=0,maximum=8000,step=64,interactive=True, visible=True,info="The maximum number of tokens")
102
+ top_p=gr.Slider(label="Top-P",step=0.01, minimum=0.01, maximum=1.0, value=0.49)
103
+ rep_p=gr.Slider(label="Repetition Penalty",step=0.01, minimum=0.1, maximum=2.0, value=0.99)
104
+ chat_mem=gr.Number(label="Chat Memory", info="Number of previous chats to retain",value=4)
105
+
106
+ client_choice=gr.Dropdown(label="Models",type='index',choices=[c for c in models],value=models[0],interactive=True)
107
+ client_choice.change(load_models,client_choice,[chat_b])
108
+ app.load(load_models,client_choice,[chat_b])
109
+
110
+ chat_sub=inp.submit().run(chat_inf,[inp,chat_b,memory,client_choice,temp,tokens,top_p,rep_p,chat_mem],[chat_b,memory])
111
+ go=btn.click().run(chat_inf,[inp,chat_b,memory,client_choice,temp,tokens,top_p,rep_p,chat_mem],[chat_b,memory])
112
+
113
+ app.queue(default_concurrency_limit=10).launch()