Update app.py
Browse files
app.py
CHANGED
@@ -5,10 +5,10 @@ import gradio as gr
|
|
5 |
tokenizer = AutoTokenizer.from_pretrained('TrLOX/gpt2-tdk')
|
6 |
model = AutoModelForCausalLM.from_pretrained('TrLOX/gpt2-tdk')
|
7 |
|
8 |
-
def text_generation(
|
9 |
-
input_ids = tokenizer(
|
10 |
torch.manual_seed(seed) # Max value: 18446744073709551615
|
11 |
-
outputs = model.generate(input_ids, do_sample=True, min_length=50, max_length=
|
12 |
generated_text = tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
13 |
return generated_text
|
14 |
|
@@ -17,7 +17,7 @@ description = "Title and description generation by keywords"
|
|
17 |
|
18 |
gr.Interface(
|
19 |
text_generation,
|
20 |
-
[gr.inputs.Textbox(lines=2, label="Enter
|
21 |
[gr.outputs.Textbox(type="auto", label="Text Generated")],
|
22 |
title=title,
|
23 |
description=description,
|
|
|
5 |
tokenizer = AutoTokenizer.from_pretrained('TrLOX/gpt2-tdk')
|
6 |
model = AutoModelForCausalLM.from_pretrained('TrLOX/gpt2-tdk')
|
7 |
|
8 |
+
def text_generation(keywords, domain, seed):
|
9 |
+
input_ids = tokenizer('keyword ' + keywords + ' domain ' + domain + ' title', return_tensors="pt").input_ids
|
10 |
torch.manual_seed(seed) # Max value: 18446744073709551615
|
11 |
+
outputs = model.generate(input_ids, do_sample=True, min_length=50, max_length=250)
|
12 |
generated_text = tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
13 |
return generated_text
|
14 |
|
|
|
17 |
|
18 |
gr.Interface(
|
19 |
text_generation,
|
20 |
+
[gr.inputs.Textbox(default='test 1,test 2',lines=2, label="Enter keywords"), gr.inputs.Textbox(lines=2, default='test.com',label="Enter domain"), gr.inputs.Number(default=10, label="Enter seed number")],
|
21 |
[gr.outputs.Textbox(type="auto", label="Text Generated")],
|
22 |
title=title,
|
23 |
description=description,
|