|
import os |
|
import random |
|
import subprocess |
|
import sys |
|
from enum import Enum, unique |
|
|
|
from . import launcher |
|
from .api.app import run_api |
|
from .chat.chat_model import run_chat |
|
from .eval.evaluator import run_eval |
|
from .extras.logging import get_logger |
|
from .extras.misc import get_device_count |
|
from .train.tuner import export_model, run_exp |
|
from .webui.interface import run_web_demo, run_web_ui |
|
|
|
|
|
USAGE = ( |
|
"-" * 70 |
|
+ "\n" |
|
+ "| Usage: |\n" |
|
+ "| llamafactory-cli api -h: launch an OpenAI-style API server |\n" |
|
+ "| llamafactory-cli chat -h: launch a chat interface in CLI |\n" |
|
+ "| llamafactory-cli eval -h: evaluate models |\n" |
|
+ "| llamafactory-cli export -h: merge LoRA adapters and export model |\n" |
|
+ "| llamafactory-cli train -h: train models |\n" |
|
+ "| llamafactory-cli webchat -h: launch a chat interface in Web UI |\n" |
|
+ "| llamafactory-cli webui: launch LlamaBoard |\n" |
|
+ "| llamafactory-cli version: show version info |\n" |
|
+ "-" * 70 |
|
) |
|
|
|
VERSION = "0.7.2.dev0" |
|
|
|
WELCOME = ( |
|
"-" * 58 |
|
+ "\n" |
|
+ "| Welcome to LLaMA Factory, version {}".format(VERSION) |
|
+ " " * (21 - len(VERSION)) |
|
+ "|\n|" |
|
+ " " * 56 |
|
+ "|\n" |
|
+ "| Project page: https://github.com/hiyouga/LLaMA-Factory |\n" |
|
+ "-" * 58 |
|
) |
|
|
|
logger = get_logger(__name__) |
|
|
|
|
|
@unique |
|
class Command(str, Enum): |
|
API = "api" |
|
CHAT = "chat" |
|
EVAL = "eval" |
|
EXPORT = "export" |
|
TRAIN = "train" |
|
WEBDEMO = "webchat" |
|
WEBUI = "webui" |
|
VER = "version" |
|
HELP = "help" |
|
|
|
|
|
def main(): |
|
command = sys.argv.pop(1) |
|
if command == Command.API: |
|
run_api() |
|
elif command == Command.CHAT: |
|
run_chat() |
|
elif command == Command.EVAL: |
|
run_eval() |
|
elif command == Command.EXPORT: |
|
export_model() |
|
elif command == Command.TRAIN: |
|
if get_device_count() > 1: |
|
nnodes = os.environ.get("NNODES", "1") |
|
node_rank = os.environ.get("RANK", "0") |
|
nproc_per_node = os.environ.get("NPROC_PER_NODE", str(get_device_count())) |
|
master_addr = os.environ.get("MASTER_ADDR", "127.0.0.1") |
|
master_port = os.environ.get("MASTER_PORT", str(random.randint(20001, 29999))) |
|
logger.info("Initializing distributed tasks at: {}:{}".format(master_addr, master_port)) |
|
subprocess.run( |
|
[ |
|
"torchrun", |
|
"--nnodes", |
|
nnodes, |
|
"--node_rank", |
|
node_rank, |
|
"--nproc_per_node", |
|
nproc_per_node, |
|
"--master_addr", |
|
master_addr, |
|
"--master_port", |
|
master_port, |
|
launcher.__file__, |
|
*sys.argv[1:], |
|
] |
|
) |
|
else: |
|
run_exp() |
|
elif command == Command.WEBDEMO: |
|
run_web_demo() |
|
elif command == Command.WEBUI: |
|
run_web_ui() |
|
elif command == Command.VER: |
|
print(WELCOME) |
|
elif command == Command.HELP: |
|
print(USAGE) |
|
else: |
|
raise NotImplementedError("Unknown command: {}".format(command)) |
|
|