stibiumghost commited on
Commit
7c0199b
·
1 Parent(s): 1f50941

Upload text_gen.py

Browse files
Files changed (1) hide show
  1. text_gen.py +35 -0
text_gen.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import transformers
2
+ import string
3
+
4
+ model_names = ['microsoft/GODEL-v1_1-base-seq2seq',
5
+ 'facebook/blenderbot-400M-distill',
6
+ 'facebook/blenderbot_small-90M']
7
+
8
+ tokenizers = [transformers.AutoTokenizer.from_pretrained(model_names[0]),
9
+ transformers.BlenderbotTokenizer.from_pretrained(model_names[1]),
10
+ transformers.BlenderbotSmallTokenizer.from_pretrained(model_names[2])]
11
+
12
+ model = [transformers.AutoModelForSeq2SeqLM.from_pretrained(model_names[0]),
13
+ transformers.BlenderbotForConditionalGeneration.from_pretrained(model_names[1]),
14
+ transformers.BlenderbotSmallForConditionalGeneration.from_pretrained(model_names[2])]
15
+
16
+
17
+ def generate_text(text, context, model_name, model, tokenizer, minimum=15, maximum=300):
18
+ text = f'{context} {text}'
19
+ if 'GODEL' in model_name:
20
+ text = 'Instruction: you need to response discreetly. [CONTEXT]' + text
21
+ else:
22
+ text = text.replace(' EOS ', '\n')
23
+ input_ids = tokenizer(text, return_tensors="pt").input_ids
24
+ outputs = model.generate(input_ids, max_new_tokens=maximum, min_new_tokens=minimum, top_p=0.9, do_sample=True)
25
+ output = tokenizer.decode(outputs[0], skip_special_tokens=True)
26
+ return capitalization(output)
27
+
28
+
29
+ def capitalization(line):
30
+ line, end = line[:-1], line[-1]
31
+ for mark in '.?!':
32
+ line = f'{mark} '.join([part.strip()[0].upper() + part.strip()[1:] for part in line.split(mark) if len(part) > 1])
33
+ line = ' '.join([word.capitalize() if word.translate(str.maketrans('', '', string.punctuation)) == 'i'
34
+ else word for word in line.split()])
35
+ return line + end