Omnibus commited on
Commit
62a62ca
·
verified ·
1 Parent(s): 72b121b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -5
app.py CHANGED
@@ -1,4 +1,4 @@
1
- from huggingface_hub import InferenceClient
2
  import gradio as gr
3
  import random
4
 
@@ -44,9 +44,9 @@ MAX_HISTORY=100
44
  opts=[]
45
  def generate(prompt, history,max_new_tokens,health,seed,temperature=temperature,top_p=top_p,repetition_penalty=repetition_penalty):
46
  opts.clear()
47
- client = InferenceClient("mistralai/Mixtral-8x7B-Instruct-v0.1")
48
  #client = InferenceClient("abacusai/Slerp-CM-mist-dpo")
49
-
50
  temperature = float(temperature)
51
  if temperature < 1e-2:
52
  temperature = 1e-2
@@ -79,10 +79,10 @@ def generate(prompt, history,max_new_tokens,health,seed,temperature=temperature,
79
  if cnt > MAX_HISTORY:
80
  history1 = compress_history(str(history), temperature, top_p, repetition_penalty)
81
  formatted_prompt = format_prompt(f"{GAME_MASTER.format(history=history1,stats=stats,dice=random.randint(1,10))}, {prompt}", history)
82
- stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
83
  output = ""
84
 
85
- for response in stream:
86
  output += response.token.text
87
  if history:
88
  yield [(prompt,output)],stats,None,None
 
1
+ from huggingface_hub import InferenceClient, AsyncInferenceClient
2
  import gradio as gr
3
  import random
4
 
 
44
  opts=[]
45
  def generate(prompt, history,max_new_tokens,health,seed,temperature=temperature,top_p=top_p,repetition_penalty=repetition_penalty):
46
  opts.clear()
47
+ #client = InferenceClient("mistralai/Mixtral-8x7B-Instruct-v0.1")
48
  #client = InferenceClient("abacusai/Slerp-CM-mist-dpo")
49
+ client = AsyncInferenceClient()
50
  temperature = float(temperature)
51
  if temperature < 1e-2:
52
  temperature = 1e-2
 
79
  if cnt > MAX_HISTORY:
80
  history1 = compress_history(str(history), temperature, top_p, repetition_penalty)
81
  formatted_prompt = format_prompt(f"{GAME_MASTER.format(history=history1,stats=stats,dice=random.randint(1,10))}, {prompt}", history)
82
+ stream = await client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
83
  output = ""
84
 
85
+ async for response in await stream:
86
  output += response.token.text
87
  if history:
88
  yield [(prompt,output)],stats,None,None