import base64 |
import functools |
import os |
import tempfile |
import time |
import types |
import uuid |
from functools import partial |
from io import BytesIO |
import numpy as np |
from PIL.Image import Resampling |
from gradio_utils.grclient import check_job |
from src.enums import valid_imagegen_models, valid_imagechange_models, valid_imagestyle_models, docs_joiner_default, \ |
llava16_model_max_length, llava16_image_tokens, llava16_image_fudge, VIDEO_EXTENSIONS, IMAGE_EXTENSIONS |
from src.image_utils import fix_image_file |
from src.utils import is_gradio_version4, get_docs_tokens, get_limited_text, makedirs, call_subprocess_onetask, \ |
have_fiftyone, sanitize_filename |
def is_animated_gif(file_path): |
if not file_path.endswith('.gif'): |
return False |
from PIL import Image, UnidentifiedImageError |
try: |
gif = Image.open(file_path) |
except (FileNotFoundError, UnidentifiedImageError): |
return False |
try: |
gif.seek(1) |
except EOFError: |
return False |
else: |
return True |
def gif_to_mp4(gif_path): |
from moviepy.editor import VideoFileClip |
""" |
Convert an animated GIF to an MP4 video. |
:param gif_path: Path to the input GIF file. |
:param mp4_path: Path to the output MP4 file. |
""" |
clip = VideoFileClip(gif_path) |
mp4_path = gif_path.replace('.gif', '.mp4') |
clip.write_videofile(mp4_path, codec='libx264') |
return mp4_path |
def is_video_file(file_path): |
""" |
Determine if the file is a video by checking its extension, frame count, and frame rate. |
:param file_path: Path to the file. |
:return: True if the file is a video, False otherwise. |
""" |
ext = os.path.splitext(file_path)[-1].lower() |
if ext not in VIDEO_EXTENSIONS: |
return False |
import cv2 |
video = cv2.VideoCapture(file_path) |
frame_count = video.get(cv2.CAP_PROP_FRAME_COUNT) |
frame_rate = video.get(cv2.CAP_PROP_FPS) |
video.release() |
return frame_count >= 1 and frame_rate > 0 |
def img_to_base64(image_file, resolution=None, output_format=None, str_bytes=True): |
from PIL import Image |
from pathlib import Path |
ext = Path(image_file).suffix |
iformat = IMAGE_EXTENSIONS.get(ext) |
assert iformat is not None, "Invalid file extension %s for file %s" % (ext, image_file) |
image = Image.open(image_file) |
if resolution: |
image = image.resize(resolution, resample=Resampling.BICUBIC) |
if output_format: |
oformat = output_format.upper() |
elif iformat not in ['JPEG', 'PNG']: |
oformat = 'JPEG' |
else: |
oformat = iformat |
buffered = BytesIO() |
image.save(buffered, format=oformat) |
img_str = base64.b64encode(buffered.getvalue()) |
if str_bytes: |
img_str = str(bytes("data:image/%s;base64," % oformat.lower(), encoding='utf-8') + img_str) |
else: |
img_str = f"data:image/{oformat.lower()};base64,{img_str.decode('utf-8')}" |
return img_str |
def base64_to_img(img_str, output_path): |
""" |
Convert a base64 string to an image or video file. |
:param img_str: The base64 encoded string with the image or video data. |
:param output_path: The path (without extension) where the output file will be saved. |
:return: The path to the saved file. |
""" |
if img_str.startswith("b'"): |
img_str = img_str[2:-1] |
meta, base64_data = img_str.split(",", 1) |
img_format = meta.split(';')[0].split('/')[-1] |
img_bytes = base64.b64decode(base64_data) |
output_file = f"{output_path}.{img_format}" |
with open(output_file, "wb") as f: |
f.write(img_bytes) |
print(f"Image saved to {output_file} with format {img_format}") |
return output_file |
def video_to_base64frames(video_path): |
import cv2 |
video = cv2.VideoCapture(video_path) |
base64Frames = [] |
while video.isOpened(): |
success, frame = video.read() |
if not success: |
break |
_, buffer = cv2.imencode(".jpg", frame) |
base64Frames.append(base64.b64encode(buffer).decode("utf-8")) |
video.release() |
print(len(base64Frames), "frames read.") |
return base64Frames |
@functools.lru_cache(maxsize=10000, typed=False) |
def video_to_frames(video_path, output_dir, resolution=None, image_format="jpg", video_frame_period=None, |
extract_frames=None, |
verbose=False): |
import cv2 |
""" |
Convert video to frames, save them as image files in the specified format, and return the list of file names. |
:param video_path: Path to the input video file. |
:param output_dir: Directory where the output frames will be saved. |
:param resolution: Tuple specifying the desired resolution (width, height) or None to keep the original resolution. |
:param image_format: String specifying the desired image format (e.g., "jpg", "png"). |
:param video_frame_period: How often to sample frames from the video. If None, every 20th frame is saved. |
e.g. if pass non-real-time video, can set to 1 to save all frames, to mimic passing actual frames separately otherwise |
:param extract_frames: Number of frames to extract from the video. If None, all frames are saved. |
:param verbose: Boolean to control whether to print progress messages. |
:return: List of file names for the saved frames. |
Example usage: |
file_names = video_to_frames("input_video.mp4", "output_frames", resolution=(640, 480), image_format="png", verbose=True) |
print(file_names) |
""" |
if output_dir is None: |
output_dir = os.path.join(tempfile.gettempdir(), 'image_path_%s' % sanitize_filename(video_path)) |
enable_fiftyone = True |
if enable_fiftyone and \ |
have_fiftyone and \ |
(video_frame_period is not None and video_frame_period < 1 or not os.path.isfile(video_path)): |
from src.vision.extract_movie import extract_unique_frames |
args = () |
urls = [video_path] if not os.path.isfile(video_path) else None |
file = video_path if os.path.isfile(video_path) else None |
kwargs = {'urls': urls, 'file': file, 'download_dir': None, 'export_dir': output_dir, |
'extract_frames': extract_frames} |
if False: |
func_new = partial(call_subprocess_onetask, extract_unique_frames, args, kwargs) |
else: |
func_new = functools.partial(extract_unique_frames, *args, **kwargs) |
export_dir = func_new() |
return [os.path.join(export_dir, x) for x in os.listdir(export_dir)] |
if video_frame_period and video_frame_period < 1: |
video_frame_period = None |
if video_frame_period in [None, 0]: |
total_frames = count_frames(video_path) |
extract_frames = min(20, extract_frames or 20) |
video_frame_period = total_frames // extract_frames |
video = cv2.VideoCapture(video_path) |
makedirs(output_dir) |
image_format = image_format or '.jpg' |
frame_count = 0 |
file_names = [] |
while True: |
success, frame = video.read() |
if not success: |
break |
if frame_count % video_frame_period != 0: |
frame_count += 1 |
continue |
if resolution: |
frame = cv2.resize(frame, resolution) |
frame_filename = os.path.join(output_dir, f"frame_{frame_count:04d}.{image_format}") |
cv2.imwrite(frame_filename, frame) |
file_names.append(frame_filename) |
frame_count += 1 |
video.release() |
if verbose: |
print(f"{frame_count} frames saved to {output_dir}.") |
return file_names |
def count_frames(video_path): |
import cv2 |
video = cv2.VideoCapture(video_path) |
if not video.isOpened(): |
print("Error: Could not open video.") |
return -1 |
total_frames = int(video.get(cv2.CAP_PROP_FRAME_COUNT)) |
video.release() |
return total_frames |
def process_file_list(file_list, output_dir, resolution=None, image_format="jpg", |
rotate_align_resize_image=True, |
video_frame_period=None, |
extract_frames=None, |
verbose=False): |
""" |
Process a list of files, converting any videos to frames and updating the list to only contain image files. |
:param file_list: List of file paths to be processed. |
:param output_dir: Directory where the output frames will be saved. |
:param resolution: Tuple specifying the desired resolution (width, height) or None to keep the original resolution. |
Does not affect images as inputs, handled elsewhere when converting to base64 for LLM |
:param image_format: String specifying the desired image format (e.g., "jpg", "png"). |
:param rotate_align_resize_image: Whether to apply rotation, alignment, resize before giving to LLM |
:param video_frame_period: Period to save frames, if <1 then automatic |
:param extract_frames: how many frames to extract if automatic period mode |
:param verbose: Boolean to control whether to print progress messages. |
:return: Updated list of file names containing only image files. |
""" |
if file_list is None: |
file_list = [] |
if image_format is None: |
image_format = 'jpg' |
image_files = [] |
for file in file_list: |
is_maybe_video = os.path.isfile(file) and is_video_file(file) or not os.path.isfile(file) or is_animated_gif( |
file) |
if is_animated_gif(file): |
extract_frames = None |
if video_frame_period is not None and video_frame_period < 1: |
video_frame_period = None |
if is_maybe_video: |
if verbose: |
print(f"Processing video file: {file}") |
frame_files = video_to_frames(file, None, resolution, image_format, video_frame_period, |
extract_frames, verbose) |
image_files.extend(frame_files) |
else: |
if rotate_align_resize_image: |
file_fixed = fix_image_file(file, do_align=True, do_rotate=True, do_pad=False, relaxed_resize=True) |
else: |
file_fixed = file |
image_files.append(file_fixed) |
return image_files |
def fix_llava_prompt(file, |
prompt=None, |
allow_prompt_auto=True, |
): |
if prompt in ['auto', None] and allow_prompt_auto: |
prompt = "Describe the image and what does the image say?" |
if file in ['', None]: |
prompt = '' |
if prompt is None: |
if os.environ.get('HARD_ASSERTS'): |
raise ValueError('prompt is None') |
else: |
prompt = '' |
return prompt |
def llava_prep(file_list, |
llava_model, |
image_model='llava-v1.6-vicuna-13b', |
client=None): |
assert client is not None or len(file_list) == 1 |
file_list_new = [] |
image_model_list_new = [] |
for file in file_list: |
image_model_new, client, file_new = _llava_prep(file, |
llava_model, |
image_model=image_model, |
client=client) |
file_list_new.append(file_new) |
image_model_list_new.append(image_model_new) |
assert len(image_model_list_new) >= 1 |
assert len(file_list_new) >= 1 |
return image_model_list_new[0], client, file_list_new |
def _llava_prep(file, |
llava_model, |
image_model='llava-v1.6-vicuna-13b', |
client=None): |
prefix = '' |
if llava_model.startswith('http://'): |
prefix = 'http://' |
if llava_model.startswith('https://'): |
prefix = 'https://' |
llava_model = llava_model[len(prefix):] |
llava_model_split = llava_model.split(':') |
assert len(llava_model_split) >= 2 |
if len(llava_model_split) >= 2: |
pass |
if len(llava_model_split) >= 3: |
image_model = llava_model_split[2] |
llava_model = ':'.join(llava_model_split[:2]) |
llava_model = prefix + llava_model |
if client is None: |
from gradio_utils.grclient import GradioClient |
client = GradioClient(llava_model, check_hash=False, serialize=is_gradio_version4) |
client.setup() |
if not is_gradio_version4 and file and os.path.isfile(file): |
file = img_to_base64(file) |
assert image_model, "No image model specified" |
if isinstance(file, np.ndarray): |
from PIL import Image |
im = Image.fromarray(file) |
file = "%s.jpeg" % str(uuid.uuid4()) |
im.save(file) |
return image_model, client, file |
def get_prompt_with_texts(texts, prompt, max_new_tokens, min_max_new_tokens, tokenizer): |
if tokenizer is None: |
raise RuntimeError("Not setup for multi-image without tokenizer") |
if hasattr(tokenizer, 'model_max_length'): |
model_max_length = tokenizer.model_max_length |
else: |
model_max_length = llava16_model_max_length |
user_part = '\n\nReduce the above information into single correct answer to the following question: ' + prompt |
user_part_tokens = len(tokenizer.encode(user_part)) |
text_context_list = ['Answer #%s:\n\n%s' % (ii, text) for ii, text in enumerate(texts)] |
text_tokens_trial = len(tokenizer.encode(docs_joiner_default.join(text_context_list))) |
if user_part_tokens + text_tokens_trial + max_new_tokens >= model_max_length: |
max_new_tokens = min_max_new_tokens |
fudge = llava16_image_fudge |
max_input_tokens = model_max_length - max_new_tokens - fudge |
top_k_docs, one_doc_size, num_doc_tokens = \ |
get_docs_tokens(tokenizer, text_context_list=text_context_list, max_input_tokens=max_input_tokens) |
text_context_list_cut = text_context_list[:top_k_docs] |
texts_joined = docs_joiner_default.join(text_context_list_cut) |
prompt_with_texts = '\n"""\n' + texts_joined + '\n"""\n' |
prompt_with_texts += user_part |
return prompt_with_texts.replace('image', 'document').replace('Image', 'Document') |
def get_llava_response(file=None, |
llava_model=None, |
prompt=None, |
chat_conversation=[], |
allow_prompt_auto=False, |
image_model='llava-v1.6-vicuna-13b', temperature=0.2, |
top_p=0.7, max_new_tokens=512, |
min_max_new_tokens=512, |
tokenizer=None, |
image_process_mode="Default", |
include_image=False, |
client=None, |
max_time=None, |
force_stream=True, |
verbose=False, |
): |
max_new_tokens = min(max_new_tokens, 1024) |
kwargs = locals().copy() |
force_stream |= isinstance(file, list) and len(file) > 1 |
if isinstance(file, str): |
file_list = [file] |
elif isinstance(file, list): |
file_list = file |
if len(file_list) == 0: |
file_list = [None] |
else: |
file_list = [None] |
if force_stream: |
text = '' |
for res in get_llava_stream(**kwargs): |
text = res |
return text, prompt |
image_model = os.path.basename(image_model) |
prompt = fix_llava_prompt(file_list, prompt, allow_prompt_auto=allow_prompt_auto) |
max_new_tokens1 = max_new_tokens if len(file_list) <= 4 else min(max_new_tokens, min_max_new_tokens) |
if tokenizer: |
model_max_length = tokenizer.model_max_length |
else: |
model_max_length = llava16_model_max_length |
image_tokens = llava16_image_tokens if len(file_list) >= 1 and file_list[0] is not None else 0 |
fudge = llava16_image_fudge |
hard_limit_tokens = model_max_length - max_new_tokens1 - fudge - image_tokens |
prompt = get_limited_text(hard_limit_tokens, prompt, tokenizer, verbose=False) |
image_model, client, file_list = \ |
llava_prep(file_list, llava_model, |
image_model=image_model, |
client=client) |
reses = [] |
for file in file_list: |
res = client.predict(prompt, |
chat_conversation if len(file_list) == 1 else [], |
file, |
image_process_mode, |
include_image, |
image_model, |
temperature, |
top_p, |
max_new_tokens1, |
api_name='/textbox_api_submit') |
reses.append(res) |
if len(reses) > 1: |
reses = [x for x in reses if server_error_msg not in x] |
prompt_with_texts = get_prompt_with_texts(reses, prompt, max_new_tokens, min_max_new_tokens, tokenizer) |
res = client.predict(prompt_with_texts, |
chat_conversation, |
None, |
image_process_mode, |
include_image, |
image_model, |
temperature, |
top_p, |
max_new_tokens, |
api_name='/textbox_api_submit') |
else: |
res = reses[0] |
return res, prompt |
def get_llava_stream(file, llava_model, |
prompt=None, |
chat_conversation=[], |
allow_prompt_auto=False, |
image_model='llava-v1.6-vicuna-13b', temperature=0.2, |
top_p=0.7, max_new_tokens=512, |
min_max_new_tokens=512, |
tokenizer=None, |
image_process_mode="Default", |
include_image=False, |
client=None, |
verbose_level=0, |
max_time=None, |
force_stream=True, |
verbose=False, |
): |
max_new_tokens = min(max_new_tokens, 1024) |
if isinstance(file, str): |
file_list = [file] |
elif isinstance(file, list): |
file_list = file |
if len(file_list) == 0: |
file_list = [None] |
else: |
file_list = [None] |
image_model = os.path.basename(image_model) |
prompt = fix_llava_prompt(file_list, prompt, allow_prompt_auto=allow_prompt_auto) |
max_new_tokens1 = max_new_tokens if len(file_list) <= 4 else min(max_new_tokens, min_max_new_tokens) |
if tokenizer: |
model_max_length = tokenizer.model_max_length |
else: |
model_max_length = llava16_model_max_length |
image_tokens = llava16_image_tokens if len(file_list) >= 1 and file_list[0] is not None else 0 |
fudge = llava16_image_fudge |
hard_limit_tokens = model_max_length - max_new_tokens1 - fudge - image_tokens |
prompt = get_limited_text(hard_limit_tokens, prompt, tokenizer) |
image_model, client, file_list = \ |
llava_prep(file_list, llava_model, |
image_model=image_model, |
client=client) |
jobs = [] |
for file in file_list: |
job = client.submit(prompt, |
chat_conversation, |
file, |
image_process_mode, |
include_image, |
image_model, |
temperature, |
top_p, |
max_new_tokens1, |
api_name='/textbox_api_submit') |
jobs.append(job) |
t0 = time.time() |
job_outputs_nums = [0] * len(jobs) |
texts = [''] * len(jobs) |
done_all = False |
reses = [''] * len(jobs) |
while True: |
for ji, job in enumerate(jobs): |
if verbose_level == 2: |
print("Inside: %s" % llava_model, time.time() - t0, flush=True) |
e = check_job(job, timeout=0, raise_exception=False) |
if e is not None: |
continue |
if max_time is not None and time.time() - t0 > max_time: |
done_all = True |
break |
outputs_list = job.outputs().copy() |
job_outputs_num_new = len(outputs_list[job_outputs_nums[ji]:]) |
for num in range(job_outputs_num_new): |
reses[ji] = outputs_list[job_outputs_nums[ji] + num] |
if verbose_level == 2: |
print('Stream %d: %s' % (num, reses[ji]), flush=True) |
elif verbose_level == 1: |
print('Stream %d' % (job_outputs_nums[ji] + num), flush=True) |
if reses[ji]: |
texts[ji] = reses[ji] |
if len(jobs) == 1: |
yield texts[ji] |
job_outputs_nums[ji] += job_outputs_num_new |
time.sleep(0.005) |
if done_all or all([job.done() for job in jobs]): |
break |
for ji, job in enumerate(jobs): |
e = check_job(job, timeout=0, raise_exception=False) |
if e is not None: |
continue |
outputs_list = job.outputs().copy() |
job_outputs_num_new = len(outputs_list[job_outputs_nums[ji]:]) |
for num in range(job_outputs_num_new): |
reses[ji] = outputs_list[job_outputs_nums[ji] + num] |
if verbose_level == 2: |
print('Final Stream %d: %s' % (num, reses[ji]), flush=True) |
elif verbose_level == 1: |
print('Final Stream %d' % (job_outputs_nums[ji] + num), flush=True) |
if reses[ji]: |
texts[ji] = reses[ji] |
if len(jobs) == 1: |
yield texts[ji] |
job_outputs_nums[ji] += job_outputs_num_new |
if verbose_level == 1: |
print("total job_outputs_num=%d" % job_outputs_nums[ji], flush=True) |
if len(jobs) > 1: |
ntexts_before = len(texts) |
texts = [x for x in texts if server_error_msg not in x] |
ntexts_after = len(texts) |
if ntexts_after != ntexts_before: |
print("texts: %s -> %s" % (ntexts_before, ntexts_after)) |
prompt_with_texts = get_prompt_with_texts(texts, prompt, max_new_tokens, min_max_new_tokens, tokenizer) |
text = '' |
max_new_tokens = max_new_tokens if len(jobs) > 4 else min(max_new_tokens, min_max_new_tokens) |
for res in get_llava_stream(None, |
llava_model, |
prompt=prompt_with_texts, |
chat_conversation=chat_conversation, |
allow_prompt_auto=allow_prompt_auto, |
image_model=image_model, |
temperature=temperature, |
top_p=top_p, |
max_new_tokens=max_new_tokens, |
min_max_new_tokens=min_max_new_tokens, |
tokenizer=tokenizer, |
image_process_mode=image_process_mode, |
include_image=include_image, |
client=client, |
verbose_level=verbose_level, |
max_time=max_time, |
force_stream=force_stream, |
verbose=verbose, |
): |
text = res |
yield text |
else: |
assert len(texts) == 1 |
text = texts[0] |
return text |
def get_image_model_dict(enable_image, |
image_models, |
image_gpu_ids, |
): |
image_dict = {} |
if not enable_image: |
return image_dict |
if image_gpu_ids is None: |
image_gpu_ids = ['auto'] * len(image_models) |
if not image_gpu_ids: |
image_gpu_ids = ['auto'] * len(image_models) |
for image_model_name in valid_imagegen_models + valid_imagechange_models + valid_imagestyle_models: |
if image_model_name in image_models: |
imagegen_index = image_models.index(image_model_name) |
if image_model_name == 'sdxl_turbo': |
from src.vision.sdxl_turbo import get_pipe_make_image, make_image |
elif image_model_name == 'playv2': |
from src.vision.playv2 import get_pipe_make_image, make_image |
elif image_model_name == 'sdxl': |
from src.vision.stable_diffusion_xl import get_pipe_make_image, make_image |
elif image_model_name == 'sd3': |
from src.vision.stable_diffusion_xl import get_pipe_make_image, make_image |
get_pipe_make_image = functools.partial(get_pipe_make_image, |
base_model='stabilityai/stable-diffusion-3-medium-diffusers', |
refiner_model=None) |
make_image = functools.partial(make_image, |
base_model='stabilityai/stable-diffusion-3-medium-diffusers', |
refiner_model=None) |
elif image_model_name == 'flux.1-dev': |
from src.vision.flux import get_pipe_make_image, make_image |
elif image_model_name == 'flux.1-schnell': |
from src.vision.flux import get_pipe_make_image_2 as get_pipe_make_image |
from src.vision.flux import make_image |
elif image_model_name == 'sdxl_change': |
from src.vision.sdxl_turbo import get_pipe_change_image as get_pipe_make_image, change_image |
make_image = change_image |
else: |
raise ValueError("Invalid image_model_name=%s" % image_model_name) |
pipe = get_pipe_make_image(gpu_id=image_gpu_ids[imagegen_index]) |
image_dict[image_model_name] = dict(pipe=pipe, make_image=make_image) |
return image_dict |
def pdf_to_base64_pngs(pdf_path, quality=75, max_size=(1024, 1024), ext='png', pages=None): |
""" |
Define the function to convert a pdf slide deck to a list of images. Note that we need to ensure we resize images to keep them within Claude's size limits. |
""" |
from PIL import Image |
import io |
import fitz |
import tempfile |
doc = fitz.open(pdf_path) |
images = [] |
if pages is None: |
pages = list(range(doc.page_count)) |
else: |
assert isinstance(pages, (list, tuple, types.GeneratorType)) |
for page_num in pages: |
page = doc.load_page(page_num) |
pix = page.get_pixmap(matrix=fitz.Matrix(300 / 72, 300 / 72)) |
output_path = f"{tempfile.mkdtemp()}/page_{page_num + 1}.{ext}" |
pix.save(output_path) |
images.append(output_path) |
doc.close() |
if ext == 'png': |
iformat = 'PNG' |
elif ext in ['jpeg', 'jpg']: |
iformat = 'JPEG' |
else: |
raise ValueError("No such ext=%s" % ext) |
images = [Image.open(image) for image in images] |
base64_encoded_pngs = [] |
for image in images: |
if image.size[0] > max_size[0] or image.size[1] > max_size[1]: |
image.thumbnail(max_size, Image.Resampling.LANCZOS) |
image_data = io.BytesIO() |
image.save(image_data, format=iformat, optimize=True, quality=quality) |
image_data.seek(0) |
base64_encoded = base64.b64encode(image_data.getvalue()).decode('utf-8') |
base64_encoded_pngs.append(base64_encoded) |
return base64_encoded_pngs |