File size: 4,278 Bytes
e0b2c12
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cde5c47
6b5aec2
e0b2c12
 
 
cde5c47
 
 
 
 
 
 
 
 
 
 
 
 
 
e0b2c12
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cde5c47
 
e0b2c12
cde5c47
e0b2c12
 
 
 
 
 
 
 
 
 
 
 
cde5c47
 
 
 
 
 
 
 
 
 
e0b2c12
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import streamlit as st
import requests
from PIL import Image
from io import BytesIO
from gradio_client import Client
import json
import os
def generate(outline, characters, settings ):
    prompt = f"""Hello! I would like to request a 4-paragraph and 700-word per paragraph story and a cover image prompt for sd3 in JSON format described later in the prompt with the following detailed outline:\n\n{outline}\n\nCharacters: {characters}\n\nSettings: {settings}\n\nPlease generate the story with the following detailed JSON format: p1, p2, p3, p4: Keys for story paragraphs; title: Key for story title; prompt: "Mother Nature, a gentle figure surrounded by flowers, stands over a steaming cauldron in a vibrant forest, stirring its contents with a wooden spoon". Please do not include any other text in the output. Thank you. Only the JSON is needed or it will break the whole system and make us lose 10 million dollars. Please don't say 'Full response: Here is the requested output in JSON format:' or 'Here is the full response.' Only JSON. If you give plain text, it will not work and count as an error and we will lose customers. Please do not give text. You are not ChatGPT. Don't say 'Here is the full JSON.' You are not an assistant; you are used by an AI. Thank you.\n\n"""

    client = Client("Be-Bo/llama-3-chatbot_70b")
    hikaye = client.predict(
        message=prompt,
        api_name="/chat"
    )
    
    # Debug print to check hikaye
    print("Debug: Generated hikaye:", hikaye)
    
    return hikaye

def cover(prompts):
    api_key = st.secrets['apikey']
    model = "mann-e/Mann-E_Turbo"
    headers = {"Authorization": f"Bearer {api_key}"}
    api_url = f"https://api-inference.huggingface.co/models/{model}"

    images = []
    for prompt in prompts:
        data = {"inputs": prompt}
        response = requests.post(api_url, headers=headers, json=data)

        if 'image' in response.headers.get('content-type', '').lower():
            image = Image.open(BytesIO(response.content))
            images.append(image)
        else:
            st.error(f"Failed to fetch image for prompt: {prompt}")
            images.append(None)

    return images

def parse_story_response(response):
    print("Debug: Raw response from API:", response)
    
    if not response:
        print("Debug: Response is empty or None.")
        return None, None, None, None, None, None
    
    title = response.get('title', '')
    p1 = response.get('p1', '')
    p2 = response.get('p2', '')
    p3 = response.get('p3', '')
    p4 = response.get('p4', '')
    prompt = response.get('prompt', '')
    
    return title, p1, p2, p3, p4, prompt

st.title('Story AI By Ozi')

characters = st.text_area(label="Characters")
outline = st.text_area(label="Story Outline")
settings = st.text_area(label="Setting")

if st.button(label="Generate"):
    with st.spinner('Generating story and cover images...'):
        hikaye = generate(outline, characters, settings)
        print("Debug: Story generation response:", hikaye)

        if hikaye:
            try:
                hikaye_json = json.loads(hikaye)
            except json.JSONDecodeError as e:
                st.error(f"Failed to parse JSON response: {e}")
                st.stop()

            title, p1, p2, p3, p4, prmt = parse_story_response(hikaye_json)

            if title and p1 and p2 and p3 and p4:
                st.markdown(f'### {title}')

                # Prepare prompts for each paragraph
                prompts = [prmt, f"Image for paragraph 1: {p1}", f"Image for paragraph 2: {p2}", f"Image for paragraph 3: {p3}", f"Image for paragraph 4: {p4}"]

                # Generate and display images
                images = cover(prompts)
                for i, image in enumerate(images):
                    if image:
                        st.image(image, caption=prompts[i])
                    else:
                        st.error(f"Failed to generate image {i+1}.")

                # Display paragraphs
                st.markdown(f'''
                {p1}
                {p2}
                {p3}
                {p4}
                ''')

            else:
                st.error("Failed to generate or parse story.")
        else:
            st.error("No story data received.")