import ast import base64 import os import argparse import sys import uuid def main(): parser = argparse.ArgumentParser(description="Generate images from text prompts") parser.add_argument("--prompt", "--query", type=str, required=True, help="User prompt or query") parser.add_argument("--model", type=str, required=False, help="Model name") parser.add_argument("--output", "--file", type=str, required=False, default="", help="Name (unique) of the output file") parser.add_argument("--quality", type=str, required=False, choices=['standard', 'hd', 'quick', 'manual'], default='standard', help="Image quality") parser.add_argument("--size", type=str, required=False, default="1024x1024", help="Image size (height x width)") imagegen_url = os.getenv("IMAGEGEN_OPENAI_BASE_URL", '') assert imagegen_url is not None, "IMAGEGEN_OPENAI_BASE_URL environment variable is not set" server_api_key = os.getenv('IMAGEGEN_OPENAI_API_KEY', 'EMPTY') generation_params = {} is_openai = False if imagegen_url == "https://api.gpt.h2o.ai/v1": parser.add_argument("--guidance_scale", type=float, help="Guidance scale for image generation") parser.add_argument("--num_inference_steps", type=int, help="Number of inference steps") args = parser.parse_args() from openai import OpenAI client = OpenAI(base_url=imagegen_url, api_key=server_api_key) available_models = ['flux.1-schnell', 'playv2'] if os.getenv('IMAGEGEN_OPENAI_MODELS'): # allow override available_models = ast.literal_eval(os.getenv('IMAGEGEN_OPENAI_MODELS')) if not args.model: args.model = available_models[0] if args.model not in available_models: args.model = available_models[0] elif imagegen_url == "https://api.openai.com/v1" or 'openai.azure.com' in imagegen_url: is_openai = True parser.add_argument("--style", type=str, choices=['vivid', 'natural', 'artistic'], default='vivid', help="Image style") args = parser.parse_args() # https://platform.openai.com/docs/api-reference/images/create available_models = ['dall-e-3', 'dall-e-2'] # assumes deployment name matches model name, unless override if os.getenv('IMAGEGEN_OPENAI_MODELS'): # allow override available_models = ast.literal_eval(os.getenv('IMAGEGEN_OPENAI_MODELS')) if not args.model: args.model = available_models[0] if args.model not in available_models: args.model = available_models[0] if 'openai.azure.com' in imagegen_url: # https://learn.microsoft.com/en-us/azure/ai-services/openai/dall-e-quickstart?tabs=dalle3%2Ccommand-line%2Ctypescript&pivots=programming-language-python from openai import AzureOpenAI client = AzureOpenAI( api_version="2024-02-01" if args.model == 'dall-e-3' else '2023-06-01-preview', api_key=os.environ["IMAGEGEN_OPENAI_API_KEY"], # like base_url, but Azure endpoint like https://PROJECT.openai.azure.com/ azure_endpoint=os.environ['IMAGEGEN_OPENAI_BASE_URL'] ) else: from openai import OpenAI client = OpenAI(base_url=imagegen_url, api_key=server_api_key) dalle2aliases = ['dall-e-2', 'dalle2', 'dalle-2'] max_chars = 1000 if args.model in dalle2aliases else 4000 args.prompt = args.prompt[:max_chars] if args.model in dalle2aliases: valid_sizes = ['256x256', '512x512', '1024x1024'] else: valid_sizes = ['1024x1024', '1792x1024', '1024x1792'] if args.size not in valid_sizes: args.size = valid_sizes[0] args.quality = 'standard' if args.quality not in ['standard', 'hd'] else args.quality args.style = 'vivid' if args.style not in ['vivid', 'natural'] else args.style generation_params.update({ "style": args.style, }) else: parser.add_argument("--guidance_scale", type=float, help="Guidance scale for image generation") parser.add_argument("--num_inference_steps", type=int, help="Number of inference steps") args = parser.parse_args() from openai import OpenAI client = OpenAI(base_url=imagegen_url, api_key=server_api_key) assert os.getenv('IMAGEGEN_OPENAI_MODELS'), "IMAGEGEN_OPENAI_MODELS environment variable is not set" available_models = ast.literal_eval(os.getenv('IMAGEGEN_OPENAI_MODELS')) # must be string of list of strings assert available_models, "IMAGEGEN_OPENAI_MODELS environment variable is not set, must be for this server" if args.model is None: args.model = available_models[0] if args.model not in available_models: args.model = available_models[0] # for azure, args.model use assume deployment name matches model name (i.e. dall-e-3 not dalle3) unless IMAGEGEN_OPENAI_MODELS set generation_params.update({ "prompt": args.prompt, "model": args.model, "quality": args.quality, "size": args.size, "response_format": "b64_json", }) if not is_openai: extra_body = {} if args.guidance_scale: extra_body["guidance_scale"] = args.guidance_scale if args.num_inference_steps: extra_body["num_inference_steps"] = args.num_inference_steps if extra_body: generation_params["extra_body"] = extra_body response = client.images.generate(**generation_params) if hasattr(response.data[0], 'revised_prompt') and response.data[0].revised_prompt: print("Image Generator revised the prompt (this is expected): %s" % response.data[0].revised_prompt) assert response.data[0].b64_json is not None or response.data[0].url is not None, "No image data returned" if response.data[0].b64_json: image_data_base64 = response.data[0].b64_json image_data = base64.b64decode(image_data_base64) else: from openai_server.agent_tools.common.utils import download_simple dest = download_simple(response.data[0].url, overwrite=True) with open(dest, "rb") as f: image_data = f.read() os.remove(dest) # Determine file type and name image_format = get_image_format(image_data) if not args.output: args.output = f"image_{str(uuid.uuid4())[:6]}.{image_format}" else: # If an output path is provided, ensure it has the correct extension base, ext = os.path.splitext(args.output) if ext.lower() != f".{image_format}": args.output = f"{base}.{image_format}" # Write the image data to a file with open(args.output, "wb") as img_file: img_file.write(image_data) full_path = os.path.abspath(args.output) print(f"Image successfully saved to the file: {full_path}") # NOTE: Could provide stats like image size, etc. def get_image_format(image_data): from PIL import Image import io # Use PIL to determine the image format with Image.open(io.BytesIO(image_data)) as img: return img.format.lower() if __name__ == "__main__": main()