Spaces:
Build error
Build error
File size: 5,215 Bytes
2283b14 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 |
# Ultralytics YOLO π, GPL-3.0 license
import os
import shutil
import psutil
import requests
from IPython import display # to display images and clear console output
from ultralytics.hub.auth import Auth
from ultralytics.hub.session import HubTrainingSession
from ultralytics.hub.utils import PREFIX, split_key
from ultralytics.yolo.utils import LOGGER, emojis, is_colab
from ultralytics.yolo.utils.torch_utils import select_device
from ultralytics.yolo.v8.detect import DetectionTrainer
def checks(verbose=True):
if is_colab():
shutil.rmtree('sample_data', ignore_errors=True) # remove colab /sample_data directory
if verbose:
# System info
gib = 1 << 30 # bytes per GiB
ram = psutil.virtual_memory().total
total, used, free = shutil.disk_usage("/")
display.clear_output()
s = f'({os.cpu_count()} CPUs, {ram / gib:.1f} GB RAM, {(total - free) / gib:.1f}/{total / gib:.1f} GB disk)'
else:
s = ''
select_device(newline=False)
LOGGER.info(f'Setup complete β
{s}')
def start(key=''):
# Start training models with Ultralytics HUB. Usage: from src.ultralytics import start; start('API_KEY')
def request_api_key(attempts=0):
"""Prompt the user to input their API key"""
import getpass
max_attempts = 3
tries = f"Attempt {str(attempts + 1)} of {max_attempts}" if attempts > 0 else ""
LOGGER.info(f"{PREFIX}Login. {tries}")
input_key = getpass.getpass("Enter your Ultralytics HUB API key:\n")
auth.api_key, model_id = split_key(input_key)
if not auth.authenticate():
attempts += 1
LOGGER.warning(f"{PREFIX}Invalid API key β οΈ\n")
if attempts < max_attempts:
return request_api_key(attempts)
raise ConnectionError(emojis(f"{PREFIX}Failed to authenticate β"))
else:
return model_id
try:
api_key, model_id = split_key(key)
auth = Auth(api_key) # attempts cookie login if no api key is present
attempts = 1 if len(key) else 0
if not auth.get_state():
if len(key):
LOGGER.warning(f"{PREFIX}Invalid API key β οΈ\n")
model_id = request_api_key(attempts)
LOGGER.info(f"{PREFIX}Authenticated β
")
if not model_id:
raise ConnectionError(emojis('Connecting with global API key is not currently supported. β'))
session = HubTrainingSession(model_id=model_id, auth=auth)
session.check_disk_space()
# TODO: refactor, hardcoded for v8
args = session.model.copy()
args.pop("id")
args.pop("status")
args.pop("weights")
args["data"] = "coco128.yaml"
args["model"] = "yolov8n.yaml"
args["batch_size"] = 16
args["imgsz"] = 64
trainer = DetectionTrainer(overrides=args)
session.register_callbacks(trainer)
setattr(trainer, 'hub_session', session)
trainer.train()
except Exception as e:
LOGGER.warning(f"{PREFIX}{e}")
def reset_model(key=''):
# Reset a trained model to an untrained state
api_key, model_id = split_key(key)
r = requests.post('https://api.ultralytics.com/model-reset', json={"apiKey": api_key, "modelId": model_id})
if r.status_code == 200:
LOGGER.info(f"{PREFIX}model reset successfully")
return
LOGGER.warning(f"{PREFIX}model reset failure {r.status_code} {r.reason}")
def export_model(key='', format='torchscript'):
# Export a model to all formats
api_key, model_id = split_key(key)
formats = ('torchscript', 'onnx', 'openvino', 'engine', 'coreml', 'saved_model', 'pb', 'tflite', 'edgetpu', 'tfjs',
'ultralytics_tflite', 'ultralytics_coreml')
assert format in formats, f"ERROR: Unsupported export format '{format}' passed, valid formats are {formats}"
r = requests.post('https://api.ultralytics.com/export',
json={
"apiKey": api_key,
"modelId": model_id,
"format": format})
assert r.status_code == 200, f"{PREFIX}{format} export failure {r.status_code} {r.reason}"
LOGGER.info(f"{PREFIX}{format} export started β
")
def get_export(key='', format='torchscript'):
# Get an exported model dictionary with download URL
api_key, model_id = split_key(key)
formats = ('torchscript', 'onnx', 'openvino', 'engine', 'coreml', 'saved_model', 'pb', 'tflite', 'edgetpu', 'tfjs',
'ultralytics_tflite', 'ultralytics_coreml')
assert format in formats, f"ERROR: Unsupported export format '{format}' passed, valid formats are {formats}"
r = requests.post('https://api.ultralytics.com/get-export',
json={
"apiKey": api_key,
"modelId": model_id,
"format": format})
assert r.status_code == 200, f"{PREFIX}{format} get_export failure {r.status_code} {r.reason}"
return r.json()
# temp. For checking
if __name__ == "__main__":
start(key="b3fba421be84a20dbe68644e14436d1cce1b0a0aaa_HeMfHgvHsseMPhdq7Ylz")
|