File size: 3,318 Bytes
3883c60
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import gc
import os.path

import diffusers
import torch.cuda
import transformers
import librosa

model: diffusers.AudioLDMPipeline = None
loaded = False
clap_model: transformers.ClapModel = None
processor: transformers.ClapProcessor = None
device: str = None

models = ['cvssp/audioldm', 'cvssp/audioldm-s-full-v2', 'cvssp/audioldm-m-full', 'cvssp/audioldm-l-full']


def create_model(pretrained='cvssp/audioldm-m-full', map_device='cuda' if torch.cuda.is_available() else 'cpu'):
    if is_loaded():
        delete_model()
    global model, loaded, clap_model, processor, device
    try:
        cache_dir = os.path.join('data', 'models', 'audioldm')
        model = diffusers.AudioLDMPipeline.from_pretrained(pretrained, cache_dir=cache_dir).to(map_device)
        clap_model = transformers.ClapModel.from_pretrained("sanchit-gandhi/clap-htsat-unfused-m-full", cache_dir=cache_dir).to(map_device)
        processor = transformers.AutoProcessor.from_pretrained("sanchit-gandhi/clap-htsat-unfused-m-full", cache_dir=cache_dir)
        device = map_device
        loaded = True
    except:
        pass


def delete_model():
    global model, loaded, clap_model, processor, device
    try:
        del model, clap_model, processor
        gc.collect()
        torch.cuda.empty_cache()
        loaded = False
        device = None
    except:
        pass


def is_loaded():
    return loaded


def score_waveforms(text, waveforms):
    inputs = processor(text=text, audios=list(waveforms), return_tensors="pt", padding=True)
    inputs = {key: inputs[key].to(device) for key in inputs}
    with torch.no_grad():
        logits_per_text = clap_model(**inputs).logits_per_text  # this is the audio-text similarity score
        probs = logits_per_text.softmax(dim=-1)  # we can take the softmax to get the label probabilities
        most_probable = torch.argmax(probs)  # and now select the most likely audio waveform
    waveform = waveforms[most_probable]
    return waveform


def generate(prompt='', negative_prompt='', steps=10, duration=5.0, cfg=2.5, seed=-1, wav_best_count=1, enhance=False, callback=None):
    if is_loaded():
        try:
            sample_rate = 16000
            seed = seed if seed >= 0 else torch.seed()
            torch.manual_seed(seed)
            output = model(prompt, negative_prompt=negative_prompt if negative_prompt else None,
                           audio_length_in_s=duration, num_inference_steps=steps, guidance_scale=cfg,
                           num_waveforms_per_prompt=wav_best_count, callback=callback)
            waveforms = output.audios
            if waveforms.shape[0] > 1:
                waveform = score_waveforms(prompt, waveforms)
            else:
                waveform = waveforms[0]
            if enhance:  # https://github.com/gitmylo/audio-webui/issues/36#issuecomment-1627380868
                sample_rate = 44100
                audio_resampled = librosa.resample(waveform, orig_sr=16000, target_sr=sample_rate)
                waveform = audio_resampled + librosa.effects.pitch_shift(audio_resampled, sr=sample_rate, n_steps=12, res_type="soxr_vhq")

            return seed, (sample_rate, waveform)
        except Exception as e:
            return f'An exception occurred: {str(e)}'
    return 'No model loaded! Please load a model first.'