sabssag commited on
Commit
e0b2a67
·
verified ·
1 Parent(s): 9afd04f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +20 -37
app.py CHANGED
@@ -1,4 +1,3 @@
1
- import streamlit as st
2
  from transformers import GPT2Tokenizer, GPT2LMHeadModel
3
 
4
  # Initialize the tokenizer and model
@@ -6,40 +5,24 @@ model_name = 'gpt2-large'
6
  tokenizer = GPT2Tokenizer.from_pretrained(model_name)
7
  model = GPT2LMHeadModel.from_pretrained(model_name)
8
 
9
- # Set the title for the Streamlit app
10
- st.title("GPT-2 Blog Post Generator")
11
-
12
  # Text input for the user
13
- text = st.text_area("Enter your Topic: ")
14
-
15
- if text:
16
- try:
17
- # Encode input text
18
- encoded_input = tokenizer(text, return_tensors='pt')
19
-
20
- # Generate text
21
- output = model.generate(
22
- input_ids=encoded_input['input_ids'],
23
- max_length=200, # Adjust length as needed
24
- num_return_sequences=1,
25
- no_repeat_ngram_size=2,
26
- top_p=0.95,
27
- top_k=50
28
- )
29
-
30
- # Decode generated text
31
- generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
32
-
33
- # Display the generated text
34
- st.subheader("Generated Blog Post")
35
- st.write(generated_text)
36
- except Exception as e:
37
- st.error(f"An error occurred: {e}")
38
-
39
- # Add instructions
40
- st.write("""
41
- Enter a topic or a starting sentence in the text area above, and the GPT-2 model will generate a blog post for you.
42
- """)
43
-
44
- # Streamlit instructions
45
- st.write("To run this app, use the command: `streamlit run <script_name>.py`")
 
 
1
  from transformers import GPT2Tokenizer, GPT2LMHeadModel
2
 
3
  # Initialize the tokenizer and model
 
5
  tokenizer = GPT2Tokenizer.from_pretrained(model_name)
6
  model = GPT2LMHeadModel.from_pretrained(model_name)
7
 
 
 
 
8
  # Text input for the user
9
+ text = "my cat"
10
+
11
+ # Encode input text
12
+ encoded_input = tokenizer(text, return_tensors='pt', padding=True, truncation=True)
13
+
14
+ # Generate text
15
+ output = model.generate(
16
+ input_ids=encoded_input['input_ids'],
17
+ attention_mask=encoded_input['attention_mask'],
18
+ max_length=200, # Adjust length as needed
19
+ num_return_sequences=1,
20
+ no_repeat_ngram_size=2,
21
+ top_p=0.95,
22
+ top_k=50,
23
+ pad_token_id=tokenizer.eos_token_id # Set pad_token_id to eos_token_id
24
+ )
25
+
26
+ # Decode generated text
27
+ generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
28
+ generated_text