dtm95 commited on
Commit
f7f2eec
Β·
verified Β·
1 Parent(s): 635f631

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -24
app.py CHANGED
@@ -1,32 +1,22 @@
1
  import streamlit as st
2
- from transformers import AutoModelForCausalLM, AutoTokenizer
3
- import torch
4
 
5
  # Set up the Streamlit app
6
  st.title("πŸ§šβ€β™€οΈ Magic Story Buddy πŸ“š")
7
  st.markdown("Let's create a magical story just for you!")
8
 
9
- # Initialize the model
10
- @st.cache(allow_output_mutation=True)
11
- def load_model():
12
- model_name = "blockblockblock/Young-Children-Storyteller-Mistral-7B-bpw6"
13
- try:
14
- # Load model and tokenizer
15
- tokenizer = AutoTokenizer.from_pretrained(model_name)
16
- model = AutoModelForCausalLM.from_pretrained(model_name)
17
 
18
- # Get configuration
19
- config = AutoConfig.from_pretrained(model_name)
20
- model_type = config.model_type
21
 
22
- print(f"Loaded model type: {model_type}")
23
 
24
- return model, tokenizer
25
-
26
- except Exception as e:
27
- st.error(f"Error loading model: {e}")
28
-
29
- model, tokenizer = load_model()
30
 
31
  # User input
32
  child_name = st.text_input("What's your name, young storyteller?")
@@ -38,9 +28,9 @@ story_length = st.slider("How long should the story be?", 50, 200, 100)
38
  include_moral = st.checkbox("Include a moral lesson?")
39
 
40
  def generate_story(prompt, max_length=500):
41
- inputs = tokenizer(prompt, return_tensors="pt")
42
- outputs = model.generate(**inputs, max_length=max_length, num_return_sequences=1, do_sample=True, temperature=0.7)
43
- return tokenizer.decode(outputs[0], skip_special_tokens=True)
44
 
45
  if st.button("Create My Story!"):
46
  if child_name and story_theme:
@@ -51,7 +41,6 @@ if st.button("Create My Story!"):
51
  - Length: About {story_length} words
52
  - Audience: Children aged 5-10
53
  - Tone: Friendly, educational, and imaginative
54
-
55
  Story:
56
  Once upon a time, in a {story_theme.lower()}, there was a brave child named {child_name}. """
57
 
 
1
  import streamlit as st
2
+ from transformers import pipeline
 
3
 
4
  # Set up the Streamlit app
5
  st.title("πŸ§šβ€β™€οΈ Magic Story Buddy πŸ“š")
6
  st.markdown("Let's create a magical story just for you!")
7
 
8
+ def initialize_model():
9
+ # Load the tokenizer and model
10
+ tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neo-2.7B")
11
+ model = AutoModelForCausalLM.from_pretrained("EleutherAI/gpt-neo-2.7B")
 
 
 
 
12
 
13
+ # Create the pipeline for text generation
14
+ text_generator = pipeline("text-generation", model=model, tokenizer=tokenizer)
 
15
 
16
+ return text_generator
17
 
18
+ # Initialize the model and pipeline
19
+ pipe = initialize_model()
 
 
 
 
20
 
21
  # User input
22
  child_name = st.text_input("What's your name, young storyteller?")
 
28
  include_moral = st.checkbox("Include a moral lesson?")
29
 
30
  def generate_story(prompt, max_length=500):
31
+ messages = [{"role": "user", "content": prompt}]
32
+ result = pipe(messages, max_length=max_length, do_sample=True, temperature=0.7)
33
+ return result[0]['generated_text']
34
 
35
  if st.button("Create My Story!"):
36
  if child_name and story_theme:
 
41
  - Length: About {story_length} words
42
  - Audience: Children aged 5-10
43
  - Tone: Friendly, educational, and imaginative
 
44
  Story:
45
  Once upon a time, in a {story_theme.lower()}, there was a brave child named {child_name}. """
46