Spaces:
Running
Running
File size: 6,006 Bytes
86b351a 4310b90 45fabe9 21eb680 85d7f84 86b351a 21eb680 86b351a 4310b90 86b351a 4310b90 86b351a 4310b90 86b351a 4310b90 d9ec72e 86b351a 21eb680 85d7f84 21eb680 85d7f84 21eb680 86b351a d9ec72e 86b351a 4310b90 86b351a 4310b90 86b351a 4310b90 86b351a 4310b90 86b351a 4310b90 86b351a c665bd4 86b351a 4310b90 86b351a 4310b90 86b351a 4310b90 86b351a 4310b90 86b351a 21eb680 85d7f84 21eb680 86b351a d9ec72e 86b351a 4310b90 86b351a 4310b90 86b351a 4310b90 86b351a 4310b90 86b351a 4310b90 86b351a 4310b90 86b351a 4310b90 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 |
from google.genai import types
import os
from PIL import Image
from io import BytesIO
from datetime import datetime
import logging
import asyncio
import gradio as gr
from config import settings
from services.google import GoogleClientFactory
from agent.utils import with_retries
logger = logging.getLogger(__name__)
safety_settings = [
types.SafetySetting(
category="HARM_CATEGORY_HARASSMENT",
threshold="BLOCK_NONE", # Block none
),
types.SafetySetting(
category="HARM_CATEGORY_HATE_SPEECH",
threshold="BLOCK_NONE", # Block none
),
types.SafetySetting(
category="HARM_CATEGORY_SEXUALLY_EXPLICIT",
threshold="BLOCK_NONE", # Block none
),
types.SafetySetting(
category="HARM_CATEGORY_DANGEROUS_CONTENT",
threshold="BLOCK_NONE", # Block none
),
]
async def generate_image(prompt: str) -> tuple[str, str] | None:
"""
Generate an image using Google's Gemini model and save it to generated/images directory.
Args:
prompt (str): The text prompt to generate the image from
Returns:
str: Path to the generated image file, or None if generation failed
"""
# Ensure the generated/images directory exists
output_dir = "generated/images"
os.makedirs(output_dir, exist_ok=True)
logger.info(f"Generating image with prompt: {prompt}")
try:
async with GoogleClientFactory.image() as client:
response = await with_retries(
lambda: client.models.generate_content(
model="gemini-2.0-flash-preview-image-generation",
contents=prompt,
config=types.GenerateContentConfig(
response_modalities=["TEXT", "IMAGE"],
safety_settings=safety_settings,
),
)
)
# Process the response parts
image_saved = False
for part in response.candidates[0].content.parts:
if part.inline_data is not None:
# Create a filename with timestamp
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
filename = f"gemini_{timestamp}.png"
filepath = os.path.join(output_dir, filename)
# Save the image
image = Image.open(BytesIO(part.inline_data.data))
await asyncio.to_thread(image.save, filepath, "PNG")
logger.info(f"Image saved to: {filepath}")
image_saved = True
return filepath, prompt
if not image_saved:
gr.Warning("Image was censored by Google!")
logger.error("No image was generated in the response.")
return None, None
except Exception as e:
logger.error(f"Error generating image: {e}")
return None, None
async def modify_image(image_path: str, modification_prompt: str) -> str | None:
"""
Modify an existing image using Google's Gemini model based on a text prompt.
Args:
image_path (str): Path to the existing image file
modification_prompt (str): The text prompt describing how to modify the image
Returns:
str: Path to the modified image file, or None if modification failed
"""
# Ensure the generated/images directory exists
output_dir = "generated/images"
os.makedirs(output_dir, exist_ok=True)
logger.info(f"Modifying current scene image with prompt: {modification_prompt}")
# Check if the input image exists
if not os.path.exists(image_path):
logger.error(f"Error: Image file not found at {image_path}")
return None
try:
async with GoogleClientFactory.image() as client:
# Load the input image
input_image = Image.open(image_path)
# Make the API call with both text and image
response = await with_retries(
lambda: client.models.generate_content(
model="gemini-2.0-flash-preview-image-generation",
contents=[modification_prompt, input_image],
config=types.GenerateContentConfig(
response_modalities=["TEXT", "IMAGE"],
safety_settings=safety_settings,
),
),
)
# Process the response parts
image_saved = False
for part in response.candidates[0].content.parts:
if part.inline_data is not None:
# Create a filename with timestamp
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
filename = f"gemini_modified_{timestamp}.png"
filepath = os.path.join(output_dir, filename)
# Save the modified image
modified_image = Image.open(BytesIO(part.inline_data.data))
await asyncio.to_thread(modified_image.save, filepath, "PNG")
logger.info(f"Modified image saved to: {filepath}")
image_saved = True
return filepath, modification_prompt
if not image_saved:
gr.Warning("Updated image was censored by Google!")
logger.error("No modified image was generated in the response.")
return None, None
except Exception as e:
logger.error(f"Error modifying image: {e}")
return None, None
if __name__ == "__main__":
# Example usage
sample_prompt = "A Luke Skywalker half height sprite with white background for visual novel game"
generated_image_path = generate_image(sample_prompt)
# if generated_image_path:
# # Example modification
# modification_prompt = "Now the house is destroyed, and the jawas are running away"
# modified_image_path = modify_image(generated_image_path, modification_prompt)
# if modified_image_path:
# print(f"Successfully modified image: {modified_image_path}")
|