Spaces:
Running
on
Zero
Running
on
Zero
import base64 | |
import logging | |
import os | |
from io import BytesIO | |
from typing import Optional | |
from openai import AzureOpenAI, OpenAI # pip install openai | |
from PIL import Image | |
from tenacity import ( | |
retry, | |
stop_after_attempt, | |
stop_after_delay, | |
wait_random_exponential, | |
) | |
from asset3d_gen.utils.process_media import combine_images_to_base64 | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
class GPTclient: | |
"""A client to interact with the GPT model via OpenAI or Azure API.""" | |
def __init__( | |
self, | |
endpoint: str, | |
api_key: str, | |
model_name: str = "yfb-gpt-4o", | |
api_version: str = None, | |
verbose: bool = False, | |
): | |
if api_version is not None: | |
self.client = AzureOpenAI( | |
azure_endpoint=endpoint, | |
api_key=api_key, | |
api_version=api_version, | |
) | |
else: | |
self.client = OpenAI( | |
base_url=endpoint, | |
api_key=api_key, | |
) | |
self.endpoint = endpoint | |
self.model_name = model_name | |
self.image_formats = {".png", ".jpg", ".jpeg", ".webp", ".bmp", ".gif"} | |
self.verbose = verbose | |
def completion_with_backoff(self, **kwargs): | |
return self.client.chat.completions.create(**kwargs) | |
def query( | |
self, | |
text_prompt: str, | |
image_base64: Optional[list[str | Image.Image]] = None, | |
system_role: Optional[str] = None, | |
) -> Optional[str]: | |
"""Queries the GPT model with a text and optional image prompts. | |
Args: | |
text_prompt (str): The main text input that the model responds to. | |
image_base64 (Optional[List[str]]): A list of image base64 strings | |
or local image paths or PIL.Image to accompany the text prompt. | |
system_role (Optional[str]): Optional system-level instructions | |
that specify the behavior of the assistant. | |
Returns: | |
Optional[str]: The response content generated by the model based on | |
the prompt. Returns `None` if an error occurs. | |
""" | |
if system_role is None: | |
system_role = "You are a highly knowledgeable assistant specializing in physics, engineering, and object properties." # noqa | |
content_user = [ | |
{ | |
"type": "text", | |
"text": text_prompt, | |
}, | |
] | |
# Process images if provided | |
if image_base64 is not None: | |
image_base64 = ( | |
image_base64 | |
if isinstance(image_base64, list) | |
else [image_base64] | |
) | |
for img in image_base64: | |
if isinstance(img, Image.Image): | |
buffer = BytesIO() | |
img.save(buffer, format=img.format or "PNG") | |
buffer.seek(0) | |
image_binary = buffer.read() | |
img = base64.b64encode(image_binary).decode("utf-8") | |
elif ( | |
len(os.path.splitext(img)) > 1 | |
and os.path.splitext(img)[-1].lower() in self.image_formats | |
): | |
if not os.path.exists(img): | |
raise FileNotFoundError(f"Image file not found: {img}") | |
with open(img, "rb") as f: | |
img = base64.b64encode(f.read()).decode("utf-8") | |
content_user.append( | |
{ | |
"type": "image_url", | |
"image_url": {"url": f"data:image/png;base64,{img}"}, | |
} | |
) | |
payload = { | |
"messages": [ | |
{"role": "system", "content": system_role}, | |
{"role": "user", "content": content_user}, | |
], | |
"temperature": 0.1, | |
"max_tokens": 500, | |
"top_p": 0.1, | |
"frequency_penalty": 0, | |
"presence_penalty": 0, | |
"stop": None, | |
} | |
payload.update({"model": self.model_name}) | |
response = None | |
try: | |
response = self.completion_with_backoff(**payload) | |
response = response.choices[0].message.content | |
except Exception as e: | |
logger.error(f"Error GPTclint {self.endpoint} API call: {e}") | |
response = None | |
if self.verbose: | |
logger.info(f"Prompt: {text_prompt}") | |
logger.info(f"Response: {response}") | |
return response | |
endpoint = os.environ.get("endpoint", None) | |
api_key = os.environ.get("api_key", None) | |
api_version = os.environ.get("api_version", None) | |
if endpoint and api_key and api_version: | |
GPT_CLIENT = GPTclient( | |
endpoint=endpoint, | |
api_key=api_key, | |
api_version=api_version, | |
model_name="yfb-gpt-4o-sweden" if "sweden" in endpoint else None, | |
) | |
else: | |
GPT_CLIENT = GPTclient( | |
endpoint="https://openrouter.ai/api/v1", | |
# api_key="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", # noqa | |
api_key="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", | |
model_name="qwen/qwen2.5-vl-72b-instruct:free", | |
) | |
if __name__ == "__main__": | |
if "openrouter" in GPT_CLIENT.endpoint: | |
response = GPT_CLIENT.query( | |
text_prompt="What is the content in each image?", | |
image_base64=combine_images_to_base64( | |
[ | |
"outputs/text2image/demo_objects/bed/sample_0.jpg", | |
"outputs/imageto3d/v2/cups/sample_69/URDF_sample_69/qa_renders/image_color/003.png", # noqa | |
"outputs/text2image/demo_objects/cardboard/sample_1.jpg", | |
] | |
), # input raw image_path if only one image | |
) | |
print(response) | |
else: | |
response = GPT_CLIENT.query( | |
text_prompt="What is the content in the images?", | |
image_base64=[ | |
Image.open("outputs/text2image/demo_objects/bed/sample_0.jpg"), | |
Image.open( | |
"outputs/imageto3d/v2/cups/sample_69/URDF_sample_69/qa_renders/image_color/003.png" # noqa | |
), | |
], | |
) | |
print(response) | |
# test2: text prompt | |
response = GPT_CLIENT.query( | |
text_prompt="What is the capital of China?" | |
) | |
print(response) | |