import base64 import logging import os from io import BytesIO from typing import Optional from openai import AzureOpenAI, OpenAI # pip install openai from PIL import Image from tenacity import ( retry, stop_after_attempt, stop_after_delay, wait_random_exponential, ) from asset3d_gen.utils.process_media import combine_images_to_base64 logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) class GPTclient: """A client to interact with the GPT model via OpenAI or Azure API.""" def __init__( self, endpoint: str, api_key: str, model_name: str = "yfb-gpt-4o", api_version: str = None, verbose: bool = False, ): if api_version is not None: self.client = AzureOpenAI( azure_endpoint=endpoint, api_key=api_key, api_version=api_version, ) else: self.client = OpenAI( base_url=endpoint, api_key=api_key, ) self.endpoint = endpoint self.model_name = model_name self.image_formats = {".png", ".jpg", ".jpeg", ".webp", ".bmp", ".gif"} self.verbose = verbose @retry( wait=wait_random_exponential(min=1, max=20), stop=(stop_after_attempt(10) | stop_after_delay(30)), ) def completion_with_backoff(self, **kwargs): return self.client.chat.completions.create(**kwargs) def query( self, text_prompt: str, image_base64: Optional[list[str | Image.Image]] = None, system_role: Optional[str] = None, ) -> Optional[str]: """Queries the GPT model with a text and optional image prompts. Args: text_prompt (str): The main text input that the model responds to. image_base64 (Optional[List[str]]): A list of image base64 strings or local image paths or PIL.Image to accompany the text prompt. system_role (Optional[str]): Optional system-level instructions that specify the behavior of the assistant. Returns: Optional[str]: The response content generated by the model based on the prompt. Returns `None` if an error occurs. """ if system_role is None: system_role = "You are a highly knowledgeable assistant specializing in physics, engineering, and object properties." # noqa content_user = [ { "type": "text", "text": text_prompt, }, ] # Process images if provided if image_base64 is not None: image_base64 = ( image_base64 if isinstance(image_base64, list) else [image_base64] ) for img in image_base64: if isinstance(img, Image.Image): buffer = BytesIO() img.save(buffer, format=img.format or "PNG") buffer.seek(0) image_binary = buffer.read() img = base64.b64encode(image_binary).decode("utf-8") elif ( len(os.path.splitext(img)) > 1 and os.path.splitext(img)[-1].lower() in self.image_formats ): if not os.path.exists(img): raise FileNotFoundError(f"Image file not found: {img}") with open(img, "rb") as f: img = base64.b64encode(f.read()).decode("utf-8") content_user.append( { "type": "image_url", "image_url": {"url": f"data:image/png;base64,{img}"}, } ) payload = { "messages": [ {"role": "system", "content": system_role}, {"role": "user", "content": content_user}, ], "temperature": 0.1, "max_tokens": 500, "top_p": 0.1, "frequency_penalty": 0, "presence_penalty": 0, "stop": None, } payload.update({"model": self.model_name}) response = None try: response = self.completion_with_backoff(**payload) response = response.choices[0].message.content except Exception as e: logger.error(f"Error GPTclint {self.endpoint} API call: {e}") response = None if self.verbose: logger.info(f"Prompt: {text_prompt}") logger.info(f"Response: {response}") return response endpoint = os.environ.get("endpoint", None) api_key = os.environ.get("api_key", None) api_version = os.environ.get("api_version", None) if endpoint and api_key and api_version: GPT_CLIENT = GPTclient( endpoint=endpoint, api_key=api_key, api_version=api_version, model_name="yfb-gpt-4o-sweden" if "sweden" in endpoint else None, ) else: GPT_CLIENT = GPTclient( endpoint="https://openrouter.ai/api/v1", # api_key="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", # noqa api_key="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", model_name="qwen/qwen2.5-vl-72b-instruct:free", ) if __name__ == "__main__": if "openrouter" in GPT_CLIENT.endpoint: response = GPT_CLIENT.query( text_prompt="What is the content in each image?", image_base64=combine_images_to_base64( [ "outputs/text2image/demo_objects/bed/sample_0.jpg", "outputs/imageto3d/v2/cups/sample_69/URDF_sample_69/qa_renders/image_color/003.png", # noqa "outputs/text2image/demo_objects/cardboard/sample_1.jpg", ] ), # input raw image_path if only one image ) print(response) else: response = GPT_CLIENT.query( text_prompt="What is the content in the images?", image_base64=[ Image.open("outputs/text2image/demo_objects/bed/sample_0.jpg"), Image.open( "outputs/imageto3d/v2/cups/sample_69/URDF_sample_69/qa_renders/image_color/003.png" # noqa ), ], ) print(response) # test2: text prompt response = GPT_CLIENT.query( text_prompt="What is the capital of China?" ) print(response)