Amit Gazal commited on
Commit
666c963
·
1 Parent(s): e25fc54
.gradio/flagged/First Output/f0fa5bc42029b76a623f/image.webp ADDED
.gradio/flagged/Second Output/1c38f7ceb54aaa9b3c55/image.webp ADDED
.gradio/flagged/Third Output/bc5418b161ee6a0f583f/image.webp ADDED
.gradio/flagged/dataset1.csv ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ input_image,"Holiday (e.g. Christmas, New Year's, etc.)",Optional Message,First Output,Second Output,Third Output,timestamp
2
+ .gradio/flagged/input_image/aeff2e0970e0cc20c2f3/Linda-Sobolewski-Photography-Family-Session-00002-1024x683.jpg,,,.gradio/flagged/First Output/f0fa5bc42029b76a623f/image.webp,.gradio/flagged/Second Output/1c38f7ceb54aaa9b3c55/image.webp,.gradio/flagged/Third Output/bc5418b161ee6a0f583f/image.webp,2024-12-10 15:59:53.431810
.gradio/flagged/input_image/aeff2e0970e0cc20c2f3/Linda-Sobolewski-Photography-Family-Session-00002-1024x683.jpg ADDED
app.py CHANGED
@@ -4,6 +4,21 @@ import matplotlib.pyplot as plt
4
  import torch
5
  from torchvision import transforms
6
  from transformers import AutoModelForImageSegmentation
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
 
8
  model = AutoModelForImageSegmentation.from_pretrained('briaai/RMBG-2.0', trust_remote_code=True)
9
  torch.set_float32_matmul_precision(['high', 'highest'][0])
@@ -11,7 +26,51 @@ if torch.cuda.is_available():
11
  model = model.to('cuda')
12
  model.eval()
13
 
14
- def remove_background(input_image, holiday, message):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  image_size = (1024, 1024)
16
  # Transform the input image
