CR7CAD commited on
Commit
1fb1e8e
·
verified ·
1 Parent(s): c99f8fd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +51 -47
app.py CHANGED
@@ -5,22 +5,21 @@ import io
5
  from gtts import gTTS
6
  import time
7
  import os
 
8
 
9
  # Set page title
10
- st.set_page_config(page_title="Story Generator for Kids")
11
 
12
  # Title and introduction
13
- st.title("Story Generator for Kids")
14
  st.write("Upload a picture and let's create a magical story!")
15
 
16
- # Initialize models with better performance
17
  @st.cache_resource
18
  def load_models():
19
  try:
20
- # Use smaller, faster models
21
  image_to_text = pipeline("image-to-text", model="microsoft/git-base-coco")
22
- # Use distilgpt2 which is smaller and faster than gpt2
23
- story_generator = pipeline("text-generation", model="distilgpt2")
24
  return image_to_text, story_generator, None
25
  except Exception as e:
26
  return None, None, str(e)
@@ -30,6 +29,8 @@ with st.spinner("Loading models..."):
30
  image_to_text, story_generator, error = load_models()
31
  if error:
32
  st.error(f"Failed to load models: {error}")
 
 
33
 
34
  # Function to generate caption from image
35
  def generate_caption(image):
@@ -38,7 +39,7 @@ def generate_caption(image):
38
  if result and len(result) > 0:
39
  caption = result[0]['generated_text']
40
  return caption, None
41
- return "An interesting image", None
42
  except Exception as e:
43
  return "An interesting image", str(e)
44
 
@@ -47,16 +48,21 @@ def generate_story(caption):
47
  try:
48
  prompt = f"Once upon a time, {caption} "
49
 
50
- # Use more efficient parameters for faster generation
 
 
 
51
  result = story_generator(
52
  prompt,
53
- max_length=50, # Reduce max length for faster generation
54
  do_sample=True,
55
- temperature=0.7, # Lower temperature for faster results
56
- top_p=0.9,
57
- num_return_sequences=1
58
  )
59
 
 
 
 
60
  if result and len(result) > 0:
61
  story = result[0]['generated_text']
62
 
@@ -64,15 +70,17 @@ def generate_story(caption):
64
  words = story.split()
65
  if len(words) > 100:
66
  words = words[:100]
67
- story = " ".join(words)
68
- # Add period to the end if needed
69
- if not story.endswith(('.', '!', '?')):
70
- story += '.'
71
 
72
  return story, None
73
- return "Story generation failed.", None
74
  except Exception as e:
75
- return f"Once upon a time, {caption}. The end.", str(e)
 
 
76
 
77
  # Function to convert text to speech
78
  def text_to_speech(text):
@@ -95,36 +103,32 @@ if uploaded_file is not None and image_to_text is not None and story_generator i
95
 
96
  # Generate button
97
  if st.button("Generate Story"):
98
- # Use progress indicator instead of spinner for better UX
99
- progress_bar = st.progress(0)
100
-
101
- # Generate caption
102
- progress_bar.progress(25)
103
- caption, caption_error = generate_caption(image)
104
- st.write("Image caption:", caption)
105
-
106
- # Generate story
107
- progress_bar.progress(50)
108
- story, story_error = generate_story(caption)
109
- word_count = len(story.split())
110
- st.write(f"### Your Story ({word_count} words)")
111
- st.write(story)
112
-
113
- # Generate audio
114
- progress_bar.progress(75)
115
- audio_file, audio_error = text_to_speech(story)
116
-
117
- if audio_file:
118
- # Display audio
119
- progress_bar.progress(100)
120
- st.write("### Listen to your story")
121
- st.audio(audio_file)
122
-
123
- # Clear progress bar when done
124
- progress_bar.empty()
125
-
126
  except Exception as e:
127
- st.error("An error occurred. Please try again with a different image.")
 
128
 
129
  st.markdown("---")
130
  st.write("Created for ISOM5240 Assignment 1")
 
5
  from gtts import gTTS
6
  import time
7
  import os
8
+ import traceback
9
 
10
  # Set page title
11
+ st.set_page_config(page_title="Image to Audio Story Generator")
12
 
13
  # Title and introduction
14
+ st.title("Image to Audio Story Generator")
15
  st.write("Upload a picture and let's create a magical story!")
16
 
17
+ # Initialize models with better error handling
18
  @st.cache_resource
19
  def load_models():
20
  try:
 
21
  image_to_text = pipeline("image-to-text", model="microsoft/git-base-coco")
22
+ story_generator = pipeline("text-generation", model="gpt2")
 
23
  return image_to_text, story_generator, None
24
  except Exception as e:
25
  return None, None, str(e)
 
29
  image_to_text, story_generator, error = load_models()
30
  if error:
31
  st.error(f"Failed to load models: {error}")
32
+ else:
33
+ st.success("Models loaded successfully!")
34
 
35
  # Function to generate caption from image
36
  def generate_caption(image):
 
39
  if result and len(result) > 0:
40
  caption = result[0]['generated_text']
41
  return caption, None
42
+ return "An interesting image", "No caption generated"
43
  except Exception as e:
44
  return "An interesting image", str(e)
45
 
 
48
  try:
49
  prompt = f"Once upon a time, {caption} "
50
 
51
+ # Debug output
52
+ st.write(f"Prompt: {prompt}")
53
+
54
+ # Generate with increased timeout and temperature
55
  result = story_generator(
56
  prompt,
57
+ max_length=100,
58
  do_sample=True,
59
+ temperature=0.9,
60
+ top_p=0.95
 
61
  )
62
 
63
+ # Debug output
64
+ st.write(f"Generation result: {result}")
65
+
66
  if result and len(result) > 0:
67
  story = result[0]['generated_text']
68
 
 
70
  words = story.split()
71
  if len(words) > 100:
72
  words = words[:100]
73
+ story = " ".join(words)
74
+ # Add period to the end if needed
75
+ if not story.endswith(('.', '!', '?')):
76
+ story += '.'
77
 
78
  return story, None
79
+ return "Story generation failed.", "No story generated"
80
  except Exception as e:
81
+ st.error(f"Error in story generation: {str(e)}")
82
+ st.error(traceback.format_exc())
83
+ return "Once upon a time... (Story generation failed)", str(e)
84
 
85
  # Function to convert text to speech
86
  def text_to_speech(text):
 
103
 
104
  # Generate button
105
  if st.button("Generate Story"):
106
+ with st.spinner("Generating your story..."):
107
+ # Generate caption
108
+ caption, caption_error = generate_caption(image)
109
+ if caption_error:
110
+ st.warning(f"Caption generation issue: {caption_error}")
111
+ st.write("Image caption:", caption)
112
+
113
+ # Generate story
114
+ story, story_error = generate_story(caption)
115
+ if story_error:
116
+ st.warning(f"Story generation issue: {story_error}")
117
+ word_count = len(story.split())
118
+ st.write(f"### Your Story ({word_count} words)")
119
+ st.write(story)
120
+
121
+ # Generate audio
122
+ audio_file, audio_error = text_to_speech(story)
123
+ if audio_error:
124
+ st.warning(f"Audio generation issue: {audio_error}")
125
+ else:
126
+ # Display audio
127
+ st.write("### Listen to your story")
128
+ st.audio(audio_file)
 
 
 
 
 
129
  except Exception as e:
130
+ st.error(f"Error processing image: {str(e)}")
131
+ st.error(traceback.format_exc())
132
 
133
  st.markdown("---")
134
  st.write("Created for ISOM5240 Assignment 1")