ramalMr commited on
Commit
beb08e3
·
verified ·
1 Parent(s): a712433

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -17
app.py CHANGED
@@ -7,16 +7,14 @@ client = InferenceClient(
7
  )
8
 
9
  def format_prompt(message, history):
10
- prompt = "<s>"
11
- for user_prompt, bot_response in history:
12
- prompt += f"[INST] {user_prompt} [/INST]"
13
- prompt += f" {bot_response}</s> "
14
- prompt += f"[INST] {message} [/INST]"
15
- return prompt
16
 
17
- def generate(
18
- prompt, history, system_prompt, temperature=0.9, max_new_tokens=256, top_p=0.95, repetition_penalty=1.0,
19
- ):
20
  temperature = float(temperature)
21
  if temperature < 1e-2:
22
  temperature = 1e-2
@@ -31,14 +29,13 @@ def generate(
31
  seed=42,
32
  )
33
 
34
- formatted_prompt = format_prompt(f"{system_prompt}, {prompt}", history)
35
- stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
36
- output = ""
37
 
38
- for response in stream:
39
- output += response.token.text
40
- yield output
41
- return output
42
 
43
  def process_file(file):
44
  text = file.decode("utf-8")
@@ -97,4 +94,4 @@ gr.ChatInterface(
97
  additional_inputs=additional_inputs,
98
  title="Synthetic-data-generation-aze",
99
  concurrency_limit=20,
100
- ).launch(show_api=False)
 
7
  )
8
 
9
  def format_prompt(message, history):
10
+ prompt = "<s>"
11
+ for user_prompt, bot_response in history:
12
+ prompt += f"[INST] {user_prompt} [/INST]"
13
+ prompt += f" {bot_response}</s> "
14
+ prompt += f"[INST] {message} [/INST]"
15
+ return prompt
16
 
17
+ def generate(prompt, history, system_prompt, temperature=0.9, max_new_tokens=256, top_p=0.95, repetition_penalty=1.0, file=None):
 
 
18
  temperature = float(temperature)
19
  if temperature < 1e-2:
20
  temperature = 1e-2
 
29
  seed=42,
30
  )
31
 
32
+ if file:
33
+ sentences = process_file(file)
34
+ prompt = "\n".join(sentences)
35
 
36
+ formatted_prompt = format_prompt(f"{system_prompt}, {prompt}", history)
37
+ response = client.text_generation(formatted_prompt, **generate_kwargs, details=True, return_full_text=True)
38
+ return response.text
 
39
 
40
  def process_file(file):
41
  text = file.decode("utf-8")
 
94
  additional_inputs=additional_inputs,
95
  title="Synthetic-data-generation-aze",
96
  concurrency_limit=20,
97
+ ).launch(show_api=False)