File size: 6,681 Bytes
3943768 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 |
import os
import argparse
import tempfile
import logging
import time
# Set up logging
logging.basicConfig(level=logging.WARNING)
logger = logging.getLogger(__name__)
# avoid logging that reveals urls
logging.getLogger("requests").setLevel(logging.WARNING)
logging.getLogger("urllib3").setLevel(logging.WARNING)
def convert_svg_to_png(svg_path):
import cairosvg
png_path = tempfile.mktemp(suffix='.png')
cairosvg.svg2png(url=svg_path, write_to=png_path)
return png_path
def convert_pdf_to_images(pdf_path):
from pdf2image import convert_from_path
images = convert_from_path(pdf_path)
image_paths = []
for i, image in enumerate(images):
image_path = tempfile.mktemp(suffix=f'_page_{i + 1}.png')
image.save(image_path, 'PNG')
image_paths.append(image_path)
return image_paths
def process_file(file_path):
_, file_extension = os.path.splitext(file_path)
if file_extension.lower() == '.svg':
png_path = convert_svg_to_png(file_path)
return [png_path] if png_path else []
elif file_extension.lower() == '.pdf':
return convert_pdf_to_images(file_path)
else:
# For standard image files, just return the original file path
return [file_path]
def main():
default_max_time = int(os.getenv('H2OGPT_AGENT_OPENAI_TIMEOUT', "120"))
parser = argparse.ArgumentParser(description="OpenAI Vision API Script")
parser.add_argument("--timeout", type=int, default=60, help="Timeout for API calls")
parser.add_argument("--system_prompt", type=str,
default="""You are a highly capable AI assistant with advanced vision capabilities.
* Analyze the provided image thoroughly and provide detailed, accurate descriptions or answers based on what you see.
* Consider various aspects such as objects, people, actions, text, colors, composition, and any other relevant details.
* If asked a specific question about the image, focus your response on addressing that question directly.
* Ensure you add a critique of the image, if anything seems wrong, or if anything requires improvement.""",
help="System prompt")
parser.add_argument("--prompt", "--query", type=str, required=True, help="User prompt")
parser.add_argument("--url", type=str, help="URL of the image")
parser.add_argument("--file", type=str,
help="Path to the image file. Accepts standard image formats (e.g., PNG, JPEG, JPG), SVG, and PDF files.")
parser.add_argument("--model", type=str, help="OpenAI or Open Source model to use")
parser.add_argument("--temperature", type=float, default=0.0, help="Temperature for the model")
parser.add_argument("--max_tokens", type=int, default=1024, help="Maximum tokens for the model")
parser.add_argument("--stream_output", help="Whether to stream output", default=True, action='store_true')
parser.add_argument("--max_time", type=float, default=default_max_time, help="Maximum time to wait for response")
args = parser.parse_args()
if not args.model:
args.model = os.getenv('H2OGPT_OPENAI_VISION_MODEL')
if not args.model:
raise ValueError("Model name must be provided via --model or H2OGPT_OPENAI_VISION_MODEL environment variable")
base_url = os.getenv('H2OGPT_OPENAI_BASE_URL')
assert base_url is not None, "H2OGPT_OPENAI_BASE_URL environment variable is not set"
server_api_key = os.getenv('H2OGPT_OPENAI_API_KEY', 'EMPTY')
from openai import OpenAI
client = OpenAI(base_url=base_url, api_key=server_api_key, timeout=args.timeout)
assert args.url or args.file, "Either --url or --file must be provided"
assert not (args.url and args.file), "--url and --file cannot be used together"
# if the file is a URL, use it as the URL
from openai_server.agent_tools.common.utils import filename_is_url
if filename_is_url(args.file):
args.url = args.file
args.file = None
if args.file:
from openai_server.openai_client import file_to_base64
image_paths = process_file(args.file)
if not image_paths:
raise ValueError(f"Unsupported file type: {args.file}")
image_contents = [
{
'type': 'image_url',
'image_url': {
'url': file_to_base64(image_path)[image_path],
'detail': 'high',
},
} for image_path in image_paths
]
else:
image_paths = []
image_contents = [{
'type': 'image_url',
'image_url': {
'url': args.url,
'detail': 'high',
},
}]
messages = [
{"role": "system", "content": args.system_prompt},
{
'role': 'user',
'content': [
{'type': 'text', 'text': args.prompt},
] + image_contents,
}
]
responses = client.chat.completions.create(
messages=messages,
model=args.model,
temperature=args.temperature,
max_tokens=args.max_tokens,
extra_body=dict(rotate_align_resize_image=True),
stream=args.stream_output,
)
if args.stream_output:
text = ''
first_delta = True
tgen0 = time.time()
verbose = True
for chunk in responses:
delta = chunk.choices[0].delta.content if chunk.choices else None
if delta:
text += delta
if first_delta:
first_delta = False
print("**Vision Model Response:**\n\n", flush=True)
print(delta, flush=True, end='')
if time.time() - tgen0 > args.max_time:
if verbose:
print("Took too long for OpenAI or VLLM Chat: %s" % (time.time() - tgen0),
flush=True)
break
if not text:
print("**Vision Model returned an empty response**", flush=True)
else:
text = responses.choices[0].message.content if responses.choices else ''
if text:
print("**Vision Model Response:**\n\n", text, flush=True)
else:
print("**Vision Model returned an empty response**", flush=True)
# Cleanup temporary files
for image_path in image_paths:
if image_path != args.file: # Don't delete the original file
try:
os.remove(image_path)
except Exception as e:
logger.warning(f"Failed to delete temporary file {image_path}: {str(e)}")
if __name__ == "__main__":
main()
|