Control_Ability_Arena / model /models /openai_api_models.py
Bbmyy
first commit
c92c0ec
raw
history blame
1.7 kB
from openai import OpenAI
from PIL import Image
import requests
import io
import os
import base64
class OpenaiModel():
def __init__(self, model_name, model_type):
self.model_name = model_name
self.model_type = model_type
def __call__(self, *args, **kwargs):
if self.model_type == "text2image":
assert "prompt" in kwargs, "prompt is required for text2image model"
client = OpenAI()
if 'Dalle-3' in self.model_name:
client = OpenAI()
response = client.images.generate(
model="dall-e-3",
prompt=kwargs["prompt"],
size="1024x1024",
quality="standard",
n=1,
)
elif 'Dalle-2' in self.model_name:
client = OpenAI()
response = client.images.generate(
model="dall-e-2",
prompt=kwargs["prompt"],
size="512x512",
quality="standard",
n=1,
)
else:
raise NotImplementedError
result_url = response.data[0].url
response = requests.get(result_url)
result = Image.open(io.BytesIO(response.content))
return result
else:
raise ValueError("model_type must be text2image or image2image")
def load_openai_model(model_name, model_type):
return OpenaiModel(model_name, model_type)
if __name__ == "__main__":
pipe = load_openai_model('Dalle-3', 'text2image')
result = pipe(prompt='draw a tiger')
print(result)