Spaces:
Sleeping
Sleeping
import base64 | |
import io | |
import jsonlines | |
import torch | |
from PIL import Image | |
import wandb | |
def get_wandb_artifact( | |
artifact_name: str, | |
artifact_type: str, | |
get_metadata: bool = False, | |
) -> str: | |
if wandb.run: | |
artifact = wandb.use_artifact(artifact_name, type=artifact_type) | |
artifact_dir = artifact.download() | |
else: | |
api = wandb.Api() | |
artifact = api.artifact(artifact_name) | |
artifact_dir = artifact.download() | |
if get_metadata: | |
return artifact_dir, artifact.metadata | |
return artifact_dir | |
def get_torch_backend(): | |
if torch.cuda.is_available(): | |
if torch.backends.cuda.is_built(): | |
return "cuda" | |
if torch.backends.mps.is_available(): | |
if torch.backends.mps.is_built(): | |
return "mps" | |
return "cpu" | |
return "cpu" | |
def base64_encode_image(image: Image.Image, mimetype: str) -> str: | |
image.load() | |
if image.mode not in ("RGB", "RGBA"): | |
image = image.convert("RGB") | |
byte_arr = io.BytesIO() | |
image.save(byte_arr, format="PNG") | |
encoded_string = base64.b64encode(byte_arr.getvalue()).decode("utf-8") | |
encoded_string = f"data:{mimetype};base64,{encoded_string}" | |
return str(encoded_string) | |
def read_jsonl_file(file_path: str) -> list[dict[str, any]]: | |
with jsonlines.open(file_path) as reader: | |
for obj in reader: | |
return obj | |