Spaces:
Runtime error
Runtime error
Commit
·
04e9e1a
1
Parent(s):
39f0451
Update text_gen.py
Browse files- text_gen.py +3 -2
text_gen.py
CHANGED
@@ -15,11 +15,12 @@ model = [transformers.AutoModelForSeq2SeqLM.from_pretrained(model_names[0]),
|
|
15 |
|
16 |
|
17 |
def generate_text(text, context, model_name, model, tokenizer, minimum=15, maximum=300):
|
|
|
|
|
18 |
if 'GODEL' in model_name:
|
19 |
-
text = f'Instruction: you need to response discreetly. [CONTEXT] {
|
20 |
text.replace('\t', ' EOS ')
|
21 |
else:
|
22 |
-
text = f'{context} {text}'
|
23 |
text = text.replace('\t', '\n')
|
24 |
input_ids = tokenizer(text, return_tensors="pt").input_ids
|
25 |
outputs = model.generate(input_ids, max_new_tokens=maximum, min_new_tokens=minimum, top_p=0.9, do_sample=True)
|
|
|
15 |
|
16 |
|
17 |
def generate_text(text, context, model_name, model, tokenizer, minimum=15, maximum=300):
|
18 |
+
if context:
|
19 |
+
text = f'{context} {text}'
|
20 |
if 'GODEL' in model_name:
|
21 |
+
text = f'Instruction: you need to response discreetly. [CONTEXT] {text}'
|
22 |
text.replace('\t', ' EOS ')
|
23 |
else:
|
|
|
24 |
text = text.replace('\t', '\n')
|
25 |
input_ids = tokenizer(text, return_tensors="pt").input_ids
|
26 |
outputs = model.generate(input_ids, max_new_tokens=maximum, min_new_tokens=minimum, top_p=0.9, do_sample=True)
|