sabssag commited on
Commit
1c861fc
·
verified ·
1 Parent(s): 5f37b42

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -16
app.py CHANGED
@@ -19,27 +19,29 @@ def generate_text(text):
19
 
20
  # Generate text
21
  output = model.generate(
22
- input_ids=encoded_input,
23
- max_length=100, # Specify the max length for the generated text
24
- num_return_sequences=1, # Number of sequences to generate
25
- no_repeat_ngram_size=2, # Avoid repeating n-grams of length 2
26
- top_k=50, # Limits the sampling pool to top_k tokens
27
- top_p=0.95, # Cumulative probability threshold for nucleus sampling
28
- temperature=0.7, # Controls the randomness of predictions
29
- do_sample=True, # Enable sampling
30
- attention_mask=encoded_input.new_ones(encoded_input.shape),
31
- pad_token_id=tokenizer.eos_token_id # Use the end-of-sequence token as padding
32
- )
33
 
34
  # Decode generated text
35
  generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
36
- return generate_text
37
 
38
- # Display the generated text
39
- st.subheader("Generated Blog Post")
40
- st.write(generated_text)
41
  except Exception as e:
42
  st.error(f"An error occurred: {e}")
 
 
43
  if st.button("Generate"):
44
  generated_text = generate_text(text)
45
- st.write(generated_text)
 
 
 
 
19
 
20
  # Generate text
21
  output = model.generate(
22
+ input_ids=encoded_input['input_ids'],
23
+ max_length=100, # Specify the max length for the generated text
24
+ num_return_sequences=1, # Number of sequences to generate
25
+ no_repeat_ngram_size=2, # Avoid repeating n-grams of length 2
26
+ top_k=50, # Limits the sampling pool to top_k tokens
27
+ top_p=0.95, # Cumulative probability threshold for nucleus sampling
28
+ temperature=0.7, # Controls the randomness of predictions
29
+ do_sample=True, # Enable sampling
30
+ attention_mask=encoded_input['attention_mask'], # Correct attention mask
31
+ pad_token_id=tokenizer.eos_token_id # Use the end-of-sequence token as padding
32
+ )
33
 
34
  # Decode generated text
35
  generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
36
+ return generated_text
37
 
 
 
 
38
  except Exception as e:
39
  st.error(f"An error occurred: {e}")
40
+ return None
41
+
42
  if st.button("Generate"):
43
  generated_text = generate_text(text)
44
+ if generated_text:
45
+ # Display the generated text
46
+ st.subheader("Generated Blog Post")
47
+ st.write(generated_text)