|
import ast |
|
import base64 |
|
import os |
|
import argparse |
|
import sys |
|
import uuid |
|
|
|
|
|
def main(): |
|
parser = argparse.ArgumentParser(description="Generate images from text prompts") |
|
parser.add_argument("--prompt", "--query", type=str, required=True, help="User prompt or query") |
|
parser.add_argument("--model", type=str, required=False, help="Model name") |
|
parser.add_argument("--output", "--file", type=str, required=False, default="", |
|
help="Name (unique) of the output file") |
|
parser.add_argument("--quality", type=str, required=False, choices=['standard', 'hd', 'quick', 'manual'], |
|
default='standard', |
|
help="Image quality") |
|
parser.add_argument("--size", type=str, required=False, default="1024x1024", help="Image size (height x width)") |
|
|
|
imagegen_url = os.getenv("IMAGEGEN_OPENAI_BASE_URL", '') |
|
assert imagegen_url is not None, "IMAGEGEN_OPENAI_BASE_URL environment variable is not set" |
|
server_api_key = os.getenv('IMAGEGEN_OPENAI_API_KEY', 'EMPTY') |
|
|
|
generation_params = {} |
|
|
|
is_openai = False |
|
if imagegen_url == "https://api.gpt.h2o.ai/v1": |
|
parser.add_argument("--guidance_scale", type=float, help="Guidance scale for image generation") |
|
parser.add_argument("--num_inference_steps", type=int, help="Number of inference steps") |
|
args = parser.parse_args() |
|
from openai import OpenAI |
|
client = OpenAI(base_url=imagegen_url, api_key=server_api_key) |
|
available_models = ['flux.1-schnell', 'playv2'] |
|
if os.getenv('IMAGEGEN_OPENAI_MODELS'): |
|
|
|
available_models = ast.literal_eval(os.getenv('IMAGEGEN_OPENAI_MODELS')) |
|
if not args.model: |
|
args.model = available_models[0] |
|
if args.model not in available_models: |
|
args.model = available_models[0] |
|
elif imagegen_url == "https://api.openai.com/v1" or 'openai.azure.com' in imagegen_url: |
|
is_openai = True |
|
parser.add_argument("--style", type=str, choices=['vivid', 'natural', 'artistic'], default='vivid', |
|
help="Image style") |
|
args = parser.parse_args() |
|
|
|
available_models = ['dall-e-3', 'dall-e-2'] |
|
|
|
if os.getenv('IMAGEGEN_OPENAI_MODELS'): |
|
|
|
available_models = ast.literal_eval(os.getenv('IMAGEGEN_OPENAI_MODELS')) |
|
if not args.model: |
|
args.model = available_models[0] |
|
if args.model not in available_models: |
|
args.model = available_models[0] |
|
|
|
if 'openai.azure.com' in imagegen_url: |
|
|
|
from openai import AzureOpenAI |
|
client = AzureOpenAI( |
|
api_version="2024-02-01" if args.model == 'dall-e-3' else '2023-06-01-preview', |
|
api_key=os.environ["IMAGEGEN_OPENAI_API_KEY"], |
|
|
|
azure_endpoint=os.environ['IMAGEGEN_OPENAI_BASE_URL'] |
|
) |
|
else: |
|
from openai import OpenAI |
|
client = OpenAI(base_url=imagegen_url, api_key=server_api_key) |
|
|
|
dalle2aliases = ['dall-e-2', 'dalle2', 'dalle-2'] |
|
max_chars = 1000 if args.model in dalle2aliases else 4000 |
|
args.prompt = args.prompt[:max_chars] |
|
|
|
if args.model in dalle2aliases: |
|
valid_sizes = ['256x256', '512x512', '1024x1024'] |
|
else: |
|
valid_sizes = ['1024x1024', '1792x1024', '1024x1792'] |
|
|
|
if args.size not in valid_sizes: |
|
args.size = valid_sizes[0] |
|
|
|
args.quality = 'standard' if args.quality not in ['standard', 'hd'] else args.quality |
|
args.style = 'vivid' if args.style not in ['vivid', 'natural'] else args.style |
|
generation_params.update({ |
|
"style": args.style, |
|
}) |
|
else: |
|
parser.add_argument("--guidance_scale", type=float, help="Guidance scale for image generation") |
|
parser.add_argument("--num_inference_steps", type=int, help="Number of inference steps") |
|
args = parser.parse_args() |
|
|
|
from openai import OpenAI |
|
client = OpenAI(base_url=imagegen_url, api_key=server_api_key) |
|
assert os.getenv('IMAGEGEN_OPENAI_MODELS'), "IMAGEGEN_OPENAI_MODELS environment variable is not set" |
|
available_models = ast.literal_eval(os.getenv('IMAGEGEN_OPENAI_MODELS')) |
|
assert available_models, "IMAGEGEN_OPENAI_MODELS environment variable is not set, must be for this server" |
|
if args.model is None: |
|
args.model = available_models[0] |
|
if args.model not in available_models: |
|
args.model = available_models[0] |
|
|
|
|
|
generation_params.update({ |
|
"prompt": args.prompt, |
|
"model": args.model, |
|
"quality": args.quality, |
|
"size": args.size, |
|
"response_format": "b64_json", |
|
}) |
|
|
|
if not is_openai: |
|
extra_body = {} |
|
if args.guidance_scale: |
|
extra_body["guidance_scale"] = args.guidance_scale |
|
if args.num_inference_steps: |
|
extra_body["num_inference_steps"] = args.num_inference_steps |
|
if extra_body: |
|
generation_params["extra_body"] = extra_body |
|
|
|
response = client.images.generate(**generation_params) |
|
|
|
if hasattr(response.data[0], 'revised_prompt') and response.data[0].revised_prompt: |
|
print("Image Generator revised the prompt (this is expected): %s" % response.data[0].revised_prompt) |
|
|
|
assert response.data[0].b64_json is not None or response.data[0].url is not None, "No image data returned" |
|
|
|
if response.data[0].b64_json: |
|
image_data_base64 = response.data[0].b64_json |
|
image_data = base64.b64decode(image_data_base64) |
|
else: |
|
from openai_server.agent_tools.common.utils import download_simple |
|
dest = download_simple(response.data[0].url, overwrite=True) |
|
with open(dest, "rb") as f: |
|
image_data = f.read() |
|
os.remove(dest) |
|
|
|
|
|
image_format = get_image_format(image_data) |
|
if not args.output: |
|
args.output = f"image_{str(uuid.uuid4())[:6]}.{image_format}" |
|
else: |
|
|
|
base, ext = os.path.splitext(args.output) |
|
if ext.lower() != f".{image_format}": |
|
args.output = f"{base}.{image_format}" |
|
|
|
|
|
with open(args.output, "wb") as img_file: |
|
img_file.write(image_data) |
|
|
|
full_path = os.path.abspath(args.output) |
|
print(f"Image successfully saved to the file: {full_path}") |
|
|
|
|
|
|
|
|
|
def get_image_format(image_data): |
|
from PIL import Image |
|
import io |
|
|
|
with Image.open(io.BytesIO(image_data)) as img: |
|
return img.format.lower() |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|