File size: 3,357 Bytes
4a07f1a
 
8db0b4f
 
8da94da
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8db0b4f
 
8da94da
 
 
 
 
8db0b4f
 
 
 
 
 
4a07f1a
8db0b4f
8da94da
 
 
 
 
8db0b4f
 
 
 
4a07f1a
8db0b4f
 
 
 
 
 
4a07f1a
8db0b4f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4a07f1a
5a31216
8da94da
 
 
 
 
 
 
 
 
 
5a31216
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import gradio as gr
import os
import subprocess
from pathlib import Path
import time
import requests

def check_services():
    """Check if all required services are running"""
    services = [
        ("Controller", "http://localhost:21001"),
        ("API Server", "http://localhost:8000"),
        ("Model Worker", "http://localhost:8080")
    ]
    
    for service_name, url in services:
        try:
            requests.get(url)
            print(f"{service_name} is running")
        except requests.exceptions.ConnectionError:
            return False, f"{service_name} is not running"
    return True, "All services are running"

def check_training_status():
    # First check if services are running
    services_ok, message = check_services()
    if not services_ok:
        return message
    
    results_dir = Path("/app/results")
    if not results_dir.exists():
        return "Training hasn't started yet."
    
    iterations = len(list(results_dir.glob("iter_*")))
    return f"Completed {iterations} training iterations."

def start_training(model_path, instruct_count, max_iter):
    # Check if services are running
    services_ok, message = check_services()
    if not services_ok:
        return message
    
    os.environ["MODEL_PATH"] = model_path
    os.environ["INSTRUCT_COUNT"] = str(instruct_count)
    os.environ["MAX_ITER"] = str(max_iter)
    
    try:
        subprocess.run(["bash", "run.sh"], 
                      check=True,
                      cwd="/app/qwen")
        return "Training completed successfully!"
    except subprocess.CalledProcessError as e:
        return f"Error during training: {str(e)}"

# Create the interface
with gr.Blocks() as iface:
    gr.Markdown("# Self-Lengthen Training Interface")
    
    with gr.Row():
        with gr.Column():
            model_path = gr.Textbox(
                label="Model Path",
                value="/app/models/base_model",
                info="Path to the base model"
            )
            instruct_count = gr.Number(
                label="Instruction Count",
                value=5000,
                minimum=100,
                info="Number of instructions to generate"
            )
            max_iter = gr.Number(
                label="Max Iterations",
                value=3,
                minimum=1,
                info="Number of training iterations"
            )
            train_btn = gr.Button("Start Training")
        
        with gr.Column():
            status_output = gr.Textbox(
                label="Status",
                value="Ready to start training...",
                interactive=False
            )
            refresh_btn = gr.Button("Refresh Status")
    
    train_btn.click(
        fn=start_training,
        inputs=[model_path, instruct_count, max_iter],
        outputs=status_output
    )
    
    refresh_btn.click(
        fn=check_training_status,
        inputs=None,
        outputs=status_output
    )

if __name__ == "__main__":
    # Wait for services to be ready
    print("Waiting for services to start...")
    while True:
        services_ok, message = check_services()
        if services_ok:
            break
        print(message)
        time.sleep(5)
    
    print("All services are running, starting web interface...")
    iface.launch(server_name="0.0.0.0", server_port=7860)