nawhgnuj commited on
Commit
e262200
·
verified ·
1 Parent(s): e3c8bb6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -1
app.py CHANGED
@@ -49,6 +49,11 @@ quantization_config = BitsAndBytesConfig(
49
  bnb_4bit_quant_type="nf4")
50
 
51
  tokenizer = AutoTokenizer.from_pretrained(MODEL)
 
 
 
 
 
52
  model = AutoModelForCausalLM.from_pretrained(
53
  MODEL,
54
  torch_dtype=torch.bfloat16,
@@ -89,17 +94,20 @@ Crucially, always respond to and rebut the previous speaker's points in Harris's
89
  conversation.append({"role": "user", "content": message})
90
 
91
  input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt").to(model.device)
 
92
 
93
  streamer = TextIteratorStreamer(tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
94
 
95
  generate_kwargs = dict(
96
  input_ids=input_ids,
 
97
  max_new_tokens=max_new_tokens,
98
  do_sample=True,
99
  top_p=top_p,
100
  top_k=top_k,
101
  temperature=temperature,
102
- eos_token_id=[128001,128008,128009],
 
103
  streamer=streamer,
104
  repetition_penalty=repetition_penalty,
105
  )
 
49
  bnb_4bit_quant_type="nf4")
50
 
51
  tokenizer = AutoTokenizer.from_pretrained(MODEL)
52
+
53
+ if tokenizer.pad_token is None:
54
+ tokenizer.pad_token = tokenizer.eos_token
55
+ tokenizer.pad_token_id = tokenizer.eos_token_id
56
+
57
  model = AutoModelForCausalLM.from_pretrained(
58
  MODEL,
59
  torch_dtype=torch.bfloat16,
 
94
  conversation.append({"role": "user", "content": message})
95
 
96
  input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt").to(model.device)
97
+ attention_mask = torch.ones_like(input_ids)
98
 
99
  streamer = TextIteratorStreamer(tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
100
 
101
  generate_kwargs = dict(
102
  input_ids=input_ids,
103
+ attention_mask=attention_mask,
104
  max_new_tokens=max_new_tokens,
105
  do_sample=True,
106
  top_p=top_p,
107
  top_k=top_k,
108
  temperature=temperature,
109
+ pad_token_id=tokenizer.pad_token_id,
110
+ eos_token_id=tokenizer.eos_token_id,
111
  streamer=streamer,
112
  repetition_penalty=repetition_penalty,
113
  )