File size: 6,006 Bytes
86b351a
 
 
 
 
 
4310b90
 
45fabe9
21eb680
85d7f84
86b351a
21eb680
86b351a
4310b90
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86b351a
 
 
4310b90
86b351a
 
4310b90
86b351a
 
 
 
 
 
4310b90
d9ec72e
86b351a
 
21eb680
85d7f84
 
21eb680
 
 
 
 
 
85d7f84
21eb680
86b351a
 
 
 
d9ec72e
86b351a
 
 
 
4310b90
86b351a
 
4310b90
86b351a
 
4310b90
 
 
86b351a
4310b90
86b351a
 
4310b90
86b351a
 
 
 
 
c665bd4
86b351a
 
4310b90
86b351a
 
 
4310b90
86b351a
 
 
 
 
 
4310b90
 
 
86b351a
 
 
 
4310b90
86b351a
21eb680
 
 
 
 
85d7f84
 
21eb680
 
 
 
 
 
 
 
86b351a
 
 
 
d9ec72e
86b351a
 
 
 
4310b90
86b351a
 
4310b90
86b351a
 
4310b90
 
 
86b351a
4310b90
86b351a
 
4310b90
86b351a
 
 
 
 
 
 
 
 
4310b90
86b351a
 
 
 
 
4310b90
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
from google.genai import types
import os
from PIL import Image
from io import BytesIO
from datetime import datetime
import logging
import asyncio
import gradio as gr
from config import settings
from services.google import GoogleClientFactory
from agent.utils import with_retries

logger = logging.getLogger(__name__)

safety_settings = [
    types.SafetySetting(
        category="HARM_CATEGORY_HARASSMENT",
        threshold="BLOCK_NONE",  # Block none
    ),
    types.SafetySetting(
        category="HARM_CATEGORY_HATE_SPEECH",
        threshold="BLOCK_NONE",  # Block none
    ),
    types.SafetySetting(
        category="HARM_CATEGORY_SEXUALLY_EXPLICIT",
        threshold="BLOCK_NONE",  # Block none
    ),
    types.SafetySetting(
        category="HARM_CATEGORY_DANGEROUS_CONTENT",
        threshold="BLOCK_NONE",  # Block none
    ),
]


async def generate_image(prompt: str) -> tuple[str, str] | None:
    """
    Generate an image using Google's Gemini model and save it to generated/images directory.

    Args:
        prompt (str): The text prompt to generate the image from

    Returns:
        str: Path to the generated image file, or None if generation failed
    """
    # Ensure the generated/images directory exists
    output_dir = "generated/images"
    os.makedirs(output_dir, exist_ok=True)

    logger.info(f"Generating image with prompt: {prompt}")

    try:
        async with GoogleClientFactory.image() as client:
            response = await with_retries(
                lambda: client.models.generate_content(
                    model="gemini-2.0-flash-preview-image-generation",
                    contents=prompt,
                    config=types.GenerateContentConfig(
                        response_modalities=["TEXT", "IMAGE"],
                        safety_settings=safety_settings,
                    ),
                )
            )

        # Process the response parts
        image_saved = False
        for part in response.candidates[0].content.parts:
            if part.inline_data is not None:
                # Create a filename with timestamp
                timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
                filename = f"gemini_{timestamp}.png"
                filepath = os.path.join(output_dir, filename)

                # Save the image
                image = Image.open(BytesIO(part.inline_data.data))
                await asyncio.to_thread(image.save, filepath, "PNG")
                logger.info(f"Image saved to: {filepath}")
                image_saved = True

                return filepath, prompt

        if not image_saved:
            gr.Warning("Image was censored by Google!")
            logger.error("No image was generated in the response.")
            return None, None

    except Exception as e:
        logger.error(f"Error generating image: {e}")
        return None, None


async def modify_image(image_path: str, modification_prompt: str) -> str | None:
    """
    Modify an existing image using Google's Gemini model based on a text prompt.

    Args:
        image_path (str): Path to the existing image file
        modification_prompt (str): The text prompt describing how to modify the image

    Returns:
        str: Path to the modified image file, or None if modification failed
    """
    # Ensure the generated/images directory exists
    output_dir = "generated/images"
    os.makedirs(output_dir, exist_ok=True)

    logger.info(f"Modifying current scene image with prompt: {modification_prompt}")

    # Check if the input image exists
    if not os.path.exists(image_path):
        logger.error(f"Error: Image file not found at {image_path}")
        return None

    try:
        async with GoogleClientFactory.image() as client:
            # Load the input image
            input_image = Image.open(image_path)

            # Make the API call with both text and image
            response = await with_retries(
                lambda: client.models.generate_content(
                    model="gemini-2.0-flash-preview-image-generation",
                    contents=[modification_prompt, input_image],
                    config=types.GenerateContentConfig(
                        response_modalities=["TEXT", "IMAGE"],
                        safety_settings=safety_settings,
                    ),
                ),
            )

        # Process the response parts
        image_saved = False
        for part in response.candidates[0].content.parts:
            if part.inline_data is not None:
                # Create a filename with timestamp
                timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
                filename = f"gemini_modified_{timestamp}.png"
                filepath = os.path.join(output_dir, filename)

                # Save the modified image
                modified_image = Image.open(BytesIO(part.inline_data.data))
                await asyncio.to_thread(modified_image.save, filepath, "PNG")
                logger.info(f"Modified image saved to: {filepath}")
                image_saved = True

                return filepath, modification_prompt

        if not image_saved:
            gr.Warning("Updated image was censored by Google!")
            logger.error("No modified image was generated in the response.")
            return None, None

    except Exception as e:
        logger.error(f"Error modifying image: {e}")
        return None, None


if __name__ == "__main__":
    # Example usage
    sample_prompt = "A Luke Skywalker half height sprite with white background for visual novel game"
    generated_image_path = generate_image(sample_prompt)

    # if generated_image_path:
    #     # Example modification
    #     modification_prompt = "Now the house is destroyed, and the jawas are running away"
    #     modified_image_path = modify_image(generated_image_path, modification_prompt)
    #     if modified_image_path:
    #         print(f"Successfully modified image: {modified_image_path}")