Spaces:
Running
Running
from google.genai import types | |
import os | |
from PIL import Image | |
from io import BytesIO | |
from datetime import datetime | |
import logging | |
import asyncio | |
import gradio as gr | |
from config import settings | |
from services.google import GoogleClientFactory | |
from agent.utils import with_retries | |
logger = logging.getLogger(__name__) | |
safety_settings = [ | |
types.SafetySetting( | |
category="HARM_CATEGORY_HARASSMENT", | |
threshold="BLOCK_NONE", # Block none | |
), | |
types.SafetySetting( | |
category="HARM_CATEGORY_HATE_SPEECH", | |
threshold="BLOCK_NONE", # Block none | |
), | |
types.SafetySetting( | |
category="HARM_CATEGORY_SEXUALLY_EXPLICIT", | |
threshold="BLOCK_NONE", # Block none | |
), | |
types.SafetySetting( | |
category="HARM_CATEGORY_DANGEROUS_CONTENT", | |
threshold="BLOCK_NONE", # Block none | |
), | |
] | |
async def generate_image(prompt: str) -> tuple[str, str] | None: | |
""" | |
Generate an image using Google's Gemini model and save it to generated/images directory. | |
Args: | |
prompt (str): The text prompt to generate the image from | |
Returns: | |
str: Path to the generated image file, or None if generation failed | |
""" | |
# Ensure the generated/images directory exists | |
output_dir = "generated/images" | |
os.makedirs(output_dir, exist_ok=True) | |
logger.info(f"Generating image with prompt: {prompt}") | |
try: | |
async with GoogleClientFactory.image() as client: | |
response = await with_retries( | |
lambda: client.models.generate_content( | |
model="gemini-2.0-flash-preview-image-generation", | |
contents=prompt, | |
config=types.GenerateContentConfig( | |
response_modalities=["TEXT", "IMAGE"], | |
safety_settings=safety_settings, | |
), | |
) | |
) | |
# Process the response parts | |
image_saved = False | |
for part in response.candidates[0].content.parts: | |
if part.inline_data is not None: | |
# Create a filename with timestamp | |
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") | |
filename = f"gemini_{timestamp}.png" | |
filepath = os.path.join(output_dir, filename) | |
# Save the image | |
image = Image.open(BytesIO(part.inline_data.data)) | |
await asyncio.to_thread(image.save, filepath, "PNG") | |
logger.info(f"Image saved to: {filepath}") | |
image_saved = True | |
return filepath, prompt | |
if not image_saved: | |
gr.Warning("Image was censored by Google!") | |
logger.error("No image was generated in the response.") | |
return None, None | |
except Exception as e: | |
logger.error(f"Error generating image: {e}") | |
return None, None | |
async def modify_image(image_path: str, modification_prompt: str) -> str | None: | |
""" | |
Modify an existing image using Google's Gemini model based on a text prompt. | |
Args: | |
image_path (str): Path to the existing image file | |
modification_prompt (str): The text prompt describing how to modify the image | |
Returns: | |
str: Path to the modified image file, or None if modification failed | |
""" | |
# Ensure the generated/images directory exists | |
output_dir = "generated/images" | |
os.makedirs(output_dir, exist_ok=True) | |
logger.info(f"Modifying current scene image with prompt: {modification_prompt}") | |
# Check if the input image exists | |
if not os.path.exists(image_path): | |
logger.error(f"Error: Image file not found at {image_path}") | |
return None | |
try: | |
async with GoogleClientFactory.image() as client: | |
# Load the input image | |
input_image = Image.open(image_path) | |
# Make the API call with both text and image | |
response = await with_retries( | |
lambda: client.models.generate_content( | |
model="gemini-2.0-flash-preview-image-generation", | |
contents=[modification_prompt, input_image], | |
config=types.GenerateContentConfig( | |
response_modalities=["TEXT", "IMAGE"], | |
safety_settings=safety_settings, | |
), | |
), | |
) | |
# Process the response parts | |
image_saved = False | |
for part in response.candidates[0].content.parts: | |
if part.inline_data is not None: | |
# Create a filename with timestamp | |
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") | |
filename = f"gemini_modified_{timestamp}.png" | |
filepath = os.path.join(output_dir, filename) | |
# Save the modified image | |
modified_image = Image.open(BytesIO(part.inline_data.data)) | |
await asyncio.to_thread(modified_image.save, filepath, "PNG") | |
logger.info(f"Modified image saved to: {filepath}") | |
image_saved = True | |
return filepath, modification_prompt | |
if not image_saved: | |
gr.Warning("Updated image was censored by Google!") | |
logger.error("No modified image was generated in the response.") | |
return None, None | |
except Exception as e: | |
logger.error(f"Error modifying image: {e}") | |
return None, None | |
if __name__ == "__main__": | |
# Example usage | |
sample_prompt = "A Luke Skywalker half height sprite with white background for visual novel game" | |
generated_image_path = generate_image(sample_prompt) | |
# if generated_image_path: | |
# # Example modification | |
# modification_prompt = "Now the house is destroyed, and the jawas are running away" | |
# modified_image_path = modify_image(generated_image_path, modification_prompt) | |
# if modified_image_path: | |
# print(f"Successfully modified image: {modified_image_path}") | |