asigalov61 commited on
Commit
47be7bd
·
verified ·
1 Parent(s): ac9be57

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +35 -44
app.py CHANGED
@@ -62,54 +62,45 @@ SOUDFONT_PATH = 'SGM-v2.01-YamahaGrand-Guit-Bass-v2.7.sf2'
62
 
63
  #==================================================================================
64
 
65
- def load_model():
 
66
 
67
- print('=' * 70)
68
- print('Instantiating model...')
69
-
70
- device_type = 'cuda'
71
- dtype = 'bfloat16'
72
-
73
- ptdtype = {'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype]
74
- ctx = torch.amp.autocast(device_type=device_type, dtype=ptdtype)
75
-
76
- SEQ_LEN = 2048
77
 
78
- if model_selector == 'with velocity - 3 epochs':
79
- PAD_IDX = 512
80
 
81
- else:
82
- PAD_IDX = 384
83
-
84
- model = TransformerWrapper(
85
- num_tokens = PAD_IDX+1,
86
- max_seq_len = SEQ_LEN,
87
- attn_layers = Decoder(dim = 2048,
88
- depth = 4,
89
- heads = 32,
90
- rotary_pos_emb = True,
91
- attn_flash = True
92
- )
93
- )
94
-
95
- model = AutoregressiveWrapper(model, ignore_index=PAD_IDX, pad_value=PAD_IDX)
96
-
97
- print('=' * 70)
98
- print('Loading model checkpoint...')
99
-
100
- model_checkpoint = hf_hub_download(repo_id='asigalov61/Guided-Accompaniment-Transformer', filename=MODEL_CHECKPOINTS[model_selector])
101
-
102
- model.load_state_dict(torch.load(model_checkpoint, map_location='cpu', weights_only=True))
103
-
104
- model = torch.compile(model, mode='max-autotune')
105
-
106
- print('=' * 70)
107
- print('Done!')
108
- print('=' * 70)
109
- print('Model will use', dtype, 'precision...')
110
- print('=' * 70)
111
 
112
- return [model, ctx]
 
 
 
 
 
 
 
 
 
 
113
 
114
  #==================================================================================
115
 
 
62
 
63
  #==================================================================================
64
 
65
+ print('=' * 70)
66
+ print('Instantiating model...')
67
 
68
+ device_type = 'cuda'
69
+ dtype = 'bfloat16'
 
 
 
 
 
 
 
 
70
 
71
+ ptdtype = {'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype]
72
+ ctx = torch.amp.autocast(device_type=device_type, dtype=ptdtype)
73
 
74
+ SEQ_LEN = 4096
75
+ PAD_IDX = 1794
76
+
77
+ model = TransformerWrapper(
78
+ num_tokens = PAD_IDX+1,
79
+ max_seq_len = SEQ_LEN,
80
+ attn_layers = Decoder(dim = 2048,
81
+ depth = 4,
82
+ heads = 32,
83
+ rotary_pos_emb = True,
84
+ attn_flash = True
85
+ )
86
+ )
87
+
88
+ model = AutoregressiveWrapper(model, ignore_index=PAD_IDX, pad_value=PAD_IDX)
89
+
90
+ print('=' * 70)
91
+ print('Loading model checkpoint...')
 
 
 
 
 
 
 
 
 
 
 
 
92
 
93
+ model_checkpoint = hf_hub_download(repo_id='asigalov61/Guided-Accompaniment-Transformer', filename=MODEL_CHECKPOINTS[model_selector])
94
+
95
+ model.load_state_dict(torch.load(model_checkpoint, map_location='cpu', weights_only=True))
96
+
97
+ model = torch.compile(model, mode='max-autotune')
98
+
99
+ print('=' * 70)
100
+ print('Done!')
101
+ print('=' * 70)
102
+ print('Model will use', dtype, 'precision...')
103
+ print('=' * 70)
104
 
105
  #==================================================================================
106