File size: 964 Bytes
bafb458
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
from huggingface_hub import hf_hub_download
from huggingface_hub import HfApi, login
import os
import yaml
from pathlib import Path
from loguru import logger
from PIL import Image
from .get_best_mp4 import get_mp4_paths

SPACE_REPO = "c-gohlke/litrl"
SPACE_REPO_TYPE = "space"
MODEL_REPO = "c-gohlke/litrl"
MODEL_REPO_TYPE = "model"

ENV_RESULTS_FILE_DEPTH = 3


class HugingFaceClient:
    def __init__(self) -> None:
        login(  # type: ignore[no-untyped-call]
            token=os.environ.get("HUGGINGFACE_TOKEN"),
            add_to_git_credential=True,
            new_session=False,
        )
        self.hf_api = HfApi()
        self.mp4_paths = get_mp4_paths()
        

    def api_predict(self, env_id: str)-> bytes|None:
        if env_id not in self.mp4_paths:
            logger.error(f"Environment {env_id} not found in {self.mp4_paths}")
            return None
        with open(self.mp4_paths[env_id], "rb") as f:
            return f.read()