CR7CAD commited on
Commit
b038974
·
verified ·
1 Parent(s): 44b14a6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +101 -51
app.py CHANGED
@@ -4,6 +4,8 @@ from PIL import Image
4
  import io
5
  from gtts import gTTS
6
  import time
 
 
7
 
8
  # Set page title
9
  st.set_page_config(page_title="Kids Story Generator")
@@ -12,73 +14,121 @@ st.set_page_config(page_title="Kids Story Generator")
12
  st.title("Kids Story Generator")
13
  st.write("Upload a picture and let's create a magical story!")
14
 
15
- # Initialize models
16
  @st.cache_resource
17
  def load_models():
18
- image_to_text = pipeline("image-to-text", model="microsoft/git-base-coco")
19
- story_generator = pipeline("text-generation", model="gpt2")
20
- return image_to_text, story_generator
 
 
 
21
 
22
- image_to_text, story_generator = load_models()
 
 
 
 
 
 
23
 
24
  # Function to generate caption from image
25
  def generate_caption(image):
26
- caption = image_to_text(image)[0]['generated_text']
27
- return caption
 
 
 
 
 
 
28
 
29
  # Function to generate story from caption (less than 100 words)
30
  def generate_story(caption):
31
- prompt = f"Once upon a time, {caption} "
32
-
33
- # Set max_length to control story length (approximately 100 words)
34
- # Typical English word is ~5 characters, so ~500 characters ≈ 100 words
35
- story = story_generator(prompt, max_length=100, do_sample=True)[0]['generated_text']
36
-
37
- # Ensure story doesn't exceed 100 words
38
- words = story.split()
39
- if len(words) > 100:
40
- words = words[:100]
41
- story = " ".join(words)
42
- # Add period to the end if needed
43
- if not story.endswith(('.', '!', '?')):
44
- story += '.'
45
-
46
- return story
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
 
48
  # Function to convert text to speech
49
  def text_to_speech(text):
50
- tts = gTTS(text=text, lang='en', slow=False)
51
- audio_file = "story_audio.mp3"
52
- tts.save(audio_file)
53
- return audio_file
 
 
 
54
 
55
  # File uploader
56
  uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
57
 
58
- if uploaded_file is not None:
59
  # Display the uploaded image
60
- image = Image.open(uploaded_file)
61
- st.image(image, caption='Uploaded Image', use_container_width=True)
62
-
63
- # Generate button
64
- if st.button("Generate Story"):
65
- with st.spinner("Generating your story..."):
66
- # Generate caption
67
- caption = generate_caption(image)
68
- st.write("Image caption:", caption)
69
-
70
- # Generate story
71
- story = generate_story(caption)
72
- word_count = len(story.split())
73
- st.write(f"### Your Story ({word_count} words)")
74
- st.write(story)
75
-
76
- # Generate audio
77
- audio_file = text_to_speech(story)
78
-
79
- # Display audio
80
- st.write("### Listen to your story")
81
- st.audio(audio_file)
 
 
 
 
 
 
 
 
 
 
82
 
83
  st.markdown("---")
84
- st.write("Created for ISOM5240 Assignment")
 
4
  import io
5
  from gtts import gTTS
6
  import time
7
+ import os
8
+ import traceback
9
 
10
  # Set page title
11
  st.set_page_config(page_title="Kids Story Generator")
 
14
  st.title("Kids Story Generator")
15
  st.write("Upload a picture and let's create a magical story!")
16
 
17
+ # Initialize models with better error handling
18
  @st.cache_resource
19
  def load_models():
20
+ try:
21
+ image_to_text = pipeline("image-to-text", model="microsoft/git-base-coco")
22
+ story_generator = pipeline("text-generation", model="gpt2")
23
+ return image_to_text, story_generator, None
24
+ except Exception as e:
25
+ return None, None, str(e)
26
 
27
+ # Load models with status indicator
28
+ with st.spinner("Loading models..."):
29
+ image_to_text, story_generator, error = load_models()
30
+ if error:
31
+ st.error(f"Failed to load models: {error}")
32
+ else:
33
+ st.success("Models loaded successfully!")
34
 
35
  # Function to generate caption from image
36
  def generate_caption(image):
37
+ try:
38
+ result = image_to_text(image)
39
+ if result and len(result) > 0:
40
+ caption = result[0]['generated_text']
41
+ return caption, None
42
+ return "An interesting image", "No caption generated"
43
+ except Exception as e:
44
+ return "An interesting image", str(e)
45
 
46
  # Function to generate story from caption (less than 100 words)
47
  def generate_story(caption):
48
+ try:
49
+ prompt = f"Once upon a time, {caption} "
50
+
51
+ # Debug output
52
+ st.write(f"Prompt: {prompt}")
53
+
54
+ # Generate with increased timeout and temperature
55
+ result = story_generator(
56
+ prompt,
57
+ max_length=100,
58
+ do_sample=True,
59
+ temperature=0.9,
60
+ top_p=0.95
61
+ )
62
+
63
+ # Debug output
64
+ st.write(f"Generation result: {result}")
65
+
66
+ if result and len(result) > 0:
67
+ story = result[0]['generated_text']
68
+
69
+ # Ensure story doesn't exceed 100 words
70
+ words = story.split()
71
+ if len(words) > 100:
72
+ words = words[:100]
73
+ story = " ".join(words)
74
+ # Add period to the end if needed
75
+ if not story.endswith(('.', '!', '?')):
76
+ story += '.'
77
+
78
+ return story, None
79
+ return "Story generation failed.", "No story generated"
80
+ except Exception as e:
81
+ st.error(f"Error in story generation: {str(e)}")
82
+ st.error(traceback.format_exc())
83
+ return "Once upon a time... (Story generation failed)", str(e)
84
 
85
  # Function to convert text to speech
86
  def text_to_speech(text):
87
+ try:
88
+ tts = gTTS(text=text, lang='en', slow=False)
89
+ audio_file = "story_audio.mp3"
90
+ tts.save(audio_file)
91
+ return audio_file, None
92
+ except Exception as e:
93
+ return None, str(e)
94
 
95
  # File uploader
96
  uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
97
 
98
+ if uploaded_file is not None and image_to_text is not None and story_generator is not None:
99
  # Display the uploaded image
100
+ try:
101
+ image = Image.open(uploaded_file)
102
+ st.image(image, caption='Uploaded Image', use_container_width=True)
103
+
104
+ # Generate button
105
+ if st.button("Generate Story"):
106
+ with st.spinner("Generating your story..."):
107
+ # Generate caption
108
+ caption, caption_error = generate_caption(image)
109
+ if caption_error:
110
+ st.warning(f"Caption generation issue: {caption_error}")
111
+ st.write("Image caption:", caption)
112
+
113
+ # Generate story
114
+ story, story_error = generate_story(caption)
115
+ if story_error:
116
+ st.warning(f"Story generation issue: {story_error}")
117
+ word_count = len(story.split())
118
+ st.write(f"### Your Story ({word_count} words)")
119
+ st.write(story)
120
+
121
+ # Generate audio
122
+ audio_file, audio_error = text_to_speech(story)
123
+ if audio_error:
124
+ st.warning(f"Audio generation issue: {audio_error}")
125
+ else:
126
+ # Display audio
127
+ st.write("### Listen to your story")
128
+ st.audio(audio_file)
129
+ except Exception as e:
130
+ st.error(f"Error processing image: {str(e)}")
131
+ st.error(traceback.format_exc())
132
 
133
  st.markdown("---")
134
+ st.write("Created for ISOM5240_Assignment1")