Spaces:
Sleeping
Sleeping
from openai import OpenAI | |
from PIL import Image | |
import requests | |
import io | |
import os | |
import base64 | |
class OpenaiModel(): | |
def __init__(self, model_name, model_type): | |
self.model_name = model_name | |
self.model_type = model_type | |
def __call__(self, *args, **kwargs): | |
if self.model_type == "text2image": | |
assert "prompt" in kwargs, "prompt is required for text2image model" | |
client = OpenAI() | |
if 'Dalle-3' in self.model_name: | |
client = OpenAI() | |
response = client.images.generate( | |
model="dall-e-3", | |
prompt=kwargs["prompt"], | |
size="1024x1024", | |
quality="standard", | |
n=1, | |
) | |
elif 'Dalle-2' in self.model_name: | |
client = OpenAI() | |
response = client.images.generate( | |
model="dall-e-2", | |
prompt=kwargs["prompt"], | |
size="512x512", | |
quality="standard", | |
n=1, | |
) | |
else: | |
raise NotImplementedError | |
result_url = response.data[0].url | |
response = requests.get(result_url) | |
result = Image.open(io.BytesIO(response.content)) | |
return result | |
else: | |
raise ValueError("model_type must be text2image or image2image") | |
def load_openai_model(model_name, model_type): | |
return OpenaiModel(model_name, model_type) | |
if __name__ == "__main__": | |
pipe = load_openai_model('Dalle-3', 'text2image') | |
result = pipe(prompt='draw a tiger') | |
print(result) | |