File size: 6,490 Bytes
39fc977
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dfc5d8f
39fc977
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
import os
import base64
import time
import shutil
import logging
import uuid
import zipfile
from flask import Flask, request, render_template, send_file, jsonify
from flask_socketio import SocketIO
from huggingface_hub import HfApi, hf_hub_download
from flask_apscheduler import APScheduler
import subprocess

# Set up logging
logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger(__name__)

app = Flask(__name__)
app.config['SECRET_KEY'] = 'secret!'
socketio = SocketIO(app)
scheduler = APScheduler()
scheduler.init_app(app)
scheduler.start()

# Directory to store temporary files
TEMP_DIR = '/tmp/piper_onnx'
os.makedirs(TEMP_DIR, exist_ok=True)

# Dictionary to store task information
tasks = {}

def cleanup_old_files():
    current_time = time.time()
    for filename in os.listdir(TEMP_DIR):
        file_path = os.path.join(TEMP_DIR, filename)
        if os.path.isfile(file_path):
            if current_time - os.path.getmtime(file_path) > 30 * 60:  # 30 minutes
                os.remove(file_path)

scheduler.add_job(id='cleanup_job', func=cleanup_old_files, trigger="interval", minutes=5)

@app.route('/', methods=['GET', 'POST'])
def index():
    if request.method == 'POST':
        repo_id = request.form['repo_id']
        token = request.form['token']
        model_name = request.form['model_name']

        task_id = str(uuid.uuid4())
        tasks[task_id] = {
            'status': 'processing',
            'log': [],
            'download_url': None
        }

        scheduler.add_job(
            func=process_model,
            args=[task_id, repo_id, token, model_name],
            id=task_id
        )

        return jsonify({'task_id': task_id})

    return render_template('index.html')

@app.route('/status/<task_id>')
def task_status(task_id):
    task = tasks.get(task_id)
    if task:
        return jsonify(task)
    return jsonify({'error': 'Task not found'}), 404

@app.route('/download/<task_id>/<filename>')
def download_file(task_id, filename):
    task = tasks.get(task_id)
    if task and task['status'] == 'completed':
        return send_file(task['download_url'], as_attachment=True)
    return jsonify({'error': 'File not found or task not completed'}), 404

def process_model(task_id, repo_id, token, model_name):
    try:
        update_task(task_id, "Starting model processing...")

        unique_dir = os.path.join(TEMP_DIR, f"{task_id}_{model_name}")
        os.makedirs(unique_dir, exist_ok=True)
        update_task(task_id, f"Created unique directory: {unique_dir}")

        download_model(task_id, repo_id, token, unique_dir)
        convert_to_onnx(task_id, model_name, unique_dir)
        compressed_file = compress_files(task_id, model_name, unique_dir)

        download_url = f"/download/{task_id}/{os.path.basename(compressed_file)}"
        tasks[task_id]['status'] = 'completed'
        tasks[task_id]['download_url'] = compressed_file
        update_task(task_id, f"Processing completed. Download URL: {download_url}")

    except Exception as e:
        logger.exception("An error occurred during processing")
        tasks[task_id]['status'] = 'error'
        update_task(task_id, f"An error occurred: {str(e)}")

def update_task(task_id, message):
    logger.info(message)
    tasks[task_id]['log'].append(message)
    socketio.emit('task_update', {'task_id': task_id, 'message': message})

def download_model(task_id, repo_id, token, directory):
    update_task(task_id, f"Downloading model from repo: {repo_id}")
    api = HfApi()
    files = api.list_repo_files(repo_id=repo_id, token=token)
    
    ckpt_files = [f for f in files if f.endswith('.ckpt')]
    if not ckpt_files:
        raise Exception("No .ckpt files found in the repository.")

    latest_ckpt = max(ckpt_files, key=lambda f: int(f.split('-')[0].split('=')[1]))
    update_task(task_id, f"Latest checkpoint file: {latest_ckpt}")
    
    ckpt_path = hf_hub_download(repo_id=repo_id, filename=latest_ckpt, token=token, local_dir=directory)
    os.rename(ckpt_path, os.path.join(directory, "model.ckpt"))
    update_task(task_id, f"Downloaded and renamed checkpoint to: {os.path.join(directory, 'model.ckpt')}")

    config_path = hf_hub_download(repo_id=repo_id, filename="config.json", token=token, local_dir=directory)
    update_task(task_id, f"Downloaded config.json to: {config_path}")

def convert_to_onnx(task_id, model_name, directory):
    update_task(task_id, f"Converting model to ONNX format: {model_name}")
    ckpt_path = os.path.join(directory, "model.ckpt")
    onnx_path = os.path.join(directory, f"{model_name}.onnx")
    
    update_task(task_id, f"Checkpoint path: {ckpt_path}")
    update_task(task_id, f"ONNX output path: {onnx_path}")
    
    original_dir = os.getcwd()
    os.chdir('/home/app/piper/src/python')
    update_task(task_id, f"Changed working directory to: {os.getcwd()}")
    
    command = [
        "python3", "-m", "piper_train.export_onnx",
        ckpt_path,
        onnx_path
    ]
    update_task(task_id, f"Running command: {' '.join(command)}")
    
    try:
        result = subprocess.run(command, check=True, capture_output=True, text=True)
        update_task(task_id, f"Command output: {result.stdout}")
    except subprocess.CalledProcessError as e:
        update_task(task_id, f"Command failed with exit code {e.returncode}")
        update_task(task_id, f"Error output: {e.stderr}")
        raise Exception(f"ONNX conversion failed: {e.stderr}")
    finally:
        os.chdir(original_dir)
        update_task(task_id, f"Changed back to original directory: {original_dir}")
    
    os.rename(
        os.path.join(directory, "config.json"),
        os.path.join(directory, f"{model_name}.onnx.json")
    )
    update_task(task_id, f"Renamed config.json to {model_name}.onnx.json")

def compress_files(task_id, model_name, directory):
    update_task(task_id, f"Compressing files for model: {model_name}")
    output_file = os.path.join(TEMP_DIR, f"{model_name}_onnx.zip")
    files_to_zip = [f for f in os.listdir(directory) if f.endswith('.onnx') or f.endswith('.onnx.json')]
    with zipfile.ZipFile(output_file, 'w') as zipf:
        for file in files_to_zip:
            zipf.write(os.path.join(directory, file), file)
    update_task(task_id, f"Created compressed file: {output_file}")
    return output_file

if __name__ == '__main__':
    logger.info("Starting Flask application")
    socketio.run(app, host='0.0.0.0', port=7860, debug=True, allow_unsafe_werkzeug=True)