File size: 5,336 Bytes
496bf8a
efcdb1c
ab3a30c
 
 
 
 
 
 
 
496bf8a
ab3a30c
e85fa31
72c65b6
d536921
72c65b6
 
ab3a30c
efcdb1c
ab3a30c
 
 
 
72c65b6
 
 
 
 
 
d86bc7f
290deb7
bddd843
 
ab3a30c
 
 
 
 
290deb7
496bf8a
 
 
 
82022e9
496bf8a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ab3a30c
 
 
 
6a42799
d86bc7f
ab3a30c
82022e9
d86bc7f
 
 
 
 
 
 
 
 
8a4d97a
d86bc7f
72c65b6
82022e9
72c65b6
93acf16
c8a6713
 
82022e9
ee4aecd
240b319
 
 
82022e9
 
840333c
5086b00
 
82022e9
 
 
 
 
 
 
 
 
 
840333c
fcf0aa2
c8a6713
ab3a30c
a7c5b39
ab3a30c
 
d86bc7f
 
ab3a30c
 
73ce57b
c3db1ad
 
 
 
73ce57b
d86bc7f
 
 
 
 
 
d6fc925
82022e9
 
ab3a30c
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
import io
import math
from typing import Optional

import numpy as np
import spaces
import gradio as gr
import torch

from parler_tts import ParlerTTSForConditionalGeneration
from pydub import AudioSegment
from transformers import AutoTokenizer, AutoFeatureExtractor, set_seed
from huggingface_hub import InferenceClient
import nltk
import random
nltk.download('punkt')


device = "cuda:0" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
torch_dtype = torch.float16 if device != "cpu" else torch.float32

repo_id = "parler-tts/parler_tts_mini_v0.1"

jenny_repo_id = "ylacombe/parler-tts-mini-jenny-30H"

model = ParlerTTSForConditionalGeneration.from_pretrained(
    jenny_repo_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True
).to(device)

client = InferenceClient()

description_tokenizer = AutoTokenizer.from_pretrained(repo_id) 
prompt_tokenizer = AutoTokenizer.from_pretrained(repo_id, padding_side="left")
feature_extractor = AutoFeatureExtractor.from_pretrained(repo_id)

SAMPLE_RATE = feature_extractor.sampling_rate
SEED = 42


def numpy_to_mp3(audio_array, sampling_rate):
    # Normalize audio_array if it's floating-point
    if np.issubdtype(audio_array.dtype, np.floating):
        max_val = np.max(np.abs(audio_array))
        audio_array = (audio_array / max_val) * 32767 # Normalize to 16-bit range
        audio_array = audio_array.astype(np.int16)

    # Create an audio segment from the numpy array
    audio_segment = AudioSegment(
        audio_array.tobytes(),
        frame_rate=sampling_rate,
        sample_width=audio_array.dtype.itemsize,
        channels=1
    )

    # Export the audio segment to MP3 bytes - use a high bitrate to maximise quality
    mp3_io = io.BytesIO()
    audio_segment.export(mp3_io, format="mp3", bitrate="320k")

    # Get the MP3 bytes
    mp3_bytes = mp3_io.getvalue()
    mp3_io.close()

    return mp3_bytes

sampling_rate = model.audio_encoder.config.sampling_rate
frame_rate = model.audio_encoder.config.frame_rate



@spaces.GPU
def generate_base(subject, setting):

    messages = [{"role": "sytem", "content": ("You are an award-winning children's bedtime story author lauded for your inventive stories."
                                              "You want to write a bed time story for your child. They will give you the subject and setting "
                                              "and you will write the entire story. It should be targetted at children 5 and younger and take about "
                                              "a minute to read")},
                {"role": "user", "content": f"Please tell me a story about a {subject} in {setting}"}]
    gr.Info("Generating story", duration=3)
    response = client.chat_completion(messages, max_tokens=2048, seed=random.randint(1, 5000))
    gr.Info("Story Generated", duration=3)
    story = response.choices[0].message.content

    model_input = story.replace("\n", " ").strip()
    model_input_tokens = nltk.sent_tokenize(model_input)

    play_steps_in_s = 4.0
    play_steps = int(frame_rate * play_steps_in_s)

    gr.Info("Generating Audio")
    description = "Jenny speaks at an average pace with a calm delivery in a very confined sounding environment with clear audio quality."
    story_tokens = prompt_tokenizer(model_input_tokens, return_tensors="pt", padding=True).to(device)
    description_tokens = description_tokenizer([description for _ in range(len(model_input_tokens))], return_tensors="pt").to(device)
        speech_output = model.generate(input_ids=description_tokens.input_ids, prompt_input_ids=story_tokens.input_ids, attention_mask=description_tokens.attention_mask, prompt_attention_mask=story_tokens.attention_mask)
    speech_output = [output.cpu().numpy() for output in speech_output]
    gr.Info("Generated Audio")
    return None, None, {"audio": speech_output, "text": model_input_tokens}

import time
def stream_audio(state):
    speech_output = state["audio"]
    sentences = state["text"]

    gr.Info("Reading Story")

    story = ""
    for sentence, new_audio in zip(sentences, speech_output):
        print(f"Sample of length: {round(new_audio.shape[0] / sampling_rate, 2)} seconds")
        story += f"{sentence}\n"
        yield story, numpy_to_mp3(new_audio, sampling_rate=sampling_rate)
        time.sleep(5)


with gr.Blocks() as block:
    gr.HTML(
        f"""
        <h1> Bedtime Story Reader 😴🔊 </h1>
        <p> Powered by <a href="https://github.com/huggingface/parler-tts"> Parler-TTS</a>
        """
    )
    with gr.Group():
        with gr.Row():
            subject = gr.Dropdown(value="Princess", choices=["Prince", "Princess", "Dog", "Cat"])
            setting = gr.Dropdown(value="Forest", choices=["Forest", "Kingdom", "Jungle", "Underwater"])
        with gr.Row():
            run_button = gr.Button("Generate Story", variant="primary")
    with gr.Row():
        with gr.Group():
            audio_out = gr.Audio(label="Bed time story",  streaming=True, autoplay=True)
            story = gr.Textbox(label="Story")

    inputs = [subject, setting]
    outputs = [story, audio_out]
    state = gr.State()
    run_button.click(fn=generate_base, inputs=inputs, outputs=[story, audio_out, state]).success(stream_audio, inputs=state, outputs=outputs)

block.queue()
block.launch(share=True)