File size: 4,349 Bytes
555b62c
38fcee7
eca8836
57e1146
38fcee7
6f710fc
1c74bd7
555b62c
a9fb110
 
 
 
 
 
 
 
 
 
97b9bcc
 
 
e620492
97b9bcc
 
 
40047db
555b62c
eca8836
 
1c74bd7
 
 
eca8836
 
 
1c74bd7
 
 
 
eca8836
 
 
 
 
 
 
 
 
 
 
 
97b9bcc
 
753ec7b
 
 
 
eca8836
 
 
 
 
 
 
 
 
 
40047db
555b62c
eca8836
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40047db
eca8836
 
38fcee7
1c74bd7
 
 
 
 
 
 
 
 
 
 
 
 
40047db
 
 
eca8836
1c74bd7
eca8836
 
 
40047db
 
eca8836
 
40047db
555b62c
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import streamlit as st
from PIL import Image, ImageOps
from internal.api import APIClient
import os

client = APIClient("https://collage-ai.onrender.com")
# client = APIClient("http://localhost:3000")

def gallery(column, images):
    groups = []
    for i in range(0, len(images), column):
        groups.append(images[i:i+column])

    for group in groups:
        cols = st.columns(column)
        for i, image in enumerate(group):
            cols[i].image(image)

def resize_image(image_file):
    image = Image.open(image_file)
    resized = image.resize((224, 224))
    print(f"Resized image: {resized}")
    resized.save(image_file.name)
    return open(image_file.name, 'rb')

st.title('CollageAI')

user_images = None
user_prompt = None
uploaded_images = None
captions = []
keywords = []
submitted = False

with st.form("user_input_form"):
    user_images = st.file_uploader(
        "Choose your photos",
        accept_multiple_files=True
    )
    user_prompt = st.text_area(
        "Describe the design you'd like to create:",
        placeholder="For our anniversary, I want to write a card to my partner to celebrate our love and share all the things I adore about them."
    )
    submitted = st.form_submit_button("Generate")

# Check form
if submitted:
    if user_images:
        with st.container():
            with st.spinner('Uploading images...'):
                try:
                    resized_images = [resize_image(image) for image in user_images]
                    uploaded_images = client.upload_images(resized_images)
                    # delete resized image files
                    for image in resized_images:
                        image.close()
                        os.remove(image.name)
                except Exception as e:
                    st.error(f"Error uploading images: {e}")

            # Display the photo gallery
            st.subheader('Your photos:')
            images = [Image.open(image) for image in user_images]
            images = [ImageOps.exif_transpose(image) for image in images]
            gallery(4, images)
    else:
        st.warning('Please upload at least one image before submitting.')

    if user_prompt:
        if uploaded_images:
            # Analysis
            with st.spinner('Analyzing prompt...'):
                try:
                    analysis = client.analyze_prompt(user_prompt, uploaded_images)
                    keywords = analysis.get("keywords")
                    captions = analysis.get("captions")
                    if captions:
                        st.subheader('Captions of your photos')
                        st.write(captions)

                    if keywords:
                        st.subheader('Keywords based on your photos and prompt')
                        st.write(keywords)
                    else:
                        st.warning('No keywords were generated. Please try again with a different prompt.')

                except Exception as e:
                    st.error(f"Error analyzing prompt: {e}")

        # Stickers
        with st.container():
            with st.spinner('Generating stickers...'):
                try:
                    sticker_image_urls = client.suggest_stickers(user_prompt, captions)
                    if sticker_image_urls:
                        st.subheader('Stickers suggestions')
                        gallery(4, sticker_image_urls[:8])
                    else:
                        st.warning('No images were generated. Please try again with a different prompt.')
                except Exception as e:
                    st.error(f"Error generating stickers: {e}")

        # Templates
        with st.container():
            with st.spinner('Generating templates...'):
                try:
                    template_image_urls = client.suggest_templates(user_prompt, captions)
                    if template_image_urls:
                        st.subheader('Template suggestions')
                        gallery(4, template_image_urls[:8])
                    else:
                        st.warning('No images were generated. Please try again with a different prompt.')
                except Exception as e:
                    st.error(f"Error generating templates: {e}")

    else:
        st.warning('Please enter a prompt before submitting.')