sabssag commited on
Commit
8c457b5
·
verified ·
1 Parent(s): e0b2a67

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +26 -17
app.py CHANGED
@@ -1,3 +1,4 @@
 
1
  from transformers import GPT2Tokenizer, GPT2LMHeadModel
2
 
3
  # Initialize the tokenizer and model
@@ -5,24 +6,32 @@ model_name = 'gpt2-large'
5
  tokenizer = GPT2Tokenizer.from_pretrained(model_name)
6
  model = GPT2LMHeadModel.from_pretrained(model_name)
7
 
 
 
 
8
  # Text input for the user
9
- text = "my cat"
 
 
 
 
 
10
 
11
- # Encode input text
12
- encoded_input = tokenizer(text, return_tensors='pt', padding=True, truncation=True)
 
 
 
 
 
 
 
13
 
14
- # Generate text
15
- output = model.generate(
16
- input_ids=encoded_input['input_ids'],
17
- attention_mask=encoded_input['attention_mask'],
18
- max_length=200, # Adjust length as needed
19
- num_return_sequences=1,
20
- no_repeat_ngram_size=2,
21
- top_p=0.95,
22
- top_k=50,
23
- pad_token_id=tokenizer.eos_token_id # Set pad_token_id to eos_token_id
24
- )
25
 
26
- # Decode generated text
27
- generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
28
- generated_text
 
 
 
1
+ import streamlit as st
2
  from transformers import GPT2Tokenizer, GPT2LMHeadModel
3
 
4
  # Initialize the tokenizer and model
 
6
  tokenizer = GPT2Tokenizer.from_pretrained(model_name)
7
  model = GPT2LMHeadModel.from_pretrained(model_name)
8
 
9
+ # Set the title for the Streamlit app
10
+ st.title("GPT-2 Blog Post Generator")
11
+
12
  # Text input for the user
13
+ text = st.text_area("Enter your Topic: ")
14
+
15
+ if text:
16
+ try:
17
+ # Encode input text
18
+ encoded_input = tokenizer(text, return_tensors='pt')
19
 
20
+ # Generate text
21
+ output = model.generate(
22
+ input_ids=encoded_input['input_ids'],
23
+ max_length=200, # Adjust length as needed
24
+ num_return_sequences=1,
25
+ no_repeat_ngram_size=2,
26
+ top_p=0.95,
27
+ top_k=50
28
+ )
29
 
30
+ # Decode generated text
31
+ generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
 
 
 
 
 
 
 
 
 
32
 
33
+ # Display the generated text
34
+ st.subheader("Generated Blog Post")
35
+ st.write(generated_text)
36
+ except Exception as e:
37
+ st.error(f"An error occurred: {e}")