nroggendorff commited on
Commit
ee70b76
·
verified ·
1 Parent(s): 7725a81

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -2
app.py CHANGED
@@ -13,8 +13,19 @@ model.config.pad_token_id = model.config.eos_token_id
13
  @spaces.GPU
14
  def chat(prompt):
15
  input_ids = tokenizer.encode(prompt, return_tensors="pt")
16
- output = model.generate(input_ids, max_length=1024, num_return_sequences=1, top_p=0.9, top_k=1)
17
- response = tokenizer.decode(output[0])
 
 
 
 
 
 
 
 
 
 
 
18
  return response
19
 
20
  demo = gr.Interface(
 
13
  @spaces.GPU
14
  def chat(prompt):
15
  input_ids = tokenizer.encode(prompt, return_tensors="pt")
16
+ attention_mask = torch.ones_like(input_ids)
17
+
18
+ output = model.generate(
19
+ input_ids,
20
+ attention_mask=attention_mask,
21
+ max_length=1024,
22
+ num_return_sequences=1,
23
+ top_p=0.9,
24
+ top_k=1,
25
+ pad_token_id=model.config.eos_token_id
26
+ )
27
+
28
+ response = tokenizer.decode(output[0], skip_special_tokens=True)
29
  return response
30
 
31
  demo = gr.Interface(