simonlee-cb commited on
Commit
eca8836
·
1 Parent(s): eb7c63c

feat: update UI for image captioning

Browse files
app.py CHANGED
@@ -1,10 +1,8 @@
1
  import streamlit as st
2
  from PIL import Image, ImageOps
 
3
 
4
- import requests
5
-
6
- API_URL = 'https://pic-gai.up.railway.app'
7
- # API_URL = 'http://localhost:8000'
8
 
9
  def gallery(column, images):
10
  groups = []
@@ -18,94 +16,83 @@ def gallery(column, images):
18
 
19
  st.title('CollageAI')
20
 
21
- # Input field for user prompt
22
- user_prompt = st.text_area(
23
- "Describe the design you'd like to create:",
24
- placeholder="For our anniversary, I want to write a card to my partner to celebrate our love and share all the things I adore about them."
25
- )
26
-
27
- uploaded_images = st.file_uploader("Choose photos", accept_multiple_files=True)
28
- if uploaded_images:
29
- images = [Image.open(image) for image in uploaded_images]
30
- images = [ImageOps.exif_transpose(image) for image in images]
31
- gallery(4, images)
32
-
33
- # pick number of photos
34
- photos_count = len(uploaded_images)
35
-
36
- # Submit buttons for templates and stickers
37
- generate_button = st.button('Generate')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
 
39
- if generate_button:
40
  if user_prompt:
41
- # Prepare the params with the user prompt
42
- params = {
43
- 'prompt': user_prompt,
44
- 'photos_count': photos_count
45
- }
46
-
47
- # remove empty params
48
- params = {k: v for k, v in params.items() if v is not None}
 
 
 
 
 
 
 
 
49
 
50
- st.markdown("---")
 
51
 
52
  # Templates
53
  with st.container():
54
- # Define the FastAPI server URL for templates
55
- url = f"{API_URL}/api/templates"
56
-
57
  with st.spinner('Generating templates...'):
58
- # Make a request to the FastAPI server
59
- response = requests.get(url, params=params)
60
-
61
- # Display the response in the appropriate output container
62
- if response.status_code == 200:
63
- templates = response.json().get('result', [])
64
- image_urls = [template.get('image_medium') for template in templates]
65
-
66
- if image_urls:
67
- st.subheader('Generated templates')
68
- gallery(4, image_urls[:8])
69
  else:
70
  st.warning('No images were generated. Please try again with a different prompt.')
71
- else:
72
- st.error(f"Error: {response.status_code}")
73
 
 
74
  with st.container():
75
- # Define the FastAPI server URL for templates
76
- url = f"{API_URL}/api/stickers"
77
-
78
  with st.spinner('Generating stickers...'):
79
- # Make a request to the FastAPI server
80
- response = requests.get(url, params=params)
81
-
82
- # Display the response in the appropriate output container
83
- if response.status_code == 200:
84
- stickers = response.json().get('result', [])
85
- image_urls = [sticker.get('image_url') for sticker in stickers]
86
-
87
- if image_urls:
88
- st.subheader('Generated stickers')
89
- gallery(4, image_urls[:8])
90
  else:
91
  st.warning('No images were generated. Please try again with a different prompt.')
92
- else:
93
- st.error(f"Error: {response.status_code}")
94
-
95
- # Keywords
96
- with st.container():
97
- # Define the FastAPI server URL for keywords
98
- url = f"{API_URL}/api/analyze_prompt"
99
-
100
- # Make a request to the FastAPI server
101
- response = requests.get(url, params=params)
102
-
103
- # Display the response in the appropriate output container
104
- if response.status_code == 200:
105
- st.subheader('Keywords based on prompt')
106
- keywords = response.json().get('keywords', [])
107
- st.write(keywords)
108
- else:
109
- st.error(f"Error: {response.status_code}")
110
  else:
111
  st.warning('Please enter a prompt before submitting.')
 
1
  import streamlit as st
2
  from PIL import Image, ImageOps
3
+ from internal.api import APIClient
4
 
5
+ client = APIClient("http://localhost:3000")
 
 
 
6
 
7
  def gallery(column, images):
8
  groups = []
 
16
 
17
  st.title('CollageAI')
18
 
