CR7CAD commited on
Commit
c99f8fd
·
verified ·
1 Parent(s): b8a0fec

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +45 -43
app.py CHANGED
@@ -5,7 +5,6 @@ import io
5
  from gtts import gTTS
6
  import time
7
  import os
8
- import traceback
9
 
10
  # Set page title
11
  st.set_page_config(page_title="Story Generator for Kids")
@@ -14,12 +13,14 @@ st.set_page_config(page_title="Story Generator for Kids")
14
  st.title("Story Generator for Kids")
15
  st.write("Upload a picture and let's create a magical story!")
16
 
17
- # Initialize models with better error handling
18
  @st.cache_resource
19
  def load_models():
20
  try:
 
21
  image_to_text = pipeline("image-to-text", model="microsoft/git-base-coco")
22
- story_generator = pipeline("text-generation", model="gpt2")
 
23
  return image_to_text, story_generator, None
24
  except Exception as e:
25
  return None, None, str(e)
@@ -37,7 +38,7 @@ def generate_caption(image):
37
  if result and len(result) > 0:
38
  caption = result[0]['generated_text']
39
  return caption, None
40
- return "An interesting image", "No caption generated"
41
  except Exception as e:
42
  return "An interesting image", str(e)
43
 
@@ -46,13 +47,14 @@ def generate_story(caption):
46
  try:
47
  prompt = f"Once upon a time, {caption} "
48
 
49
- # Generate with increased timeout and temperature
50
  result = story_generator(
51
  prompt,
52
- max_length=100,
53
  do_sample=True,
54
- temperature=0.9,
55
- top_p=0.95
 
56
  )
57
 
58
  if result and len(result) > 0:
@@ -62,15 +64,15 @@ def generate_story(caption):
62
  words = story.split()
63
  if len(words) > 100:
64
  words = words[:100]
65
- story = " ".join(words)
66
- # Add period to the end if needed
67
- if not story.endswith(('.', '!', '?')):
68
- story += '.'
69
 
70
  return story, None
71
- return "Story generation failed.", "No story generated"
72
  except Exception as e:
73
- return "Once upon a time... (Story generation failed)", str(e)
74
 
75
  # Function to convert text to speech
76
  def text_to_speech(text):
@@ -93,36 +95,36 @@ if uploaded_file is not None and image_to_text is not None and story_generator i
93
 
94
  # Generate button
95
  if st.button("Generate Story"):
96
- with st.spinner("Generating your story..."):
97
- # Generate caption
98
- caption, caption_error = generate_caption(image)
99
- if caption_error:
100
- st.warning(f"Caption generation issue: {caption_error}", icon="⚠️")
101
-
102
- # Display the caption (without debug information)
103
- st.write("Image caption:", caption)
104
-
105
- # Generate story
106
- story, story_error = generate_story(caption)
107
- if story_error and not st.session_state.get("deployed", True):
108
- st.warning(f"Story generation issue: {story_error}", icon="⚠️")
109
-
110
- # Display the story (without debug information)
111
- word_count = len(story.split())
112
- st.write(f"### Your Story ({word_count} words)")
113
- st.write(story)
114
-
115
- # Generate audio
116
- audio_file, audio_error = text_to_speech(story)
117
- if audio_error and not st.session_state.get("deployed", True):
118
- st.warning(f"Audio generation issue: {audio_error}", icon="⚠️")
119
- elif audio_file:
120
- # Display audio
121
- st.write("### Listen to your story")
122
- st.audio(audio_file)
 
123
  except Exception as e:
124
- if not st.session_state.get("deployed", True):
125
- st.error(f"Error processing image: {str(e)}")
126
 
127
  st.markdown("---")
128
  st.write("Created for ISOM5240 Assignment 1")
 
5
  from gtts import gTTS
6
  import time
7
  import os
 
8
 
9
  # Set page title
10
  st.set_page_config(page_title="Story Generator for Kids")
 
13
  st.title("Story Generator for Kids")
14
  st.write("Upload a picture and let's create a magical story!")
15
 
16
+ # Initialize models with better performance
17
  @st.cache_resource
18
  def load_models():
19
  try:
20
+ # Use smaller, faster models
21
  image_to_text = pipeline("image-to-text", model="microsoft/git-base-coco")
22
+ # Use distilgpt2 which is smaller and faster than gpt2
23
+ story_generator = pipeline("text-generation", model="distilgpt2")
24
  return image_to_text, story_generator, None
25
  except Exception as e:
26
  return None, None, str(e)
 
38
  if result and len(result) > 0:
39
  caption = result[0]['generated_text']
40
  return caption, None
41
+ return "An interesting image", None
42
  except Exception as e:
43
  return "An interesting image", str(e)
44
 
 
47
  try:
48
  prompt = f"Once upon a time, {caption} "
49
 
50
+ # Use more efficient parameters for faster generation
51
  result = story_generator(
52
  prompt,
53
+ max_length=50, # Reduce max length for faster generation
54
  do_sample=True,
55
+ temperature=0.7, # Lower temperature for faster results
56
+ top_p=0.9,
57
+ num_return_sequences=1
58
  )
59
 
60
  if result and len(result) > 0:
 
64
  words = story.split()
65
  if len(words) > 100:
66
  words = words[:100]
67
+ story = " ".join(words)
68
+ # Add period to the end if needed
69
+ if not story.endswith(('.', '!', '?')):
70
+ story += '.'
71
 
72
  return story, None
73
+ return "Story generation failed.", None
74
  except Exception as e:
75
+ return f"Once upon a time, {caption}. The end.", str(e)
76
 
77
  # Function to convert text to speech
78
  def text_to_speech(text):
 
95
 
96
  # Generate button
97
  if st.button("Generate Story"):
98
+ # Use progress indicator instead of spinner for better UX
99
+ progress_bar = st.progress(0)
100
+
101
+ # Generate caption
102
+ progress_bar.progress(25)
103
+ caption, caption_error = generate_caption(image)
104
+ st.write("Image caption:", caption)
105
+
106
+ # Generate story
107
+ progress_bar.progress(50)
108
+ story, story_error = generate_story(caption)
109
+ word_count = len(story.split())
110
+ st.write(f"### Your Story ({word_count} words)")
111
+ st.write(story)
112
+
113
+ # Generate audio
114
+ progress_bar.progress(75)
115
+ audio_file, audio_error = text_to_speech(story)
116
+
117
+ if audio_file:
118
+ # Display audio
119
+ progress_bar.progress(100)
120
+ st.write("### Listen to your story")
121
+ st.audio(audio_file)
122
+
123
+ # Clear progress bar when done
124
+ progress_bar.empty()
125
+
126
  except Exception as e:
127
+ st.error("An error occurred. Please try again with a different image.")
 
128
 
129
  st.markdown("---")
130
  st.write("Created for ISOM5240 Assignment 1")