Spaces:
Runtime error
Runtime error
Create wrapper.py
Browse files- wrapper.py +48 -0
wrapper.py
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import tempfile
|
3 |
+
import os
|
4 |
+
import subprocess
|
5 |
+
import sys
|
6 |
+
|
7 |
+
def main():
|
8 |
+
parser = argparse.ArgumentParser(description='Run YuE model with direct input')
|
9 |
+
parser.add_argument('--genre', type=str, required=True, help='Genre tags for the music')
|
10 |
+
parser.add_argument('--lyrics', type=str, required=True, help='Lyrics for the music')
|
11 |
+
parser.add_argument('--run_n_segments', type=int, default=2, help='Number of segments to process')
|
12 |
+
parser.add_argument('--stage2_batch_size', type=int, default=4, help='Batch size for stage 2')
|
13 |
+
parser.add_argument('--max_new_tokens', type=int, default=3000, help='Maximum number of new tokens')
|
14 |
+
parser.add_argument('--cuda_idx', type=int, default=0, help='CUDA device index')
|
15 |
+
|
16 |
+
args = parser.parse_args()
|
17 |
+
|
18 |
+
# Create temporary files for genre and lyrics
|
19 |
+
with tempfile.NamedTemporaryFile(mode='w', delete=False) as genre_file:
|
20 |
+
genre_file.write(args.genre)
|
21 |
+
genre_path = genre_file.name
|
22 |
+
|
23 |
+
with tempfile.NamedTemporaryFile(mode='w', delete=False) as lyrics_file:
|
24 |
+
lyrics_file.write(args.lyrics)
|
25 |
+
lyrics_path = lyrics_file.name
|
26 |
+
|
27 |
+
try:
|
28 |
+
# Run the inference script
|
29 |
+
subprocess.run([
|
30 |
+
'python', 'infer.py',
|
31 |
+
'--stage1_model', 'm-a-p/YuE-s1-7B-anneal-en-cot',
|
32 |
+
'--stage2_model', 'm-a-p/YuE-s2-1B-general',
|
33 |
+
'--genre_txt', genre_path,
|
34 |
+
'--lyrics_txt', lyrics_path,
|
35 |
+
'--run_n_segments', str(args.run_n_segments),
|
36 |
+
'--stage2_batch_size', str(args.stage2_batch_size),
|
37 |
+
'--output_dir', '/app/output',
|
38 |
+
'--cuda_idx', str(args.cuda_idx),
|
39 |
+
'--max_new_tokens', str(args.max_new_tokens)
|
40 |
+
], check=True)
|
41 |
+
|
42 |
+
finally:
|
43 |
+
# Clean up temporary files
|
44 |
+
os.unlink(genre_path)
|
45 |
+
os.unlink(lyrics_path)
|
46 |
+
|
47 |
+
if __name__ == '__main__':
|
48 |
+
main()
|