PeterBrendan commited on
Commit
0dc9007
·
1 Parent(s): d312778

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -4
app.py CHANGED
@@ -2,8 +2,11 @@ import streamlit as st
2
  from transformers import pipeline
3
 
4
  @st.cache_resource
5
- def load_model():
6
- return pipeline("text-generation", model="PeterBrendan/pbjs_gpt2")
 
 
 
7
 
8
  def main():
9
  if "generated_widget_id" not in st.session_state:
@@ -27,6 +30,9 @@ def main():
27
  # Create a text input field for custom prompt
28
  custom_prompt = st.text_input("Enter a custom prompt:", "")
29
 
 
 
 
30
  # Check if a default prompt is selected
31
  if default_prompt:
32
  user_input = default_prompt
@@ -36,7 +42,7 @@ def main():
36
  # Check if the user input is empty
37
  if user_input:
38
  # Load the Hugging Face model
39
- generator = load_model()
40
 
41
  # Display 'Generating Output' message
42
  output_placeholder = st.empty()
@@ -44,7 +50,7 @@ def main():
44
  st.write("Generating Output...")
45
 
46
  # Generate text based on user input
47
- generated_text = generator(user_input, max_length=700, num_return_sequences=1)[0]["generated_text"]
48
 
49
  # Clear 'Generating Output' message and display the generated text
50
  output_placeholder.empty()
 
2
  from transformers import pipeline
3
 
4
  @st.cache_resource
5
+ def load_model(advanced_mode):
6
+ if advanced_mode:
7
+ return pipeline("text-generation", model="PeterBrendan/pbjsGPT2v2")
8
+ else:
9
+ return pipeline("text-generation", model="PeterBrendan/pbjs_gpt2")
10
 
11
  def main():
12
  if "generated_widget_id" not in st.session_state:
 
30
  # Create a text input field for custom prompt
31
  custom_prompt = st.text_input("Enter a custom prompt:", "")
32
 
33
+ # Create a checkbox for advanced mode
34
+ advanced_mode = st.checkbox("Advanced Mode")
35
+
36
  # Check if a default prompt is selected
37
  if default_prompt:
38
  user_input = default_prompt
 
42
  # Check if the user input is empty
43
  if user_input:
44
  # Load the Hugging Face model
45
+ generator = load_model(advanced_mode)
46
 
47
  # Display 'Generating Output' message
48
  output_placeholder = st.empty()
 
50
  st.write("Generating Output...")
51
 
52
  # Generate text based on user input
53
+ generated_text = generator(user_input, max_length=1500, num_return_sequences=1)[0]["generated_text"]
54
 
55
  # Clear 'Generating Output' message and display the generated text
56
  output_placeholder.empty()