stibiumghost commited on
Commit
04e9e1a
·
1 Parent(s): 39f0451

Update text_gen.py

Browse files
Files changed (1) hide show
  1. text_gen.py +3 -2
text_gen.py CHANGED
@@ -15,11 +15,12 @@ model = [transformers.AutoModelForSeq2SeqLM.from_pretrained(model_names[0]),
15
 
16
 
17
  def generate_text(text, context, model_name, model, tokenizer, minimum=15, maximum=300):
 
 
18
  if 'GODEL' in model_name:
19
- text = f'Instruction: you need to response discreetly. [CONTEXT] {context} {text}'
20
  text.replace('\t', ' EOS ')
21
  else:
22
- text = f'{context} {text}'
23
  text = text.replace('\t', '\n')
24
  input_ids = tokenizer(text, return_tensors="pt").input_ids
25
  outputs = model.generate(input_ids, max_new_tokens=maximum, min_new_tokens=minimum, top_p=0.9, do_sample=True)
 
15
 
16
 
17
  def generate_text(text, context, model_name, model, tokenizer, minimum=15, maximum=300):
18
+ if context:
19
+ text = f'{context} {text}'
20
  if 'GODEL' in model_name:
21
+ text = f'Instruction: you need to response discreetly. [CONTEXT] {text}'
22
  text.replace('\t', ' EOS ')
23
  else:
 
24
  text = text.replace('\t', '\n')
25
  input_ids = tokenizer(text, return_tensors="pt").input_ids
26
  outputs = model.generate(input_ids, max_new_tokens=maximum, min_new_tokens=minimum, top_p=0.9, do_sample=True)