File size: 3,774 Bytes
7d79917
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import contextvars
import time
import webbrowser
from pathlib import Path

from gradio_client import Client
from httpx import ReadTimeout
from huggingface_hub.errors import RepositoryNotFoundError

from trackio.deploy import deploy_as_space
from trackio.run import Run
from trackio.ui import demo
from trackio.utils import TRACKIO_DIR, TRACKIO_LOGO_PATH, block_except_in_notebook

__version__ = Path(__file__).parent.joinpath("version.txt").read_text().strip()


current_run: contextvars.ContextVar[Run | None] = contextvars.ContextVar(
    "current_run", default=None
)
current_project: contextvars.ContextVar[str | None] = contextvars.ContextVar(
    "current_project", default=None
)
current_server: contextvars.ContextVar[str | None] = contextvars.ContextVar(
    "current_server", default=None
)

config = {}


def init(
    project: str, name: str | None = None, space_id=None, config: dict | None = None
) -> Run:
    if not current_server.get():
        if space_id is None:
            _, url, _ = demo.launch(
                show_api=False, inline=False, quiet=True, prevent_thread_lock=True
            )
        else:
            url = space_id
        current_server.set(url)
    else:
        url = current_server.get()

    client = None
    try:
        client = Client(url, verbose=False)
    except RepositoryNotFoundError as e:
        # if we're passed a space_id, ignore the error here, otherwise re-raise it
        if space_id is None:
            raise e
    if client is None and space_id is not None:
        # try to create the space on demand
        deploy_as_space(space_id)
        # try to wait until the Client can initialize
        max_attempts = 30
        attempts = 0
        while client is None:
            try:
                client = Client(space_id, verbose=False)
            except ReadTimeout:
                print("Space is not yet ready. Waiting 5 seconds...")
                time.sleep(5)
            except ValueError as e:
                print(f"Space gave error {e}. Trying again in 5 seconds...")
                time.sleep(5)
            attempts += 1
            if attempts >= max_attempts:
                break

    if current_project.get() is None or current_project.get() != project:
        print(f"* Trackio project initialized: {project}")
        if space_id is None:
            print(f"* Trackio metrics logged to: {TRACKIO_DIR}")
            print(
                f'\n* View dashboard by running in your terminal: trackio show --project "{project}"'
            )
            print(f'* or by running in Python: trackio.show(project="{project}")')
        else:
            print(
                f"* Trackio metrics logged to: https://huggingface.co/spaces/{space_id}"
            )
    current_project.set(project)

    run = Run(project=project, client=client, name=name, config=config)
    current_run.set(run)
    globals()["config"] = run.config
    return run


def log(metrics: dict) -> None:
    if current_run.get() is None:
        raise RuntimeError("Call trackio.init() before log().")
    current_run.get().log(metrics)


def finish():
    if current_run.get() is None:
        raise RuntimeError("Call trackio.init() before finish().")
    current_run.get().finish()


def show(project: str | None = None):
    _, url, share_url = demo.launch(
        show_api=False,
        quiet=True,
        inline=False,
        prevent_thread_lock=True,
        favicon_path=TRACKIO_LOGO_PATH,
        allowed_paths=[TRACKIO_LOGO_PATH],
    )
    base_url = share_url + "/" if share_url else url
    dashboard_url = base_url + f"?project={project}" if project else base_url
    print(f"* Trackio UI launched at: {dashboard_url}")
    webbrowser.open(dashboard_url)
    block_except_in_notebook()