prithivMLmods commited on
Commit
9810ea7
·
verified ·
1 Parent(s): 3e00b8b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +94 -88
app.py CHANGED
@@ -1,100 +1,106 @@
1
- import torch
2
  import spaces
3
- from transformers import AutoTokenizer, AutoModelForCausalLM
4
  import gradio as gr
5
- from snac import SNAC
6
 
 
 
 
 
 
 
 
7
 
8
- def redistribute_codes(row):
9
- """
10
- Convert a sequence of token codes into an audio waveform using SNAC.
11
- The code assumes each 7 tokens represent one group of instructions.
12
- """
13
- row_length = row.size(0)
14
- new_length = (row_length // 7) * 7
15
- trimmed_row = row[:new_length]
16
- code_list = [t - 128266 for t in trimmed_row]
17
-
18
- layer_1, layer_2, layer_3 = [], [], []
19
-
20
- for i in range((len(code_list) + 1) // 7):
21
- layer_1.append(code_list[7 * i][None])
22
- layer_2.append(code_list[7 * i + 1][None] - 4096)
23
- layer_3.append(code_list[7 * i + 2][None] - (2 * 4096))
24
- layer_3.append(code_list[7 * i + 3][None] - (3 * 4096))
25
- layer_2.append(code_list[7 * i + 4][None] - (4 * 4096))
26
- layer_3.append(code_list[7 * i + 5][None] - (5 * 4096))
27
- layer_3.append(code_list[7 * i + 6][None] - (6 * 4096))
28
-
29
- with torch.no_grad():
30
- codes = [
31
- torch.concat(layer_1),
32
- torch.concat(layer_2),
33
- torch.concat(layer_3)
34
- ]
35
- for i in range(len(codes)):
36
- codes[i][codes[i] < 0] = 0
37
- codes[i] = codes[i][None]
38
 
39
- audio_hat = snac_model.decode(codes)
40
- return audio_hat.cpu()[0, 0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
 
42
- # Load the SNAC model for audio decoding
43
- snac_model = SNAC.from_pretrained("hubertsiuzdak/snac_24khz").to("cuda")
 
 
 
44
 
45
- # Load the single-speaker language model
46
- tokenizer = AutoTokenizer.from_pretrained('prithivMLmods/Llama-3B-Mono-Cooper')
47
- model = AutoModelForCausalLM.from_pretrained(
48
- 'prithivMLmods/Llama-3B-Mono-Cooper', torch_dtype=torch.bfloat16
49
- ).cuda()
50
 
51
- @spaces.GPU
52
- def generate_audio(text, temperature, top_p, max_new_tokens):
53
- """
54
- Given input text, generate speech audio.
55
- """
56
- speaker = "Cooper"
57
- prompt = f'<custom_token_3><|begin_of_text|>{speaker}: {text}<|eot_id|><custom_token_4><custom_token_5><custom_token_1>'
58
- input_ids = tokenizer(prompt, add_special_tokens=False, return_tensors='pt').to('cuda')
59
-
60
- with torch.no_grad():
61
- generated_ids = model.generate(
62
- **input_ids,
63
- max_new_tokens=max_new_tokens,
64
- do_sample=True,
65
- temperature=temperature,
66
- top_p=top_p,
67
- repetition_penalty=1.1,
68
- num_return_sequences=1,
69
- eos_token_id=128258,
70
  )
71
-
72
- row = generated_ids[0, input_ids['input_ids'].shape[1]:]
73
- y_tensor = redistribute_codes(row)
74
- y_np = y_tensor.detach().cpu().numpy()
75
- return (24000, y_np)
76
 
77
- # Gradio Interface
78
- with gr.Blocks() as demo:
79
- gr.Markdown("# Llama-3B-Mono-Cooper - Single Speaker Audio Generation")
80
- gr.Markdown("Generate speech audio using the `prithivMLmods/Llama-3B-Mono-Cooper` model.")
81
-
82
- with gr.Row():
83
- text_input = gr.Textbox(lines=4, label="Input Text")
84
-
85
- with gr.Row():
86
- temp_slider = gr.Slider(minimum=0.1, maximum=2.0, step=0.1, value=0.9, label="Temperature")
87
- top_p_slider = gr.Slider(minimum=0.1, maximum=1.0, step=0.05, value=0.8, label="Top-p")
88
- tokens_slider = gr.Slider(minimum=100, maximum=2000, step=50, value=1200, label="Max New Tokens")
89
-
90
- output_audio = gr.Audio(type="numpy", label="Generated Audio")
91
- generate_button = gr.Button("Generate Audio")
92
-
93
- generate_button.click(
94
- fn=generate_audio,
95
- inputs=[text_input, temp_slider, top_p_slider, tokens_slider],
96
- outputs=output_audio
97
- )
98
 
99
- if __name__ == "__main__":
100
  demo.launch()
 
 
 
 
1
+ import argparse
2
  import spaces
3
+ import torch
4
  import gradio as gr
5
+ from transformers import AutoModelForCausalLM, AutoTokenizer
6
 
7
+ def get_args():
8
+ parser = argparse.ArgumentParser()
9
+ parser.add_argument("--model", type=str, default="prithivMLmods/Pocket-Llama-3.2-3B-Instruct")
10
+ parser.add_argument("--max_length", type=int, default=512)
11
+ parser.add_argument("--do_sample", action="store_true")
12
+ # This allows ignoring unrecognized arguments, e.g., from Jupyter
13
+ return parser.parse_known_args()
14
 
15
+ def load_model(model_name):
16
+ """Load model and tokenizer from Hugging Face."""
17
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
18
+ model = AutoModelForCausalLM.from_pretrained(
19
+ model_name,
20
+ torch_dtype=torch.bfloat16,
21
+ device_map="auto"
22
+ )
23
+ return model, tokenizer
24
+
25
+ def generate_reply(model, tokenizer, prompt, max_length, do_sample):
26
+ """Generate text from the model given a prompt."""
27
+ inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
28
+ # We’re returning just the final string; no streaming here
29
+ output_tokens = model.generate(
30
+ **inputs,
31
+ max_length=max_length,
32
+ do_sample=do_sample
33
+ )
34
+ return tokenizer.decode(output_tokens[0], skip_special_tokens=True)
35
+
36
+ @sapces.GPU
37
+ def main():
38
+ args, _ = get_args()
39
+ model, tokenizer = load_model(args.model)
40
+
41
+ def respond(user_message, chat_history):
42
+ """
43
+ Gradio expects a function that takes the last user message and the
44
+ conversation history, then returns the updated history.
45
 
46
+ chat_history is a list of (user_message, bot_reply) pairs.
47
+ """
48
+ # Build a single text prompt from the conversation so far
49
+ prompt = ""
50
+ for (old_user_msg, old_bot_msg) in chat_history:
51
+ prompt += f"User: {old_user_msg}\nBot: {old_bot_msg}\n"
52
+ # Add the new user query
53
+ prompt += f"User: {user_message}\nBot:"
54
+
55
+ # Generate the response
56
+ bot_message = generate_reply(
57
+ model=model,
58
+ tokenizer=tokenizer,
59
+ prompt=prompt,
60
+ max_length=args.max_length,
61
+ do_sample=args.do_sample
62
+ )
63
 
64
+ # In many cases, the model output will contain the entire prompt again,
65
+ # so we can strip that off or just let it show. If you see repeated
66
+ # text, you can try to remove the prompt prefix from bot_message.
67
+ if bot_message.startswith(prompt):
68
+ bot_message = bot_message[len(prompt):]
69
 
70
+ # Append the new user-message and bot-response to the history
71
+ chat_history.append((user_message, bot_message))
72
+ return chat_history, chat_history
 
 
73
 
74
+ # Define the Gradio interface
75
+ with gr.Blocks() as demo:
76
+ gr.Markdown("<h2 style='text-align: center;'>Chat with Your Model</h2>")
77
+
78
+ # A Chatbot component that will display the conversation
79
+ chatbot = gr.Chatbot(label="Chat")
80
+
81
+ # A text box for user input
82
+ user_input = gr.Textbox(
83
+ show_label=False,
84
+ placeholder="Type your message here and press Enter"
 
 
 
 
 
 
 
 
85
  )
86
+
87
+ # A button to clear the conversation
88
+ clear_button = gr.Button("Clear")
 
 
89
 
90
+ # When the user hits Enter in the textbox, call 'respond'
91
+ # - Inputs: [user_input, chatbot] (the last user message and history)
92
+ # - Outputs: [chatbot, chatbot] (updates the chatbot display and history)
93
+ user_input.submit(respond, [user_input, chatbot], [chatbot, chatbot])
94
+
95
+ # Define a helper function for clearing
96
+ def clear_conversation():
97
+ return [], []
98
+
99
+ # When "Clear" is clicked, reset the conversation
100
+ clear_button.click(fn=clear_conversation, outputs=[chatbot, chatbot])
 
 
 
 
 
 
 
 
 
 
101
 
102
+ # Launch the Gradio app
103
  demo.launch()
104
+
105
+ if __name__ == "__main__":
106
+ main()