19
+ user_images = None
20
+ user_prompt = None
21
+ uploaded_images = []
22
+ image_captions_dict = {}
23
+ submitted = False
24
+
25
+ with st.form("user_input_form"):
26
+ user_images = st.file_uploader("Choose your photos", accept_multiple_files=True)
27
+ user_prompt = st.text_area(
28
+ "Describe the design you'd like to create:",
29
+ placeholder="For our anniversary, I want to write a card to my partner to celebrate our love and share all the things I adore about them."
30
+ )
31
+ submitted = st.form_submit_button("Generate")
32
+
33
+ # Check form
34
+ if submitted:
35
+ if user_images:
36
+ with st.container():
37
+ with st.spinner('Uploading images...'):
38
+ try:
39
+ uploaded_images = client.upload_images(user_images)
40
+ except Exception as e:
41
+ st.error(f"Error uploading images: {e}")
42
+
43
+ # Display the photo gallery
44
+ st.subheader('Your photos:')
45
+ images = [Image.open(image) for image in user_images]
46
+ images = [ImageOps.exif_transpose(image) for image in images]
47
+ gallery(4, images)
48
+ else:
49
+ st.warning('Please upload at least one image before submitting.')
50
 
 
51
  if user_prompt:
52
+ if uploaded_images:
53
+ # Analysis
54
+ with st.spinner('Analyzing prompt...'):
55
+ try:
56
+ analysis = client.analyze_prompt(user_prompt, uploaded_images)
57
+ keywords = analysis.get("keywords")
58
+ captions = analysis.get("captions")
59
+ if captions:
60
+ st.subheader('Captions of your photos')
61
+ st.write(captions)
62
+
63
+ if keywords:
64
+ st.subheader('Keywords based on your photos and prompt')
65
+ st.write(keywords)
66
+ else:
67
+ st.warning('No keywords were generated. Please try again with a different prompt.')
68
 
69
+ except Exception as e:
70
+ st.error(f"Error analyzing prompt: {e}")
71
 
72
  # Templates
73
  with st.container():
 
 
 
74
  with st.spinner('Generating templates...'):
75
+ try:
76
+ template_image_urls = client.suggest_templates(user_prompt, uploaded_images)
77
+ if template_image_urls:
78
+ st.subheader('Template suggestions')
79
+ gallery(4, template_image_urls[:8])
 
 
 
 
 
 
80
  else:
81
  st.warning('No images were generated. Please try again with a different prompt.')
82
+ except Exception as e:
83
+ st.error(f"Error generating templates: {e}")
84
 
85
+ # Stickers
86
  with st.container():
 
 
 
87
  with st.spinner('Generating stickers...'):
88
+ try:
89
+ sticker_image_urls = client.suggest_stickers(user_prompt, uploaded_images)
90
+ if sticker_image_urls:
91
+ st.subheader('Stickers suggestions')
92
+ gallery(4, sticker_image_urls[:8])
 
 
 
 
 
 
93
  else:
94
  st.warning('No images were generated. Please try again with a different prompt.')
95
+ except Exception as e:
96
+ st.error(f"Error generating stickers: {e}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
  else:
98
  st.warning('Please enter a prompt before submitting.')
internal/__init__.py ADDED
File without changes
internal/__pycache__/APIClient.cpython-311.pyc ADDED
Binary file (4.7 kB). View file
 
internal/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (174 Bytes). View file
 
internal/__pycache__/api.cpython-311.pyc ADDED
Binary file (4.42 kB). View file
 
internal/api.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import requests
2
+
3
+ class APIClient():
4
+ def __init__(self, host):
5
+ self.host = host
6
+
7
+ def url_with_path(self, path):
8
+ return f"{self.host}{path}"
9
+
10
+ def upload_images(self, images):
11
+ uploaded_images = []
12
+ url = self.url_with_path("/api/upload_images")
13
+ images_metadata = ([("images", (image.name, image)) for image in images])
14
+ print("Uploading images...")
15
+ print(f"url: {url}")
16
+ print(f"metadata: {images_metadata}")
17
+ response = requests.post(url, files=images_metadata)
18
+ uploaded_images = response.json().get('result', [])
19
+ return uploaded_images
20
+
21
+ def suggest_templates(self, prompt, images):
22
+ url = self.url_with_path("/api/templates")
23
+ image_urls = [image.get('image_url') for image in images]
24
+
25
+ response = requests.get(url, json={'prompt': prompt, 'image_urls': image_urls})
26
+ templates = response.json().get('result', [])
27
+ template_image_urls = [template.get('image_medium') for template in templates]
28
+ return template_image_urls
29
+
30
+ def suggest_stickers(self, prompt, images):
31
+ url = self.url_with_path("/api/stickers")
32
+ image_urls = [image.get('image_url') for image in images]
33
+
34
+ response = requests.get(url, json={'prompt': prompt, 'image_urls': image_urls})
35
+ stickers = response.json().get('result', [])
36
+ sticker_image_urls = [sticker.get('image_url') for sticker in stickers]
37
+ return sticker_image_urls
38
+
39
+ def analyze_prompt(self, prompt, images):
40
+ url = self.url_with_path("/api/analyze_prompt")
41
+ image_urls = [image.get('image_url') for image in images]
42
+
43
+ response = requests.get(url, json={'prompt': prompt, 'image_urls': image_urls})
44
+ return response.json()