File size: 4,663 Bytes
4c8266a
 
 
 
 
 
 
 
 
 
 
fe7937e
b121117
4c8266a
 
 
 
fe7937e
4c8266a
 
 
 
fe7937e
4c8266a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fe7937e
4c8266a
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
import torch
import gradio as gr
import whisper
import outetts
import numpy as np
from huggingface_hub import hf_hub_download
from outetts.wav_tokenizer.audio_codec import AudioCodec
from outetts.version.v2.prompt_processor import PromptProcessor
from outetts.version.playback import ModelOutput

model_path = hf_hub_download(
    repo_id="KandirResearch/CiSiMi-v0.1",
    filename="unsloth.Q8_0.gguf", # unsloth.Q4_K_M.gguf
)

model_config = outetts.GGUFModelConfig_v2(
    model_path=model_path,
    tokenizer_path="KandirResearch/CiSiMi-v0.1",
)

interface = outetts.InterfaceGGUF(model_version="0.3", cfg=model_config)
audio_codec = AudioCodec()
prompt_processor = PromptProcessor("KandirResearch/CiSiMi-v0.1")
whisper_model = whisper.load_model("base.en")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
gguf_model = interface.get_model()

def get_audio(tokens):
    outputs = prompt_processor.extract_audio_from_tokens(tokens)
    if not outputs:
        return None
    audio_tensor = audio_codec.decode(torch.tensor([[outputs]], dtype=torch.int64).to(device))
    return ModelOutput(audio_tensor, audio_codec.sr)

def extract_text_from_tts_output(tts_output):
    text = ""
    for line in tts_output.strip().split('\n'):
        if '<|audio_end|>' in line or '<|im_end|>' in line:
            continue
        if '<|' in line:
            word = line.split('<|')[0].strip()
            if word:
                text += word + " "
        else:
            text += line.strip() + " "
    return text.strip()

def process_input(audio_input, text_input):
    if audio_input is None and (text_input is None or text_input.strip() == ""):
        return "Please provide either audio or text input.", None
    
    if audio_input is not None:
        return process_audio(audio_input)
    else:
        return process_text(text_input)

def process_audio(audio):
    result = whisper_model.transcribe(audio)
    instruction = result["text"]
    return generate_response(instruction)

def process_text(text):
    instruction = text
    return generate_response(instruction)

def generate_response(instruction):
    prompt = f"<|im_start|>\nInstructions:\n{instruction}\n<|im_end|>\nAnswer:\n"
    gen_cfg = outetts.GenerationConfig(
        text=prompt, 
        temperature=0.6, 
        repetition_penalty=1.1, 
        max_length=4096, 
        speaker=None
    )
    
    input_ids = prompt_processor.tokenizer.encode(prompt)
    tokens = gguf_model.generate(input_ids, gen_cfg)
    
    output_text = prompt_processor.tokenizer.decode(tokens, skip_special_tokens=False)
    
    if "<|audio_end|>" in output_text:
        first_part, _, _ = output_text.partition("<|audio_end|>")
        
        if "<|audio_end|>\n<|im_end|>\n" not in first_part:
            first_part += "<|audio_end|>\n<|im_end|>\n"
            
        extracted_text = extract_text_from_tts_output(first_part)
        
        audio_start_pos = first_part.find("<|audio_start|>\n") + len("<|audio_start|>\n")
        audio_end_pos = first_part.find("<|audio_end|>\n<|im_end|>\n") + len("<|audio_end|>\n<|im_end|>\n")
        
        if audio_start_pos >= len("<|audio_start|>\n") and audio_end_pos > audio_start_pos:
            audio_tokens_text = first_part[audio_start_pos:audio_end_pos]
            
            audio_tokens = prompt_processor.tokenizer.encode(audio_tokens_text)
            
            #print(f"Decoding audio with {len(audio_tokens)} tokens")
            #print(f"audio_tokens: {audio_tokens_text}")
            
            audio_output = get_audio(audio_tokens)
            
            if audio_output is not None and hasattr(audio_output, 'audio') and audio_output.audio is not None:
                audio_numpy = audio_output.audio.cpu().numpy()
                if audio_numpy.ndim > 1:
                    audio_numpy = audio_numpy.squeeze()
                
                #display(Audio(data=audio_numpy, rate=audio_output.sr, autoplay=True))
                return extracted_text, (audio_output.sr, audio_numpy)
    
    return output_text, None

iface = gr.Interface(
    fn=process_input,
    inputs=[
        gr.Audio(type="filepath", label="Audio Input (Optional)"),
        gr.Textbox(label="Text Input (Optional)")
    ],
    outputs=[
        gr.Textbox(label="Response Text"),
        gr.Audio(type="numpy", label="Generated Speech")
    ],
    title="CiSiMi-v0.1 @ Home Demo",
    description="Me: Mom can we have CSM locally! Mom: we have CSM locally. CSM locally:",
    examples=[
        [None, "Hello, what are you capable of?"],
        [None, "Explain to me how gravity works!"]
    ]
)

iface.launch(debug=True)