File size: 10,716 Bytes
7c2b30b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
import os
import configparser
import PIL
from PIL import Image
from typing import Any, List, Optional, Literal
from pydantic import BaseModel, Field
import google.generativeai as genai
import vertexai
from vertexai.preview.vision_models import ImageGenerationModel, GeneratedImage
from vertexai.preview.vision_models import Image as VertexImage

# Load configuration
config = configparser.ConfigParser()
config.read('config.ini')

api_key = config['credentials']['gemini_api_key']
proj_id = config['credentials']['gcp_project_id']
info = {"PROJECT_ID": proj_id, "LOCATION": "us-central1", "API_KEY": api_key}

# Create directories for temporary and output images if they do not exist
os.makedirs('images/temp', exist_ok=True)
os.makedirs('images/output', exist_ok=True)

class BannerGenerator(BaseModel):
    """Class responsible for generating banners using generative models."""

    CONFIGS: dict
    topic: str
    images: List[str] | None = Field(default=None)
    aspect_ratio: Optional[Literal["1:1", "9:16", "16:9", "4:3", "3:4"]] = None  # Added aspect_ratio attribute
    text_model: str = "gemini-1.5-flash"
    image_model: str = "imagen-3.0-generate-001"
    edit_model: str = "imagegeneration@006"
    pm: Any = None
    im: Any = None
    em: Any = None
    text_v0: str = None
    text_v1: str = None
    text_v2: str = None
    text_v3: str = None
    img_response_v1: Any = None
    img_response_v2: Any = None
    img_response_v3: Any = None
    launch_state: bool = False

    def __launch(self):
        """Launches the generative models and sets up environment."""
        if not self.launch_state:
            vertexai.init(project=self.CONFIGS['PROJECT_ID'], location=self.CONFIGS['LOCATION'])
            genai.configure(api_key=self.CONFIGS['API_KEY'])
            self.pm = genai.GenerativeModel(self.text_model)
            self.im = ImageGenerationModel.from_pretrained(self.image_model)
            self.em = ImageGenerationModel.from_pretrained(self.edit_model)
            self.launch_state = True
            print("Model Launch successful!")

    def load_images(self) -> List[PIL.Image.Image]:
        """Loads images from file paths provided in the `images` attribute."""
        self.__launch()
        loaded_images = []

        for image_path in self.images:
            img = PIL.Image.open(image_path)
            if img.mode == 'RGBA':
                img = img.convert('RGB')  # Convert to RGB if needed
            loaded_images.append(img)

        return loaded_images

    def extract_image_information(self) -> str:
        """Extracts information from images using the generative text model."""
        images = self.load_images()
        extraction_prompt = '''Examine the set of images to extract information about the product (name, logo) in less than 80 words.'''
        model_input = [extraction_prompt] + images
        response = self.pm.generate_content(model_input)
        print("Attached images examined!")
        return response.text

    def extract_information(self) -> None:
        """Extracts information from the given topic and images using a detailed analysis."""
        self.__launch()

        out_text = f"""Deep analyze text from retail advertising, marketing psychology, and thoroughly researched marketing studies perspective: {self.topic}
        Extract the following information:
        0. Product: Product name, brand and supplier, logo, tagline, size, packaging if available
        1. Objective: 1 word for primary goal of the banner. Example - Awareness, Engagement, Conversion, Branding
        2. Festival: Event or occasion it may be tied to. Example - Christmas, Diwali, Black Friday, Summer Sale, New Year, Generic
        3. Headline: Suggest a main text that captures attention. Example - Discover [product] for [festival], Shop now and save!, Limited time offer, Innovate your life with [product]
        4. Subheadline: Optional additional supporting information to clarify the offer. Example - Get 50% off until [date], Exclusive deal for festive season, Hurry offer ends soon
        5. CTA: Add a call to action. Example - Buy now, Shop the collection, Discover More, Sign up today
        6. Color Scheme: Use color palette based on audience, occasion, or product tone. Example - Red & Gold (Festive, Urgency), Blue & White (Trust, Calm), Green & Brown (Eco-friendly, Natural), Black & White (Elegant, Minimal)
        7. Promotional offer: Suggest 1 best promotional offer. Example - MAX ₹99 OFF, UP TO 60% OFF, UNDER ₹999, MIN ₹10 OFF, MIN 20% OFF, STARTS @₹99, FLAT ₹100 OFF, FLAT 20% OFF, ₹499 STORE, BUY 2 GET 1 FREE
        8. Background color gradient: Dynamic color generation to match overall look and feel
        9. Background theme: Festival oriented or generic if no festival
        """
        self.text_v0 = self.pm.generate_content(out_text).text

        # Information consolidation
        out_text = f"Respond concisely and summarize in Python dictionary format only this: {self.text_v0}"
        if self.images:
            image_info = self.extract_image_information()
            out_text += ' Product insights: ' + image_info
            print(f"Product insights: {image_info}")

        self.text_v1 = self.pm.generate_content(out_text).text[9:-5]

        # Scrapper to ensure data integrity and consistency
        out_text = f"Respond concisely by scrapping all unavailable information in Python dictionary format only this: {self.text_v1}"
        self.text_v2 = self.pm.generate_content(out_text).text

        print("Information collection complete!")

    def create_text_prompt(self) -> None:
        """Creates a text prompt based on the extracted information."""
        out_text = f"""Task: Fill in the values in this json: {self.text_v2}
        Guidelines:
        1. It will be used to generate an ads banner.
        2. Ensure it has all details pair-wise meticulously captured.
        3. All unknown/missing/unprovided variables are replaced with the attributes of the most probable shopper for that product.
        4. Recheck and identify all ambiguity or any text that leads to uncertainty.
        5. Replace all uncertainty with targeted values that make the most sense for the given product.
        6. Quantify everything possible, like high, medium, and lows to percentage values based on marketing and psychometric research studies.
        7. All KPIs and qualitative measures are to be used subcontextually only. Remove any details about statistical testing or names of any performance KPIs.
        8. Avoid sentences and use only necessary keywords.
        9. Remove all redundant key-value pairs.
        """
        self.text_v3 = self.pm.generate_content(out_text).text
        print("Information processed!")

    def generate_image(self) -> str:
        """Generates an image based on the given prompt and saves it in images/temp/."""
        prompt = f"""Realistic, subcontextually implied qualitative attributes inspired, excellent image quality ad capturing every detail in json:{self.text_v3}"""

        # Adding the aspect ratio parameter to the image generation
        if self.aspect_ratio:
            self.img_response_v1 = self.im.generate_images(prompt=prompt, aspect_ratio=self.aspect_ratio)
        else:
            self.img_response_v1 = self.im.generate_images(prompt=prompt)

        # Save the generated image to images/temp/
        temp_img_path = 'images/temp/temp.jpg'
        if os.path.exists(temp_img_path):
            os.remove(temp_img_path)
        self.img_response_v1.images[0].save(temp_img_path)

        print("Image v1 generated!")
        return temp_img_path

    def identify_lags(self) -> str:
        """Identifies quality issues in the generated image and provides suggestions for improvement."""
        prompt = f"""Be direct. Quality check the banner out of 10 on:
        1. Promotional offer present as per instructions below
        2. Ensure ALL texts pass grammatical checks
        3. Color pallette as per instructions below
        4. Occassion or festival theme is present as per instructions below

        ONLY USE INFORMATION FROM {self.text_v3}. Don't DO NOT make up colors, promo or occasion. Make sure the promo and color pallete is followed as per above instructions.

        Precisely point out errors and corresponding actions to fix the image where score is below 8.
        Do not output anything about elements that need no change.
        """
        temp_img_path = 'images/temp/temp.jpg'
        response = self.pm.generate_content([prompt, PIL.Image.open(temp_img_path)])
        print(f'Lags identified: {response.text}')
        return response.text

    def fix_image(self, retest: bool = False) -> str:
        """Attempts to fix the identified lags in the generated image and saves it."""
        prompt = f'Realistic, subcontextually implied qualitative attributes inspired, excellent image quality ad by: {self.identify_lags()}'
        temp_img_path = 'images/temp/temp.jpg'

        base_image = VertexImage.load_from_file(location=temp_img_path)
        self.img_response_v2 = self.em.edit_image(
            base_image=base_image,
            prompt=prompt,
            edit_mode="inpainting-insert",
            mask_mode="background"
        )
        self.img_response_v2.images[0].save(temp_img_path)
        print("Image v2 generated!")

        if retest:
            prompt = f'Realistic, subcontextually implied qualitative attributes inspired, excellent image quality ad edit by: {self.identify_lags()}'
            self.img_response_v3 = self.em.edit_image(
                base_image=VertexImage.load_from_file(location=temp_img_path),
                prompt=prompt,
                edit_mode="inpainting-insert",
                mask_mode="background"
            )
            self.img_response_v3.images[0].save(temp_img_path)
            print("Image v3 generated!")

        # Save final image to output directory with a name corresponding to the topic
        output_filename = ''.join(e for e in self.topic if e.isalnum() or e == ' ')
        output_filename = output_filename[:50].strip().replace(' ', '_') + ".jpg"
        output_img_path = os.path.join('images/output', output_filename)
        self.img_response_v2.images[0].save(output_img_path)
        print(f"Final image saved as: {output_img_path}")

        return output_img_path

    def execute(self, QC=False) -> str:
        """Executes the entire workflow to generate and refine the banner."""
        self.extract_information()
        self.create_text_prompt()
        self.generate_image()
        final_image_path = self.fix_image(retest=QC)
        return final_image_path