Spaces:
Sleeping
Sleeping
# This wrapper provides several features: | |
# 1. A `ImageGenerationParams` dataclass to handle parameters with default values | |
# 2. A `ImageGenerationResult` class to wrap the API response | |
# 3. The main `ImagenWrapper` class with: | |
# - Proper initialization with error handling | |
# - Logging support | |
# - Two methods for generation: | |
# - `generate()` using the `ImageGenerationParams` class | |
# - `generate_simple()` for a more straightforward interface | |
# Here's how to use it: | |
# # Example usage: | |
# # Initialize the wrapper | |
# wrapper = ImagenWrapper("https://bcdb8b7f9c4a57127c.gradio.live/") | |
# # Method 1: Using ImageGenerationParams | |
# params = ImageGenerationParams( | |
# prompt="A beautiful sunset over mountains", | |
# width=512, | |
# height=512 | |
# ) | |
# result = wrapper.generate(params) | |
# # Method 2: Using generate_simple | |
# result = wrapper.generate_simple( | |
# prompt="A beautiful sunset over mountains", | |
# width=512, | |
# height=512 | |
# ) | |
# # Access the results | |
# print(f"Image URL: {result.image_url}") | |
# print(f"Seed used: {result.seed}") | |
# The wrapper includes: | |
# - Type hints for better IDE support | |
# - Error handling and logging | |
# - Parameter validation | |
# - Flexible parameter input (both through dataclass and dictionary) | |
# - Clean result handling through a dedicated class | |
# You can also add error handling in your code: | |
# try: | |
# wrapper = ImagenWrapper("https://bcdb8b7f9c4a57127c.gradio.live/") | |
# result = wrapper.generate_simple("A beautiful sunset") | |
# print(f"Generated image: {result}") | |
# except ConnectionError as e: | |
# print(f"Failed to connect to API: {e}") | |
# except RuntimeError as e: | |
# print(f"Generation failed: {e}") | |
# except Exception as e: | |
# print(f"Unexpected error: {e}") | |
from gradio_client import Client | |
from typing import Dict, Tuple, Optional, Union | |
from dataclasses import dataclass | |
import logging | |
class ImageGenerationParams: | |
"""Data class to hold image generation parameters""" | |
prompt: str | |
seed: float = 0 | |
randomize_seed: bool = True | |
width: float = 1024 | |
height: float = 1024 | |
guidance_scale: float = 3.5 | |
num_inference_steps: float = 28 | |
lora_scale: float = 0.7 | |
class ImageGenerationResult: | |
"""Class to handle the generation result""" | |
def __init__(self, image_data: Dict, seed: float): | |
self.image_path = image_data.get('path') | |
self.image_url = image_data.get('url') | |
self.size = image_data.get('size') | |
self.orig_name = image_data.get('orig_name') | |
self.mime_type = image_data.get('mime_type') | |
self.is_stream = image_data.get('is_stream', False) | |
self.meta = image_data.get('meta', {}) | |
self.seed = seed | |
def __str__(self) -> str: | |
return f"ImageGenerationResult(url={self.image_url}, seed={self.seed})" | |
class ImagenWrapper: | |
"""Wrapper class for the Imagen Gradio deployment""" | |
def __init__(self, api_url: str): | |
""" | |
Initialize the wrapper with the API URL | |
Args: | |
api_url (str): The URL of the Gradio deployment | |
""" | |
self.api_url = api_url | |
self.logger = logging.getLogger(__name__) | |
try: | |
self.client = Client(api_url) | |
self.logger.info(f"Successfully connected to API at {api_url}") | |
except Exception as e: | |
self.logger.error(f"Failed to connect to API at {api_url}: {str(e)}") | |
raise ConnectionError(f"Failed to connect to API: {str(e)}") | |
def generate(self, | |
params: Union[ImageGenerationParams, Dict], | |
) -> ImageGenerationResult: | |
""" | |
Generate an image using the provided parameters | |
Args: | |
params: Either an ImageGenerationParams object or a dictionary with the parameters | |
Returns: | |
ImageGenerationResult: Object containing the generation results | |
Raises: | |
ValueError: If parameters are invalid | |
RuntimeError: If the API call fails | |
""" | |
try: | |
# Convert dict to ImageGenerationParams if necessary | |
if isinstance(params, dict): | |
params = ImageGenerationParams(**params) | |
# Validate parameters | |
if not params.prompt: | |
raise ValueError("Prompt cannot be empty") | |
# Make the API call | |
result = self.client.predict( | |
prompt=params.prompt, | |
seed=params.seed, | |
randomize_seed=params.randomize_seed, | |
width=params.width, | |
height=params.height, | |
guidance_scale=params.guidance_scale, | |
num_inference_steps=params.num_inference_steps, | |
lora_scale=params.lora_scale, | |
api_name="/infer" | |
) | |
# Process the result | |
if not result or len(result) != 2: | |
raise RuntimeError("Invalid response from API") | |
image_data, seed = result | |
return ImageGenerationResult(image_data, seed) | |
except Exception as e: | |
self.logger.error(f"Error during image generation: {str(e)}") | |
raise RuntimeError(f"Failed to generate image: {str(e)}") | |
def generate_simple(self, | |
prompt: str, | |
**kwargs) -> ImageGenerationResult: | |
""" | |
Simplified interface for generating images | |
Args: | |
prompt (str): The prompt for image generation | |
**kwargs: Optional parameters to override defaults | |
Returns: | |
ImageGenerationResult: Object containing the generation results | |
""" | |
params = ImageGenerationParams(prompt=prompt, **kwargs) | |
return self.generate(params) |