File size: 4,057 Bytes
555b62c
38fcee7
 
555b62c
 
 
40047db
555b62c
a9fb110
 
 
 
 
 
 
 
 
 
40047db
555b62c
 
40047db
 
 
 
 
a9fb110
 
38fcee7
 
 
a9fb110
40047db
a9fb110
555b62c
 
40047db
 
 
555b62c
40047db
 
 
 
 
 
 
 
 
38fcee7
 
40047db
 
555b62c
 
 
40047db
 
 
 
 
 
 
 
 
 
38fcee7
40047db
 
 
 
 
 
 
 
555b62c
 
40047db
 
 
 
 
 
 
 
 
 
38fcee7
40047db
 
 
 
 
38fcee7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
555b62c
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import streamlit as st
from PIL import Image, ImageOps

import requests

API_URL = 'https://pic-gai.up.railway.app'
# API_URL = 'http://localhost:8000'

def gallery(column, images):
    groups = []
    for i in range(0, len(images), column):
        groups.append(images[i:i+column])

    for group in groups:
        cols = st.columns(column)
        for i, image in enumerate(group):
            cols[i].image(image)

st.title('CollageAI')

# Input field for user prompt
user_prompt = st.text_area(
    "Describe the design you'd like to create:",
    placeholder="For our anniversary, I want to write a card to my partner to celebrate our love and share all the things I adore about them."
)

uploaded_images = st.file_uploader("Choose photos", accept_multiple_files=True)
if uploaded_images:
    images = [Image.open(image) for image in uploaded_images]
    images = [ImageOps.exif_transpose(image) for image in images]
    gallery(4, images)

# pick number of photos
photos_count = len(uploaded_images)

# Submit buttons for templates and stickers
generate_button = st.button('Generate')

if generate_button:
    if user_prompt:
        # Prepare the params with the user prompt
        params = {
            'prompt': user_prompt,
            'photos_count': photos_count
        }

        # remove empty params
        params = {k: v for k, v in params.items() if v is not None}

        st.markdown("---")

        # Templates
        with st.container():
            # Define the FastAPI server URL for templates
            url = f"{API_URL}/api/templates"

            with st.spinner('Generating templates...'):
                # Make a request to the FastAPI server
                response = requests.get(url, params=params)

                # Display the response in the appropriate output container
                if response.status_code == 200:
                    templates = response.json().get('result', [])
                    image_urls = [template.get('image_medium') for template in templates]
                    
                    if image_urls:
                        st.subheader('Generated templates')
                        gallery(4, image_urls[:8])
                    else:
                        st.warning('No images were generated. Please try again with a different prompt.')
                else:
                    st.error(f"Error: {response.status_code}")

        with st.container():
            # Define the FastAPI server URL for templates
            url = f"{API_URL}/api/stickers"

            with st.spinner('Generating stickers...'):
                # Make a request to the FastAPI server
                response = requests.get(url, params=params)

                # Display the response in the appropriate output container
                if response.status_code == 200:
                    stickers = response.json().get('result', [])
                    image_urls = [sticker.get('image_url') for sticker in stickers]
                    
                    if image_urls:
                        st.subheader('Generated stickers')
                        gallery(4, image_urls[:8])
                    else:
                        st.warning('No images were generated. Please try again with a different prompt.')
                else:
                    st.error(f"Error: {response.status_code}")
        
        # Keywords
        with st.container():
            # Define the FastAPI server URL for keywords
            url = f"{API_URL}/api/analyze_prompt"

            # Make a request to the FastAPI server
            response = requests.get(url, params=params)

            # Display the response in the appropriate output container
            if response.status_code == 200:
                st.subheader('Keywords based on prompt')
                keywords = response.json().get('keywords', [])
                st.write(keywords)
            else:
                st.error(f"Error: {response.status_code}")
    else:
        st.warning('Please enter a prompt before submitting.')