stibiumghost commited on
Commit
d3dc6c7
·
1 Parent(s): 5a50f61

Update text_gen.py

Browse files
Files changed (1) hide show
  1. text_gen.py +3 -6
text_gen.py CHANGED
@@ -3,24 +3,21 @@ import string
3
 
4
  model_names = ['microsoft/GODEL-v1_1-base-seq2seq',
5
  'facebook/blenderbot-1B-distill',
6
- 'PygmalionAI/pygmalion-1.3b']
7
 
8
  tokenizers = [transformers.AutoTokenizer.from_pretrained(model_names[0]),
9
  transformers.BlenderbotTokenizer.from_pretrained(model_names[1]),
10
- transformers.GPTNeoXTokenizerFast.from_pretrained(model_names[2])]
11
 
12
  model = [transformers.AutoModelForSeq2SeqLM.from_pretrained(model_names[0]),
13
  transformers.BlenderbotForConditionalGeneration.from_pretrained(model_names[1]),
14
- transformers.GPTNeoXForCausalLM.from_pretrained(model_names[2])]
15
 
16
 
17
  def generate_text(text, context, model_name, model, tokenizer, minimum=15, maximum=300):
18
  text = f'{context} {text}'
19
  if 'GODEL' in model_name:
20
  text = 'Instruction: you need to response discreetly. [CONTEXT] ' + text
21
- elif 'pygmalion' in model_name:
22
- text = 'ROBOTS\'s Persona: Discreet and polite humanoid robot. ' + text
23
- text = text.replace(' EOS ', '\n')
24
  else:
25
  text = text.replace(' EOS ', '\n')
26
  input_ids = tokenizer(text, return_tensors="pt").input_ids
 
3
 
4
  model_names = ['microsoft/GODEL-v1_1-base-seq2seq',
5
  'facebook/blenderbot-1B-distill',
6
+ 'microsoft/DialoGPT-medium']
7
 
8
  tokenizers = [transformers.AutoTokenizer.from_pretrained(model_names[0]),
9
  transformers.BlenderbotTokenizer.from_pretrained(model_names[1]),
10
+ transformers.GPT2Tokenizer.from_pretrained(model_names[2])]
11
 
12
  model = [transformers.AutoModelForSeq2SeqLM.from_pretrained(model_names[0]),
13
  transformers.BlenderbotForConditionalGeneration.from_pretrained(model_names[1]),
14
+ transformers.GPT2Model.from_pretrained(model_names[2])]
15
 
16
 
17
  def generate_text(text, context, model_name, model, tokenizer, minimum=15, maximum=300):
18
  text = f'{context} {text}'
19
  if 'GODEL' in model_name:
20
  text = 'Instruction: you need to response discreetly. [CONTEXT] ' + text
 
 
 
21
  else:
22
  text = text.replace(' EOS ', '\n')
23
  input_ids = tokenizer(text, return_tensors="pt").input_ids