project-baize commited on
Commit
06a42ee
1 Parent(s): 46db63f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -3
app.py CHANGED
@@ -17,7 +17,7 @@ base_model = "decapoda-research/llama-7b-hf"
17
  adapter_model = "project-baize/baize-lora-7B"
18
  tokenizer,model,device = load_tokenizer_and_model(base_model,adapter_model)
19
 
20
-
21
  def predict(text,
22
  chatbot,
23
  history,
@@ -43,7 +43,8 @@ def predict(text,
43
  begin_length = len(prompt)
44
  torch.cuda.empty_cache()
45
  input_ids = inputs["input_ids"].to(device)
46
-
 
47
  with torch.no_grad():
48
  for x in greedy_search(input_ids,model,tokenizer,stop_words=["[|Human|]", "[|AI|]"],max_length=max_length_tokens,temperature=temperature,top_p=top_p):
49
  if is_stop_word_or_prefix(x,["[|Human|]", "[|AI|]"]) is False:
@@ -221,4 +222,4 @@ with gr.Blocks(css=customCSS, theme=small_and_beautiful_theme) as demo:
221
  #)
222
  demo.title = "Baize"
223
 
224
- demo.queue(concurrency_count=1,).launch()
 
17
  adapter_model = "project-baize/baize-lora-7B"
18
  tokenizer,model,device = load_tokenizer_and_model(base_model,adapter_model)
19
 
20
+ total_cont = 0
21
  def predict(text,
22
  chatbot,
23
  history,
 
43
  begin_length = len(prompt)
44
  torch.cuda.empty_cache()
45
  input_ids = inputs["input_ids"].to(device)
46
+ total_cont += 1
47
+ print(total_cont)
48
  with torch.no_grad():
49
  for x in greedy_search(input_ids,model,tokenizer,stop_words=["[|Human|]", "[|AI|]"],max_length=max_length_tokens,temperature=temperature,top_p=top_p):
50
  if is_stop_word_or_prefix(x,["[|Human|]", "[|AI|]"]) is False:
 
222
  #)
223
  demo.title = "Baize"
224
 
225
+ demo.queue(concurrency_count=2).launch()