multimodalart HF staff commited on
Commit
b30a088
·
verified ·
1 Parent(s): a19e0f7

Create wrapper.py

Browse files
Files changed (1) hide show
  1. wrapper.py +48 -0
wrapper.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import tempfile
3
+ import os
4
+ import subprocess
5
+ import sys
6
+
7
+ def main():
8
+ parser = argparse.ArgumentParser(description='Run YuE model with direct input')
9
+ parser.add_argument('--genre', type=str, required=True, help='Genre tags for the music')
10
+ parser.add_argument('--lyrics', type=str, required=True, help='Lyrics for the music')
11
+ parser.add_argument('--run_n_segments', type=int, default=2, help='Number of segments to process')
12
+ parser.add_argument('--stage2_batch_size', type=int, default=4, help='Batch size for stage 2')
13
+ parser.add_argument('--max_new_tokens', type=int, default=3000, help='Maximum number of new tokens')
14
+ parser.add_argument('--cuda_idx', type=int, default=0, help='CUDA device index')
15
+
16
+ args = parser.parse_args()
17
+
18
+ # Create temporary files for genre and lyrics
19
+ with tempfile.NamedTemporaryFile(mode='w', delete=False) as genre_file:
20
+ genre_file.write(args.genre)
21
+ genre_path = genre_file.name
22
+
23
+ with tempfile.NamedTemporaryFile(mode='w', delete=False) as lyrics_file:
24
+ lyrics_file.write(args.lyrics)
25
+ lyrics_path = lyrics_file.name
26
+
27
+ try:
28
+ # Run the inference script
29
+ subprocess.run([
30
+ 'python', 'infer.py',
31
+ '--stage1_model', 'm-a-p/YuE-s1-7B-anneal-en-cot',
32
+ '--stage2_model', 'm-a-p/YuE-s2-1B-general',
33
+ '--genre_txt', genre_path,
34
+ '--lyrics_txt', lyrics_path,
35
+ '--run_n_segments', str(args.run_n_segments),
36
+ '--stage2_batch_size', str(args.stage2_batch_size),
37
+ '--output_dir', '/app/output',
38
+ '--cuda_idx', str(args.cuda_idx),
39
+ '--max_new_tokens', str(args.max_new_tokens)
40
+ ], check=True)
41
+
42
+ finally:
43
+ # Clean up temporary files
44
+ os.unlink(genre_path)
45
+ os.unlink(lyrics_path)
46
+
47
+ if __name__ == '__main__':
48
+ main()