stibiumghost commited on
Commit
42ecfe8
·
1 Parent(s): f67d229

Update text_gen.py

Browse files
Files changed (1) hide show
  1. text_gen.py +4 -1
text_gen.py CHANGED
@@ -17,7 +17,10 @@ model = [transformers.AutoModelForSeq2SeqLM.from_pretrained(model_names[0]),
17
  def generate_text(text, context, model_name, model, tokenizer, minimum=15, maximum=300):
18
  text = f'{context} {text}'
19
  if 'GODEL' in model_name:
20
- text = 'Instruction: you need to response discreetly. [CONTEXT]' + text
 
 
 
21
  else:
22
  text = text.replace(' EOS ', '\n')
23
  input_ids = tokenizer(text, return_tensors="pt").input_ids
 
17
  def generate_text(text, context, model_name, model, tokenizer, minimum=15, maximum=300):
18
  text = f'{context} {text}'
19
  if 'GODEL' in model_name:
20
+ text = 'Instruction: you need to response discreetly. [CONTEXT] ' + text
21
+ elif 'pygmalion' in model_name:
22
+ text = 'ROBOTS\'s Persona: Discreet and polite humanoid robot. ' + text
23
+ text = text.replace(' EOS ', '\n')
24
  else:
25
  text = text.replace(' EOS ', '\n')
26
  input_ids = tokenizer(text, return_tensors="pt").input_ids