File size: 5,215 Bytes
2283b14
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
# Ultralytics YOLO πŸš€, GPL-3.0 license

import os
import shutil

import psutil
import requests
from IPython import display  # to display images and clear console output

from ultralytics.hub.auth import Auth
from ultralytics.hub.session import HubTrainingSession
from ultralytics.hub.utils import PREFIX, split_key
from ultralytics.yolo.utils import LOGGER, emojis, is_colab
from ultralytics.yolo.utils.torch_utils import select_device
from ultralytics.yolo.v8.detect import DetectionTrainer


def checks(verbose=True):
    if is_colab():
        shutil.rmtree('sample_data', ignore_errors=True)  # remove colab /sample_data directory

    if verbose:
        # System info
        gib = 1 << 30  # bytes per GiB
        ram = psutil.virtual_memory().total
        total, used, free = shutil.disk_usage("/")
        display.clear_output()
        s = f'({os.cpu_count()} CPUs, {ram / gib:.1f} GB RAM, {(total - free) / gib:.1f}/{total / gib:.1f} GB disk)'
    else:
        s = ''

    select_device(newline=False)
    LOGGER.info(f'Setup complete βœ… {s}')


def start(key=''):
    # Start training models with Ultralytics HUB. Usage: from src.ultralytics import start; start('API_KEY')
    def request_api_key(attempts=0):
        """Prompt the user to input their API key"""
        import getpass

        max_attempts = 3
        tries = f"Attempt {str(attempts + 1)} of {max_attempts}" if attempts > 0 else ""
        LOGGER.info(f"{PREFIX}Login. {tries}")
        input_key = getpass.getpass("Enter your Ultralytics HUB API key:\n")
        auth.api_key, model_id = split_key(input_key)
        if not auth.authenticate():
            attempts += 1
            LOGGER.warning(f"{PREFIX}Invalid API key ⚠️\n")
            if attempts < max_attempts:
                return request_api_key(attempts)
            raise ConnectionError(emojis(f"{PREFIX}Failed to authenticate ❌"))
        else:
            return model_id

    try:
        api_key, model_id = split_key(key)
        auth = Auth(api_key)  # attempts cookie login if no api key is present
        attempts = 1 if len(key) else 0
        if not auth.get_state():
            if len(key):
                LOGGER.warning(f"{PREFIX}Invalid API key ⚠️\n")
            model_id = request_api_key(attempts)
        LOGGER.info(f"{PREFIX}Authenticated βœ…")
        if not model_id:
            raise ConnectionError(emojis('Connecting with global API key is not currently supported. ❌'))
        session = HubTrainingSession(model_id=model_id, auth=auth)
        session.check_disk_space()

        # TODO: refactor, hardcoded for v8
        args = session.model.copy()
        args.pop("id")
        args.pop("status")
        args.pop("weights")
        args["data"] = "coco128.yaml"
        args["model"] = "yolov8n.yaml"
        args["batch_size"] = 16
        args["imgsz"] = 64

        trainer = DetectionTrainer(overrides=args)
        session.register_callbacks(trainer)
        setattr(trainer, 'hub_session', session)
        trainer.train()
    except Exception as e:
        LOGGER.warning(f"{PREFIX}{e}")


def reset_model(key=''):
    # Reset a trained model to an untrained state
    api_key, model_id = split_key(key)
    r = requests.post('https://api.ultralytics.com/model-reset', json={"apiKey": api_key, "modelId": model_id})

    if r.status_code == 200:
        LOGGER.info(f"{PREFIX}model reset successfully")
        return
    LOGGER.warning(f"{PREFIX}model reset failure {r.status_code} {r.reason}")


def export_model(key='', format='torchscript'):
    # Export a model to all formats
    api_key, model_id = split_key(key)
    formats = ('torchscript', 'onnx', 'openvino', 'engine', 'coreml', 'saved_model', 'pb', 'tflite', 'edgetpu', 'tfjs',
               'ultralytics_tflite', 'ultralytics_coreml')
    assert format in formats, f"ERROR: Unsupported export format '{format}' passed, valid formats are {formats}"

    r = requests.post('https://api.ultralytics.com/export',
                      json={
                          "apiKey": api_key,
                          "modelId": model_id,
                          "format": format})
    assert r.status_code == 200, f"{PREFIX}{format} export failure {r.status_code} {r.reason}"
    LOGGER.info(f"{PREFIX}{format} export started βœ…")


def get_export(key='', format='torchscript'):
    # Get an exported model dictionary with download URL
    api_key, model_id = split_key(key)
    formats = ('torchscript', 'onnx', 'openvino', 'engine', 'coreml', 'saved_model', 'pb', 'tflite', 'edgetpu', 'tfjs',
               'ultralytics_tflite', 'ultralytics_coreml')
    assert format in formats, f"ERROR: Unsupported export format '{format}' passed, valid formats are {formats}"

    r = requests.post('https://api.ultralytics.com/get-export',
                      json={
                          "apiKey": api_key,
                          "modelId": model_id,
                          "format": format})
    assert r.status_code == 200, f"{PREFIX}{format} get_export failure {r.status_code} {r.reason}"
    return r.json()


# temp. For checking
if __name__ == "__main__":
    start(key="b3fba421be84a20dbe68644e14436d1cce1b0a0aaa_HeMfHgvHsseMPhdq7Ylz")