Spaces:
Sleeping
Sleeping
Commit
·
0dc9007
1
Parent(s):
d312778
Update app.py
Browse files
app.py
CHANGED
@@ -2,8 +2,11 @@ import streamlit as st
|
|
2 |
from transformers import pipeline
|
3 |
|
4 |
@st.cache_resource
|
5 |
-
def load_model():
|
6 |
-
|
|
|
|
|
|
|
7 |
|
8 |
def main():
|
9 |
if "generated_widget_id" not in st.session_state:
|
@@ -27,6 +30,9 @@ def main():
|
|
27 |
# Create a text input field for custom prompt
|
28 |
custom_prompt = st.text_input("Enter a custom prompt:", "")
|
29 |
|
|
|
|
|
|
|
30 |
# Check if a default prompt is selected
|
31 |
if default_prompt:
|
32 |
user_input = default_prompt
|
@@ -36,7 +42,7 @@ def main():
|
|
36 |
# Check if the user input is empty
|
37 |
if user_input:
|
38 |
# Load the Hugging Face model
|
39 |
-
generator = load_model()
|
40 |
|
41 |
# Display 'Generating Output' message
|
42 |
output_placeholder = st.empty()
|
@@ -44,7 +50,7 @@ def main():
|
|
44 |
st.write("Generating Output...")
|
45 |
|
46 |
# Generate text based on user input
|
47 |
-
generated_text = generator(user_input, max_length=
|
48 |
|
49 |
# Clear 'Generating Output' message and display the generated text
|
50 |
output_placeholder.empty()
|
|
|
2 |
from transformers import pipeline
|
3 |
|
4 |
@st.cache_resource
|
5 |
+
def load_model(advanced_mode):
|
6 |
+
if advanced_mode:
|
7 |
+
return pipeline("text-generation", model="PeterBrendan/pbjsGPT2v2")
|
8 |
+
else:
|
9 |
+
return pipeline("text-generation", model="PeterBrendan/pbjs_gpt2")
|
10 |
|
11 |
def main():
|
12 |
if "generated_widget_id" not in st.session_state:
|
|
|
30 |
# Create a text input field for custom prompt
|
31 |
custom_prompt = st.text_input("Enter a custom prompt:", "")
|
32 |
|
33 |
+
# Create a checkbox for advanced mode
|
34 |
+
advanced_mode = st.checkbox("Advanced Mode")
|
35 |
+
|
36 |
# Check if a default prompt is selected
|
37 |
if default_prompt:
|
38 |
user_input = default_prompt
|
|
|
42 |
# Check if the user input is empty
|
43 |
if user_input:
|
44 |
# Load the Hugging Face model
|
45 |
+
generator = load_model(advanced_mode)
|
46 |
|
47 |
# Display 'Generating Output' message
|
48 |
output_placeholder = st.empty()
|
|
|
50 |
st.write("Generating Output...")
|
51 |
|
52 |
# Generate text based on user input
|
53 |
+
generated_text = generator(user_input, max_length=1500, num_return_sequences=1)[0]["generated_text"]
|
54 |
|
55 |
# Clear 'Generating Output' message and display the generated text
|
56 |
output_placeholder.empty()
|