|
import base64 |
|
import functools |
|
import os |
|
import tempfile |
|
import time |
|
import types |
|
import uuid |
|
from functools import partial |
|
from io import BytesIO |
|
import numpy as np |
|
from PIL.Image import Resampling |
|
|
|
from gradio_utils.grclient import check_job |
|
from src.enums import valid_imagegen_models, valid_imagechange_models, valid_imagestyle_models, docs_joiner_default, \ |
|
llava16_model_max_length, llava16_image_tokens, llava16_image_fudge, VIDEO_EXTENSIONS, IMAGE_EXTENSIONS |
|
from src.image_utils import fix_image_file |
|
from src.utils import is_gradio_version4, get_docs_tokens, get_limited_text, makedirs, call_subprocess_onetask, \ |
|
have_fiftyone, sanitize_filename |
|
|
|
def is_animated_gif(file_path): |
|
if not file_path.endswith('.gif'): |
|
return False |
|
from PIL import Image, UnidentifiedImageError |
|
try: |
|
gif = Image.open(file_path) |
|
except (FileNotFoundError, UnidentifiedImageError): |
|
return False |
|
try: |
|
gif.seek(1) |
|
except EOFError: |
|
return False |
|
else: |
|
return True |
|
|
|
|
|
def gif_to_mp4(gif_path): |
|
from moviepy.editor import VideoFileClip |
|
""" |
|
Convert an animated GIF to an MP4 video. |
|
|
|
:param gif_path: Path to the input GIF file. |
|
:param mp4_path: Path to the output MP4 file. |
|
""" |
|
clip = VideoFileClip(gif_path) |
|
mp4_path = gif_path.replace('.gif', '.mp4') |
|
clip.write_videofile(mp4_path, codec='libx264') |
|
return mp4_path |
|
|
|
|
|
def is_video_file(file_path): |
|
""" |
|
Determine if the file is a video by checking its extension, frame count, and frame rate. |
|
|
|
:param file_path: Path to the file. |
|
:return: True if the file is a video, False otherwise. |
|
""" |
|
ext = os.path.splitext(file_path)[-1].lower() |
|
if ext not in VIDEO_EXTENSIONS: |
|
return False |
|
|
|
import cv2 |
|
video = cv2.VideoCapture(file_path) |
|
frame_count = video.get(cv2.CAP_PROP_FRAME_COUNT) |
|
frame_rate = video.get(cv2.CAP_PROP_FPS) |
|
video.release() |
|
|
|
|
|
return frame_count >= 1 and frame_rate > 0 |
|
|
|
|
|
def img_to_base64(image_file, resolution=None, output_format=None, str_bytes=True): |
|
|
|
from PIL import Image |
|
|
|
from pathlib import Path |
|
ext = Path(image_file).suffix |
|
iformat = IMAGE_EXTENSIONS.get(ext) |
|
assert iformat is not None, "Invalid file extension %s for file %s" % (ext, image_file) |
|
|
|
image = Image.open(image_file) |
|
|
|
if resolution: |
|
image = image.resize(resolution, resample=Resampling.BICUBIC) |
|
|
|
if output_format: |
|
oformat = output_format.upper() |
|
elif iformat not in ['JPEG', 'PNG']: |
|
|
|
oformat = 'JPEG' |
|
else: |
|
oformat = iformat |
|
|
|
buffered = BytesIO() |
|
image.save(buffered, format=oformat) |
|
img_str = base64.b64encode(buffered.getvalue()) |
|
|
|
|
|
if str_bytes: |
|
img_str = str(bytes("data:image/%s;base64," % oformat.lower(), encoding='utf-8') + img_str) |
|
else: |
|
img_str = f"data:image/{oformat.lower()};base64,{img_str.decode('utf-8')}" |
|
|
|
return img_str |
|
|
|
|
|
def base64_to_img(img_str, output_path): |
|
""" |
|
Convert a base64 string to an image or video file. |
|
|
|
:param img_str: The base64 encoded string with the image or video data. |
|
:param output_path: The path (without extension) where the output file will be saved. |
|
:return: The path to the saved file. |
|
""" |
|
if img_str.startswith("b'"): |
|
|
|
img_str = img_str[2:-1] |
|
|
|
|
|
meta, base64_data = img_str.split(",", 1) |
|
|
|
img_format = meta.split(';')[0].split('/')[-1] |
|
|
|
img_bytes = base64.b64decode(base64_data) |
|
|
|
output_file = f"{output_path}.{img_format}" |
|
|
|
with open(output_file, "wb") as f: |
|
f.write(img_bytes) |
|
print(f"Image saved to {output_file} with format {img_format}") |
|
return output_file |
|
|
|
|
|
def video_to_base64frames(video_path): |
|
import cv2 |
|
video = cv2.VideoCapture(video_path) |
|
|
|
base64Frames = [] |
|
while video.isOpened(): |
|
success, frame = video.read() |
|
if not success: |
|
break |
|
_, buffer = cv2.imencode(".jpg", frame) |
|
base64Frames.append(base64.b64encode(buffer).decode("utf-8")) |
|
|
|
video.release() |
|
print(len(base64Frames), "frames read.") |
|
return base64Frames |
|
|
|
|
|
@functools.lru_cache(maxsize=10000, typed=False) |
|
def video_to_frames(video_path, output_dir, resolution=None, image_format="jpg", video_frame_period=None, |
|
extract_frames=None, |
|
verbose=False): |
|
import cv2 |
|
""" |
|
Convert video to frames, save them as image files in the specified format, and return the list of file names. |
|
|
|
:param video_path: Path to the input video file. |
|
:param output_dir: Directory where the output frames will be saved. |
|
:param resolution: Tuple specifying the desired resolution (width, height) or None to keep the original resolution. |
|
:param image_format: String specifying the desired image format (e.g., "jpg", "png"). |
|
:param video_frame_period: How often to sample frames from the video. If None, every 20th frame is saved. |
|
e.g. if pass non-real-time video, can set to 1 to save all frames, to mimic passing actual frames separately otherwise |
|
:param extract_frames: Number of frames to extract from the video. If None, all frames are saved. |
|
:param verbose: Boolean to control whether to print progress messages. |
|
:return: List of file names for the saved frames. |
|
|
|
Example usage: |
|
file_names = video_to_frames("input_video.mp4", "output_frames", resolution=(640, 480), image_format="png", verbose=True) |
|
print(file_names) |
|
""" |
|
if output_dir is None: |
|
output_dir = os.path.join(tempfile.gettempdir(), 'image_path_%s' % sanitize_filename(video_path)) |
|
|
|
enable_fiftyone = True |
|
if enable_fiftyone and \ |
|
have_fiftyone and \ |
|
(video_frame_period is not None and video_frame_period < 1 or not os.path.isfile(video_path)): |
|
|
|
from src.vision.extract_movie import extract_unique_frames |
|
args = () |
|
urls = [video_path] if not os.path.isfile(video_path) else None |
|
file = video_path if os.path.isfile(video_path) else None |
|
kwargs = {'urls': urls, 'file': file, 'download_dir': None, 'export_dir': output_dir, |
|
'extract_frames': extract_frames} |
|
|
|
if False: |
|
func_new = partial(call_subprocess_onetask, extract_unique_frames, args, kwargs) |
|
else: |
|
func_new = functools.partial(extract_unique_frames, *args, **kwargs) |
|
export_dir = func_new() |
|
return [os.path.join(export_dir, x) for x in os.listdir(export_dir)] |
|
|
|
if video_frame_period and video_frame_period < 1: |
|
video_frame_period = None |
|
if video_frame_period in [None, 0]: |
|
|
|
total_frames = count_frames(video_path) |
|
extract_frames = min(20, extract_frames or 20) |
|
video_frame_period = total_frames // extract_frames |
|
|
|
video = cv2.VideoCapture(video_path) |
|
makedirs(output_dir) |
|
|
|
image_format = image_format or '.jpg' |
|
|
|
frame_count = 0 |
|
file_names = [] |
|
while True: |
|
success, frame = video.read() |
|
if not success: |
|
break |
|
|
|
|
|
if frame_count % video_frame_period != 0: |
|
frame_count += 1 |
|
continue |
|
if resolution: |
|
frame = cv2.resize(frame, resolution) |
|
|
|
frame_filename = os.path.join(output_dir, f"frame_{frame_count:04d}.{image_format}") |
|
cv2.imwrite(frame_filename, frame) |
|
file_names.append(frame_filename) |
|
frame_count += 1 |
|
video.release() |
|
|
|
if verbose: |
|
print(f"{frame_count} frames saved to {output_dir}.") |
|
|
|
return file_names |
|
|
|
|
|
def count_frames(video_path): |
|
import cv2 |
|
|
|
video = cv2.VideoCapture(video_path) |
|
|
|
|
|
if not video.isOpened(): |
|
print("Error: Could not open video.") |
|
return -1 |
|
|
|
|
|
total_frames = int(video.get(cv2.CAP_PROP_FRAME_COUNT)) |
|
|
|
|
|
video.release() |
|
|
|
return total_frames |
|
|
|
|
|
def process_file_list(file_list, output_dir, resolution=None, image_format="jpg", |
|
rotate_align_resize_image=True, |
|
video_frame_period=None, |
|
extract_frames=None, |
|
verbose=False): |
|
|
|
""" |
|
Process a list of files, converting any videos to frames and updating the list to only contain image files. |
|
|
|
:param file_list: List of file paths to be processed. |
|
:param output_dir: Directory where the output frames will be saved. |
|
:param resolution: Tuple specifying the desired resolution (width, height) or None to keep the original resolution. |
|
Does not affect images as inputs, handled elsewhere when converting to base64 for LLM |
|
:param image_format: String specifying the desired image format (e.g., "jpg", "png"). |
|
:param rotate_align_resize_image: Whether to apply rotation, alignment, resize before giving to LLM |
|
:param video_frame_period: Period to save frames, if <1 then automatic |
|
:param extract_frames: how many frames to extract if automatic period mode |
|
:param verbose: Boolean to control whether to print progress messages. |
|
:return: Updated list of file names containing only image files. |
|
""" |
|
if file_list is None: |
|
file_list = [] |
|
if image_format is None: |
|
image_format = 'jpg' |
|
|
|
image_files = [] |
|
|
|
for file in file_list: |
|
|
|
is_maybe_video = os.path.isfile(file) and is_video_file(file) or not os.path.isfile(file) or is_animated_gif( |
|
file) |
|
if is_animated_gif(file): |
|
|
|
|
|
extract_frames = None |
|
if video_frame_period is not None and video_frame_period < 1: |
|
video_frame_period = None |
|
|
|
if is_maybe_video: |
|
|
|
if verbose: |
|
print(f"Processing video file: {file}") |
|
|
|
frame_files = video_to_frames(file, None, resolution, image_format, video_frame_period, |
|
extract_frames, verbose) |
|
image_files.extend(frame_files) |
|
else: |
|
|
|
if rotate_align_resize_image: |
|
file_fixed = fix_image_file(file, do_align=True, do_rotate=True, do_pad=False, relaxed_resize=True) |
|
else: |
|
file_fixed = file |
|
image_files.append(file_fixed) |
|
|
|
return image_files |
|
|
|
|
|
def fix_llava_prompt(file, |
|
prompt=None, |
|
allow_prompt_auto=True, |
|
): |
|
if prompt in ['auto', None] and allow_prompt_auto: |
|
prompt = "Describe the image and what does the image say?" |
|
|
|
if file in ['', None]: |
|
|
|
prompt = '' |
|
|
|
if prompt is None: |
|
if os.environ.get('HARD_ASSERTS'): |
|
raise ValueError('prompt is None') |
|
else: |
|
prompt = '' |
|
return prompt |
|
|
|
|
|
def llava_prep(file_list, |
|
llava_model, |
|
image_model='llava-v1.6-vicuna-13b', |
|
client=None): |
|
assert client is not None or len(file_list) == 1 |
|
|
|
file_list_new = [] |
|
image_model_list_new = [] |
|
for file in file_list: |
|
image_model_new, client, file_new = _llava_prep(file, |
|
llava_model, |
|
image_model=image_model, |
|
client=client) |
|
file_list_new.append(file_new) |
|
image_model_list_new.append(image_model_new) |
|
assert len(image_model_list_new) >= 1 |
|
assert len(file_list_new) >= 1 |
|
return image_model_list_new[0], client, file_list_new |
|
|
|
|
|
def _llava_prep(file, |
|
llava_model, |
|
image_model='llava-v1.6-vicuna-13b', |
|
client=None): |
|
prefix = '' |
|
if llava_model.startswith('http://'): |
|
prefix = 'http://' |
|
if llava_model.startswith('https://'): |
|
prefix = 'https://' |
|
llava_model = llava_model[len(prefix):] |
|
|
|
llava_model_split = llava_model.split(':') |
|
assert len(llava_model_split) >= 2 |
|
|
|
if len(llava_model_split) >= 2: |
|
pass |
|
|
|
|
|
|
|
if len(llava_model_split) >= 3: |
|
image_model = llava_model_split[2] |
|
llava_model = ':'.join(llava_model_split[:2]) |
|
|
|
llava_model = prefix + llava_model |
|
|
|
if client is None: |
|
from gradio_utils.grclient import GradioClient |
|
client = GradioClient(llava_model, check_hash=False, serialize=is_gradio_version4) |
|
client.setup() |
|
|
|
if not is_gradio_version4 and file and os.path.isfile(file): |
|
file = img_to_base64(file) |
|
|
|
assert image_model, "No image model specified" |
|
|
|
if isinstance(file, np.ndarray): |
|
from PIL import Image |
|
im = Image.fromarray(file) |
|
file = "%s.jpeg" % str(uuid.uuid4()) |
|
im.save(file) |
|
|
|
return image_model, client, file |
|
|
|
|
|
server_error_msg = "**NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.**" |
|
|
|
|
|
def get_prompt_with_texts(texts, prompt, max_new_tokens, min_max_new_tokens, tokenizer): |
|
if tokenizer is None: |
|
raise RuntimeError("Not setup for multi-image without tokenizer") |
|
|
|
|
|
if hasattr(tokenizer, 'model_max_length'): |
|
model_max_length = tokenizer.model_max_length |
|
else: |
|
model_max_length = llava16_model_max_length |
|
|
|
user_part = '\n\nReduce the above information into single correct answer to the following question: ' + prompt |
|
user_part_tokens = len(tokenizer.encode(user_part)) |
|
|
|
text_context_list = ['Answer #%s:\n\n%s' % (ii, text) for ii, text in enumerate(texts)] |
|
|
|
|
|
text_tokens_trial = len(tokenizer.encode(docs_joiner_default.join(text_context_list))) |
|
if user_part_tokens + text_tokens_trial + max_new_tokens >= model_max_length: |
|
max_new_tokens = min_max_new_tokens |
|
fudge = llava16_image_fudge |
|
max_input_tokens = model_max_length - max_new_tokens - fudge |
|
|
|
top_k_docs, one_doc_size, num_doc_tokens = \ |
|
get_docs_tokens(tokenizer, text_context_list=text_context_list, max_input_tokens=max_input_tokens) |
|
text_context_list_cut = text_context_list[:top_k_docs] |
|
texts_joined = docs_joiner_default.join(text_context_list_cut) |
|
|
|
prompt_with_texts = '\n"""\n' + texts_joined + '\n"""\n' |
|
prompt_with_texts += user_part |
|
|
|
return prompt_with_texts.replace('image', 'document').replace('Image', 'Document') |
|
|
|
|
|
def get_llava_response(file=None, |
|
llava_model=None, |
|
prompt=None, |
|
chat_conversation=[], |
|
allow_prompt_auto=False, |
|
image_model='llava-v1.6-vicuna-13b', temperature=0.2, |
|
top_p=0.7, max_new_tokens=512, |
|
min_max_new_tokens=512, |
|
tokenizer=None, |
|
image_process_mode="Default", |
|
include_image=False, |
|
client=None, |
|
max_time=None, |
|
force_stream=True, |
|
verbose=False, |
|
): |
|
max_new_tokens = min(max_new_tokens, 1024) |
|
|
|
kwargs = locals().copy() |
|
|
|
force_stream |= isinstance(file, list) and len(file) > 1 |
|
if isinstance(file, str): |
|
file_list = [file] |
|
elif isinstance(file, list): |
|
file_list = file |
|
if len(file_list) == 0: |
|
file_list = [None] |
|
else: |
|
file_list = [None] |
|
|
|
if force_stream: |
|
text = '' |
|
for res in get_llava_stream(**kwargs): |
|
text = res |
|
return text, prompt |
|
|
|
image_model = os.path.basename(image_model) |
|
prompt = fix_llava_prompt(file_list, prompt, allow_prompt_auto=allow_prompt_auto) |
|
max_new_tokens1 = max_new_tokens if len(file_list) <= 4 else min(max_new_tokens, min_max_new_tokens) |
|
if tokenizer: |
|
model_max_length = tokenizer.model_max_length |
|
else: |
|
model_max_length = llava16_model_max_length |
|
image_tokens = llava16_image_tokens if len(file_list) >= 1 and file_list[0] is not None else 0 |
|
fudge = llava16_image_fudge |
|
hard_limit_tokens = model_max_length - max_new_tokens1 - fudge - image_tokens |
|
prompt = get_limited_text(hard_limit_tokens, prompt, tokenizer, verbose=False) |
|
|
|
image_model, client, file_list = \ |
|
llava_prep(file_list, llava_model, |
|
image_model=image_model, |
|
client=client) |
|
|
|
reses = [] |
|
for file in file_list: |
|
res = client.predict(prompt, |
|
chat_conversation if len(file_list) == 1 else [], |
|
file, |
|
image_process_mode, |
|
include_image, |
|
image_model, |
|
temperature, |
|
top_p, |
|
max_new_tokens1, |
|
api_name='/textbox_api_submit') |
|
reses.append(res) |
|
|
|
if len(reses) > 1: |
|
reses = [x for x in reses if server_error_msg not in x] |
|
prompt_with_texts = get_prompt_with_texts(reses, prompt, max_new_tokens, min_max_new_tokens, tokenizer) |
|
res = client.predict(prompt_with_texts, |
|
chat_conversation, |
|
None, |
|
image_process_mode, |
|
include_image, |
|
image_model, |
|
temperature, |
|
top_p, |
|
max_new_tokens, |
|
api_name='/textbox_api_submit') |
|
else: |
|
res = reses[0] |
|
|
|
return res, prompt |
|
|
|
|
|
def get_llava_stream(file, llava_model, |
|
prompt=None, |
|
chat_conversation=[], |
|
allow_prompt_auto=False, |
|
image_model='llava-v1.6-vicuna-13b', temperature=0.2, |
|
top_p=0.7, max_new_tokens=512, |
|
min_max_new_tokens=512, |
|
tokenizer=None, |
|
image_process_mode="Default", |
|
include_image=False, |
|
client=None, |
|
verbose_level=0, |
|
max_time=None, |
|
force_stream=True, |
|
verbose=False, |
|
): |
|
max_new_tokens = min(max_new_tokens, 1024) |
|
|
|
if isinstance(file, str): |
|
file_list = [file] |
|
elif isinstance(file, list): |
|
file_list = file |
|
if len(file_list) == 0: |
|
file_list = [None] |
|
else: |
|
file_list = [None] |
|
|
|
image_model = os.path.basename(image_model) |
|
prompt = fix_llava_prompt(file_list, prompt, allow_prompt_auto=allow_prompt_auto) |
|
max_new_tokens1 = max_new_tokens if len(file_list) <= 4 else min(max_new_tokens, min_max_new_tokens) |
|
if tokenizer: |
|
model_max_length = tokenizer.model_max_length |
|
else: |
|
model_max_length = llava16_model_max_length |
|
image_tokens = llava16_image_tokens if len(file_list) >= 1 and file_list[0] is not None else 0 |
|
fudge = llava16_image_fudge |
|
hard_limit_tokens = model_max_length - max_new_tokens1 - fudge - image_tokens |
|
prompt = get_limited_text(hard_limit_tokens, prompt, tokenizer) |
|
|
|
image_model, client, file_list = \ |
|
llava_prep(file_list, llava_model, |
|
image_model=image_model, |
|
client=client) |
|
|
|
jobs = [] |
|
for file in file_list: |
|
job = client.submit(prompt, |
|
chat_conversation, |
|
file, |
|
image_process_mode, |
|
include_image, |
|
image_model, |
|
temperature, |
|
top_p, |
|
max_new_tokens1, |
|
api_name='/textbox_api_submit') |
|
jobs.append(job) |
|
|
|
t0 = time.time() |
|
job_outputs_nums = [0] * len(jobs) |
|
texts = [''] * len(jobs) |
|
done_all = False |
|
reses = [''] * len(jobs) |
|
while True: |
|
for ji, job in enumerate(jobs): |
|
if verbose_level == 2: |
|
print("Inside: %s" % llava_model, time.time() - t0, flush=True) |
|
e = check_job(job, timeout=0, raise_exception=False) |
|
if e is not None: |
|
continue |
|
if max_time is not None and time.time() - t0 > max_time: |
|
done_all = True |
|
break |
|
outputs_list = job.outputs().copy() |
|
job_outputs_num_new = len(outputs_list[job_outputs_nums[ji]:]) |
|
for num in range(job_outputs_num_new): |
|
reses[ji] = outputs_list[job_outputs_nums[ji] + num] |
|
if verbose_level == 2: |
|
print('Stream %d: %s' % (num, reses[ji]), flush=True) |
|
elif verbose_level == 1: |
|
print('Stream %d' % (job_outputs_nums[ji] + num), flush=True) |
|
if reses[ji]: |
|
texts[ji] = reses[ji] |
|
if len(jobs) == 1: |
|
yield texts[ji] |
|
job_outputs_nums[ji] += job_outputs_num_new |
|
time.sleep(0.005) |
|
if done_all or all([job.done() for job in jobs]): |
|
break |
|
|
|
for ji, job in enumerate(jobs): |
|
e = check_job(job, timeout=0, raise_exception=False) |
|
if e is not None: |
|
continue |
|
outputs_list = job.outputs().copy() |
|
job_outputs_num_new = len(outputs_list[job_outputs_nums[ji]:]) |
|
for num in range(job_outputs_num_new): |
|
reses[ji] = outputs_list[job_outputs_nums[ji] + num] |
|
if verbose_level == 2: |
|
print('Final Stream %d: %s' % (num, reses[ji]), flush=True) |
|
elif verbose_level == 1: |
|
print('Final Stream %d' % (job_outputs_nums[ji] + num), flush=True) |
|
if reses[ji]: |
|
texts[ji] = reses[ji] |
|
if len(jobs) == 1: |
|
yield texts[ji] |
|
job_outputs_nums[ji] += job_outputs_num_new |
|
if verbose_level == 1: |
|
print("total job_outputs_num=%d" % job_outputs_nums[ji], flush=True) |
|
|
|
if len(jobs) > 1: |
|
|
|
ntexts_before = len(texts) |
|
texts = [x for x in texts if server_error_msg not in x] |
|
ntexts_after = len(texts) |
|
if ntexts_after != ntexts_before: |
|
print("texts: %s -> %s" % (ntexts_before, ntexts_after)) |
|
prompt_with_texts = get_prompt_with_texts(texts, prompt, max_new_tokens, min_max_new_tokens, tokenizer) |
|
text = '' |
|
max_new_tokens = max_new_tokens if len(jobs) > 4 else min(max_new_tokens, min_max_new_tokens) |
|
for res in get_llava_stream(None, |
|
llava_model, |
|
prompt=prompt_with_texts, |
|
chat_conversation=chat_conversation, |
|
allow_prompt_auto=allow_prompt_auto, |
|
image_model=image_model, |
|
temperature=temperature, |
|
top_p=top_p, |
|
|
|
max_new_tokens=max_new_tokens, |
|
min_max_new_tokens=min_max_new_tokens, |
|
tokenizer=tokenizer, |
|
image_process_mode=image_process_mode, |
|
include_image=include_image, |
|
client=client, |
|
verbose_level=verbose_level, |
|
max_time=max_time, |
|
force_stream=force_stream, |
|
verbose=verbose, |
|
): |
|
text = res |
|
yield text |
|
else: |
|
assert len(texts) == 1 |
|
text = texts[0] |
|
|
|
return text |
|
|
|
|
|
def get_image_model_dict(enable_image, |
|
image_models, |
|
image_gpu_ids, |
|
): |
|
image_dict = {} |
|
if not enable_image: |
|
return image_dict |
|
|
|
if image_gpu_ids is None: |
|
image_gpu_ids = ['auto'] * len(image_models) |
|
if not image_gpu_ids: |
|
image_gpu_ids = ['auto'] * len(image_models) |
|
|
|
for image_model_name in valid_imagegen_models + valid_imagechange_models + valid_imagestyle_models: |
|
if image_model_name in image_models: |
|
imagegen_index = image_models.index(image_model_name) |
|
if image_model_name == 'sdxl_turbo': |
|
from src.vision.sdxl_turbo import get_pipe_make_image, make_image |
|
elif image_model_name == 'playv2': |
|
from src.vision.playv2 import get_pipe_make_image, make_image |
|
elif image_model_name == 'sdxl': |
|
from src.vision.stable_diffusion_xl import get_pipe_make_image, make_image |
|
elif image_model_name == 'sd3': |
|
from src.vision.stable_diffusion_xl import get_pipe_make_image, make_image |
|
get_pipe_make_image = functools.partial(get_pipe_make_image, |
|
base_model='stabilityai/stable-diffusion-3-medium-diffusers', |
|
refiner_model=None) |
|
make_image = functools.partial(make_image, |
|
base_model='stabilityai/stable-diffusion-3-medium-diffusers', |
|
refiner_model=None) |
|
elif image_model_name == 'flux.1-dev': |
|
from src.vision.flux import get_pipe_make_image, make_image |
|
elif image_model_name == 'flux.1-schnell': |
|
from src.vision.flux import get_pipe_make_image_2 as get_pipe_make_image |
|
from src.vision.flux import make_image |
|
elif image_model_name == 'sdxl_change': |
|
from src.vision.sdxl_turbo import get_pipe_change_image as get_pipe_make_image, change_image |
|
make_image = change_image |
|
|
|
else: |
|
raise ValueError("Invalid image_model_name=%s" % image_model_name) |
|
pipe = get_pipe_make_image(gpu_id=image_gpu_ids[imagegen_index]) |
|
image_dict[image_model_name] = dict(pipe=pipe, make_image=make_image) |
|
return image_dict |
|
|
|
|
|
def pdf_to_base64_pngs(pdf_path, quality=75, max_size=(1024, 1024), ext='png', pages=None): |
|
""" |
|
Define the function to convert a pdf slide deck to a list of images. Note that we need to ensure we resize images to keep them within Claude's size limits. |
|
""" |
|
|
|
from PIL import Image |
|
import io |
|
import fitz |
|
import tempfile |
|
|
|
|
|
doc = fitz.open(pdf_path) |
|
|
|
|
|
images = [] |
|
if pages is None: |
|
pages = list(range(doc.page_count)) |
|
else: |
|
assert isinstance(pages, (list, tuple, types.GeneratorType)) |
|
|
|
for page_num in pages: |
|
|
|
page = doc.load_page(page_num) |
|
|
|
|
|
pix = page.get_pixmap(matrix=fitz.Matrix(300 / 72, 300 / 72)) |
|
|
|
|
|
output_path = f"{tempfile.mkdtemp()}/page_{page_num + 1}.{ext}" |
|
pix.save(output_path) |
|
images.append(output_path) |
|
|
|
doc.close() |
|
|
|
if ext == 'png': |
|
iformat = 'PNG' |
|
elif ext in ['jpeg', 'jpg']: |
|
iformat = 'JPEG' |
|
else: |
|
raise ValueError("No such ext=%s" % ext) |
|
|
|
images = [Image.open(image) for image in images] |
|
base64_encoded_pngs = [] |
|
for image in images: |
|
|
|
if image.size[0] > max_size[0] or image.size[1] > max_size[1]: |
|
image.thumbnail(max_size, Image.Resampling.LANCZOS) |
|
image_data = io.BytesIO() |
|
image.save(image_data, format=iformat, optimize=True, quality=quality) |
|
image_data.seek(0) |
|
base64_encoded = base64.b64encode(image_data.getvalue()).decode('utf-8') |
|
base64_encoded_pngs.append(base64_encoded) |
|
|
|
return base64_encoded_pngs |
|
|