mayf commited on
Commit
6949ffc
·
verified ·
1 Parent(s): c876f7b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +33 -32
app.py CHANGED
@@ -5,34 +5,34 @@ from PIL import Image
5
  from transformers import pipeline
6
  from gtts import gTTS
7
  import tempfile
 
 
 
 
8
 
9
  # —––––––– Page Setup —–––––––
10
  st.set_page_config(page_title="Magic Story Generator", layout="centered")
11
  st.title("📖✨ Turn Images into Children's Stories")
12
 
13
- # —––––––– Load Pipelines (cached) —–––––––
14
  @st.cache_resource(show_spinner=False)
15
- def load_pipelines():
16
- # 1) Image captioning pipeline
17
  captioner = pipeline(
18
- task="image-to-text",
19
- model="Salesforce/blip-image-captioning-base",
20
- device=-1
21
  )
22
 
23
- # 2) Story generation pipeline using verified model
24
- storyteller = pipeline(
25
- task="text2text-generation",
26
- model="laxya007/story-generator-t5-small",
27
- tokenizer="t5-small",
28
- device=-1,
29
- max_length=200,
30
- temperature=0.7,
31
- do_sample=True
32
  )
33
  return captioner, storyteller
34
 
35
- captioner, storyteller = load_pipelines()
36
 
37
  # —––––––– Main App —–––––––
38
  uploaded = st.file_uploader("Upload an image:", type=["jpg", "png", "jpeg"])
@@ -43,29 +43,31 @@ if uploaded:
43
  # Generate caption
44
  with st.spinner("🔍 Generating caption..."):
45
  cap = captioner(img)
46
- caption = cap[0].get("generated_text", "").strip()
47
- if not caption:
48
- st.error("😢 Couldn't understand this image. Try another one!")
49
- st.stop()
50
  st.success(f"**Caption:** {caption}")
51
 
52
  # Generate story
53
- prompt = f"generate story: {caption}"
54
- with st.spinner("📝 Writing story..."):
 
 
 
 
55
  start = time.time()
56
- story = storyteller(prompt)[0]['generated_text']
 
 
 
 
 
 
57
  gen_time = time.time() - start
 
58
  st.text(f"⏱ Generated in {gen_time:.1f}s")
59
 
60
- # Format story output
61
- story = story.replace("<pad>", "").replace("</s>", "").strip()
62
- if story.startswith("generate story:"):
63
- story = story[15:].strip()
64
 
65
- # Word limit enforcement
66
- words = story.split()
67
- story = " ".join(words[:100]) if len(words) > 100 else story
68
-
69
  # Display story
70
  st.subheader("📚 Your Magical Story")
71
  st.write(story)
@@ -82,4 +84,3 @@ if uploaded:
82
 
83
  # Footer
84
  st.markdown("---\n*Made with ❤️ by your friendly story wizard*")
85
-
 
5
  from transformers import pipeline
6
  from gtts import gTTS
7
  import tempfile
8
+ from llama_cpp import Llama
9
+
10
+ # First install required package:
11
+ # pip install llama-cpp-python
12
 
13
  # —––––––– Page Setup —–––––––
14
  st.set_page_config(page_title="Magic Story Generator", layout="centered")
15
  st.title("📖✨ Turn Images into Children's Stories")
16
 
17
+ # —––––––– Load Models (cached) —–––––––
18
  @st.cache_resource(show_spinner=False)
19
+ def load_models():
20
+ # 1) Image captioning model
21
  captioner = pipeline(
22
+ "image-to-text",
23
+ model="Salesforce/blip-image-captioning-base"
 
24
  )
25
 
26
+ # 2) GGUF Story Model
27
+ storyteller = Llama(
28
+ model_path="DavidAU/L3-Grand-Story-Darkness-MOE-4X8-24.9B-e32-GGUF",
29
+ n_ctx=2048,
30
+ n_threads=4,
31
+ n_gpu_layers=0 # Set based on your GPU capacity
 
 
 
32
  )
33
  return captioner, storyteller
34
 
35
+ captioner, storyteller = load_models()
36
 
37
  # —––––––– Main App —–––––––
38
  uploaded = st.file_uploader("Upload an image:", type=["jpg", "png", "jpeg"])
 
43
  # Generate caption
44
  with st.spinner("🔍 Generating caption..."):
45
  cap = captioner(img)
46
+ caption = cap[0]['generated_text']
 
 
 
47
  st.success(f"**Caption:** {caption}")
48
 
49
  # Generate story
50
+ prompt = f"""Below is an image description. Write a children's story based on it.
51
+
52
+ Image Description: {caption}
53
+ Story:"""
54
+
55
+ with st.spinner("📝 Crafting magical story..."):
56
  start = time.time()
57
+ output = storyteller(
58
+ prompt=prompt,
59
+ max_tokens=500,
60
+ temperature=0.7,
61
+ top_p=0.9,
62
+ repeat_penalty=1.1
63
+ )
64
  gen_time = time.time() - start
65
+ story = output['choices'][0]['text'].strip()
66
  st.text(f"⏱ Generated in {gen_time:.1f}s")
67
 
68
+ # Post-process story
69
+ story = story.split("###")[0].strip() # Remove any trailing artifacts
 
 
70
 
 
 
 
 
71
  # Display story
72
  st.subheader("📚 Your Magical Story")
73
  st.write(story)
 
84
 
85
  # Footer
86
  st.markdown("---\n*Made with ❤️ by your friendly story wizard*")