stibiumghost commited on
Commit
b24eb89
·
1 Parent(s): e5cb5e2

Update text_gen.py

Browse files
Files changed (1) hide show
  1. text_gen.py +7 -6
text_gen.py CHANGED
@@ -1,9 +1,9 @@
1
  import transformers
2
  import string
3
 
4
- model_names = ['microsoft/GODEL-v1_1-base-seq2seq',
5
  'facebook/blenderbot-1B-distill',
6
- 'microsoft/DialoGPT-medium']
7
 
8
  tokenizers = [transformers.AutoTokenizer.from_pretrained(model_names[0]),
9
  transformers.BlenderbotTokenizer.from_pretrained(model_names[1]),
@@ -15,15 +15,16 @@ model = [transformers.AutoModelForSeq2SeqLM.from_pretrained(model_names[0]),
15
 
16
 
17
  def generate_text(text, context, model_name, model, tokenizer, minimum=15, maximum=300):
18
- text = f'{context} {text}'
19
  if 'GODEL' in model_name:
20
- text = 'Instruction: you need to response discreetly. [CONTEXT] ' + text
 
21
  else:
22
- text = text.replace(' EOS ', tokenizer.eos_token) + tokenizer.eos_token
 
23
  input_ids = tokenizer(text, return_tensors="pt").input_ids
24
  outputs = model.generate(input_ids, max_new_tokens=maximum, min_new_tokens=minimum, top_p=0.9, do_sample=True)
25
  output = tokenizer.decode(outputs[0], skip_special_tokens=True)
26
- return capitalization(output)
27
 
28
 
29
  def capitalization(line):
 
1
  import transformers
2
  import string
3
 
4
+ model_names = ['microsoft/GODEL-v1_1-large-seq2seq
5
  'facebook/blenderbot-1B-distill',
6
+ 'satvikag/chatbot']
7
 
8
  tokenizers = [transformers.AutoTokenizer.from_pretrained(model_names[0]),
9
  transformers.BlenderbotTokenizer.from_pretrained(model_names[1]),
 
15
 
16
 
17
  def generate_text(text, context, model_name, model, tokenizer, minimum=15, maximum=300):
 
18
  if 'GODEL' in model_name:
19
+ text = f'Instruction: you need to response discreetly. [CONTEXT] {context} {text}'
20
+ text.replace('\t', ' EOS ')
21
  else:
22
+ text = f'{context} {text}'
23
+ text = text.replace('\t', '\n')
24
  input_ids = tokenizer(text, return_tensors="pt").input_ids
25
  outputs = model.generate(input_ids, max_new_tokens=maximum, min_new_tokens=minimum, top_p=0.9, do_sample=True)
26
  output = tokenizer.decode(outputs[0], skip_special_tokens=True)
27
+ return model_name + capitalization(output)
28
 
29
 
30
  def capitalization(line):