17
  transform_image = transforms.Compose([
@@ -41,15 +100,72 @@ def remove_background(input_image, holiday, message):
41
  inverted_mask = Image.eval(mask, lambda x: 255 - x) # Invert the mask
42
  only_background_image.putalpha(inverted_mask)
43
 
44
- first_output_image = result_image
45
- second_output_image = only_background_image
46
- third_output_image = result_image
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
 
48
  return first_output_image, second_output_image, third_output_image
49
 
50
  # Replace the demo interface
51
  demo = gr.Interface(
52
- fn=remove_background,
53
  inputs=[
54
  gr.Image(type="pil"),
55
  gr.Text(label="Holiday (e.g. Christmas, New Year's, etc.)"),
 
4
  import torch
5
  from torchvision import transforms
6
  from transformers import AutoModelForImageSegmentation
7
+ from openai import OpenAI
8
+ import os
9
+ import base64
10
+ import io
11
+ import requests
12
+ import numpy as np
13
+ from scipy import ndimage
14
+
15
+ IDEOGRAM_API_KEY = os.getenv('IDEOGRAM_API_KEY')
16
+ IDEOGRAM_URL = "https://api.ideogram.ai/edit"
17
+
18
+ client = OpenAI(api_key=os.getenv('OPENAI_API_KEY'))
19
+ # Constants should be in UPPERCASE
20
+ GPT_MODEL_NAME = "gpt-4o"
21
+ GPT_MAX_TOKENS = 500
22
 
23
  model = AutoModelForImageSegmentation.from_pretrained('briaai/RMBG-2.0', trust_remote_code=True)
24
  torch.set_float32_matmul_precision(['high', 'highest'][0])
 
26
  model = model.to('cuda')
27
  model.eval()
28
 
29
+ GPT_PROMPT = '''
30
+ I work with a tool that knows how to edit backgrounds.
31
+ I want your help with prompt.
32
+ I want to adjust their background to be in a christmas vibes.
33
+ For example, if you see a tree there, cover it in snow,
34
+ add christmas lights to some of the stuff in the background, maybe add a few elements like christmas tree, but take into considration the perspective and the logic of the image.
35
+ '''
36
+
37
+ def image_to_prompt(image: str) -> tuple[str, str]:
38
+ base64_image = encode_image(image)
39
+
40
+ messages = [{
41
+ "role": "user",
42
+ "content": [
43
+ {"type": "text", "text": GPT_PROMPT},
44
+ {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{base64_image}"}}
45
+ ]
46
+ }]
47
+
48
+ response = client.chat.completions.create(
49
+ model=GPT_MODEL_NAME,
50
+ messages=messages,
51
+ max_tokens=GPT_MAX_TOKENS
52
+ )
53
+
54
+ full_response = response.choices[0].message.content
55
+ return full_response
56
+
57
+ def encode_image(image: Image.Image) -> str:
58
+ """Convert a PIL Image to base64 encoded string.
59
+
60
+ Args:
61
+ image (PIL.Image.Image): The PIL Image to encode
62
+
63
+ Returns:
64
+ str: Base64 encoded image string
65
+ """
66
+ # Create a temporary buffer to save the image
67
+ buffer = io.BytesIO()
68
+ # Save the image as PNG to the buffer
69
+ image.save(buffer, format='PNG')
70
+ # Get the bytes from the buffer and encode to base64
71
+ return base64.b64encode(buffer.getvalue()).decode('utf-8')
72
+
73
+ def remove_background(input_image):
74
  image_size = (1024, 1024)
75
  # Transform the input image
76
  transform_image = transforms.Compose([
 
100
  inverted_mask = Image.eval(mask, lambda x: 255 - x) # Invert the mask
101
  only_background_image.putalpha(inverted_mask)
102
 
103
+ return result_image, only_background_image, mask
104
+
105
+ def modify_background(image: Image.Image, mask: Image.Image, prompt: str) -> Image.Image:
106
+ # Convert PIL images to bytes
107
+ image_buffer = io.BytesIO()
108
+ image.save(image_buffer, format='PNG')
109
+ image_bytes = image_buffer.getvalue()
110
+
111
+ mask_buffer = io.BytesIO()
112
+ mask.save(mask_buffer, format='PNG')
113
+ mask_bytes = mask_buffer.getvalue()
114
+
115
+ # Create the files dictionary with actual bytes data
116
+ files = {
117
+ "image_file": ("image.png", image_bytes, "image/png"),
118
+ "mask": ("mask.png", mask_bytes, "image/png") # You might want to send a different mask file
119
+ }
120
+
121
+ payload = {
122
+ "prompt": prompt, # Use the actual prompt parameter
123
+ "model": "V_2",
124
+ "magic_prompt_option": "ON",
125
+ "num_images": 1,
126
+ "style_type": "REALISTIC"
127
+ }
128
+ headers = {"Api-Key": IDEOGRAM_API_KEY}
129
+
130
+ response = requests.post(IDEOGRAM_URL, data=payload, files=files, headers=headers)
131
+
132
+ if response.status_code == 200:
133
+ # Assuming the API returns an image in the response
134
+ response_data = response.json()
135
+ # You'll need to handle the response according to Ideogram's API specification
136
+ # This is a placeholder - adjust according to actual API response format
137
+ result_image_url = response_data.get('data')[0].get('url')
138
+ if result_image_url:
139
+ result_response = requests.get(result_image_url)
140
+ return Image.open(io.BytesIO(result_response.content))
141
+
142
+ raise Exception(f"Failed to modify background: {response.text}")
143
+
144
+ def dilate_mask(mask: Image.Image) -> Image.Image:
145
+ # Convert mask to numpy array
146
+ mask_array = np.array(mask)
147
+
148
+ # Apply maximum filter using scipy.ndimage
149
+ dilated_mask = ndimage.maximum_filter(mask_array, size=20)
150
+
151
+ # Convert back to PIL Image
152
+ return Image.fromarray(dilated_mask.astype(np.uint8))
153
+
154
+ def run_flow(input_image, holiday, message):
155
+ prompt = image_to_prompt(input_image)
156
+ print(prompt)
157
+ result_image, only_background_image, mask = remove_background(input_image)
158
+ dilated_mask = dilate_mask(mask)
159
+ modified_image = modify_background(input_image, dilated_mask, prompt)
160
+ first_output_image = mask
161
+ second_output_image = dilated_mask
162
+ third_output_image = modified_image
163
 
164
  return first_output_image, second_output_image, third_output_image
165
 
166
  # Replace the demo interface
167
  demo = gr.Interface(
168
+ fn=run_flow,
169
  inputs=[
170
  gr.Image(type="pil"),
171
  gr.Text(label="Holiday (e.g. Christmas, New Year's, etc.)"),
requirements.txt CHANGED
@@ -5,4 +5,6 @@ pillow
5
  kornia
6
  transformers
7
  timm
8
- matplotlib
 
 
 
5
  kornia
6
  transformers
7
  timm
8
+ matplotlib
9
+ openai
10
+ requests