wetey commited on
Commit
8983492
·
1 Parent(s): f045f5a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -10
app.py CHANGED
@@ -8,19 +8,21 @@ headline = AutoModelForSeq2SeqLM.from_pretrained("wetey/content-summarizer")
8
  generate_long = AutoModelForSeq2SeqLM.from_pretrained("wetey/content-generator")
9
 
10
  def generate_headline(text):
11
- inputs = tokenizer(text, return_tensors="pt").input_ids
12
 
 
 
 
13
  generation_config = GenerationConfig(temperature = 1.2,
14
- encoder_no_repeat_ngram_size = 4)
15
-
16
- outputs = headline.generate(inputs,
17
- do_sample = True,
18
- generation_config = generation_config)
19
 
20
  return tokenizer.decode(outputs[0], skip_special_tokens = True)
21
 
22
  def generate_content(text):
23
- inputs = tokenizer(text, return_tensors="pt").input_ids
 
 
 
24
  generation_config = GenerationConfig(temperature = 1.2,
25
  encoder_no_repeat_ngram_size = 2,
26
  min_length = 50,
@@ -29,9 +31,7 @@ def generate_content(text):
29
  num_beams = 4,
30
  repetition_penalty = 1.5,
31
  no_repeat_ngram_size = 3)
32
- outputs = generate_long.generate(inputs,
33
- do_sample = True,
34
- generation_config = generation_config)
35
 
36
  return tokenizer.decode(outputs[0], skip_special_tokens = True)
37
 
 
8
  generate_long = AutoModelForSeq2SeqLM.from_pretrained("wetey/content-generator")
9
 
10
  def generate_headline(text):
 
11
 
12
+ prefix = "summarize "
13
+ input = prefix + text
14
+ inputs = tokenizer(input, return_tensors = "pt", max_length = 128, truncation = True).input_ids
15
  generation_config = GenerationConfig(temperature = 1.2,
16
+ encoder_no_repeat_ngram_size = 7)
17
+ outputs = headline.generate(inputs, do_sample = True, generation_config = generation_config)
 
 
 
18
 
19
  return tokenizer.decode(outputs[0], skip_special_tokens = True)
20
 
21
  def generate_content(text):
22
+
23
+ prefix = "generate_longer_text_from_headline: "
24
+ input = prefix + text
25
+ inputs = tokenizer(input, return_tensors="pt", max_length = 128, truncation = True).input_ids
26
  generation_config = GenerationConfig(temperature = 1.2,
27
  encoder_no_repeat_ngram_size = 2,
28
  min_length = 50,
 
31
  num_beams = 4,
32
  repetition_penalty = 1.5,
33
  no_repeat_ngram_size = 3)
34
+ outputs = generate_long.generate(inputs, do_sample = True, generation_config = generation_config)
 
 
35
 
36
  return tokenizer.decode(outputs[0], skip_special_tokens = True)
37