Spaces:
Running
Running
File size: 1,397 Bytes
8d64162 7934a8e 70d9de4 8d64162 0d77bb1 d197e7f 96bca50 0d77bb1 96bca50 d197e7f 96bca50 d197e7f 70d9de4 af715dd 70d9de4 8d64162 7934a8e 8d64162 7934a8e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 |
import base64
import io
import jsonlines
import torch
from PIL import Image
import wandb
def get_wandb_artifact(
artifact_name: str,
artifact_type: str,
get_metadata: bool = False,
) -> str:
if wandb.run:
artifact = wandb.use_artifact(artifact_name, type=artifact_type)
artifact_dir = artifact.download()
else:
api = wandb.Api()
artifact = api.artifact(artifact_name)
artifact_dir = artifact.download()
if get_metadata:
return artifact_dir, artifact.metadata
return artifact_dir
def get_torch_backend():
if torch.cuda.is_available():
if torch.backends.cuda.is_built():
return "cuda"
if torch.backends.mps.is_available():
if torch.backends.mps.is_built():
return "mps"
return "cpu"
return "cpu"
def base64_encode_image(image: Image.Image, mimetype: str) -> str:
image.load()
if image.mode not in ("RGB", "RGBA"):
image = image.convert("RGB")
byte_arr = io.BytesIO()
image.save(byte_arr, format="PNG")
encoded_string = base64.b64encode(byte_arr.getvalue()).decode("utf-8")
encoded_string = f"data:{mimetype};base64,{encoded_string}"
return str(encoded_string)
def read_jsonl_file(file_path: str) -> list[dict[str, any]]:
with jsonlines.open(file_path) as reader:
for obj in reader:
return obj
|