Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -1,100 +1,106 @@
|
|
1 |
-
import
|
2 |
import spaces
|
3 |
-
|
4 |
import gradio as gr
|
5 |
-
from
|
6 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
7 |
|
8 |
-
def
|
9 |
-
"""
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
|
39 |
-
|
40 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
41 |
|
42 |
-
#
|
43 |
-
|
|
|
|
|
|
|
44 |
|
45 |
-
#
|
46 |
-
|
47 |
-
|
48 |
-
'prithivMLmods/Llama-3B-Mono-Cooper', torch_dtype=torch.bfloat16
|
49 |
-
).cuda()
|
50 |
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
**input_ids,
|
63 |
-
max_new_tokens=max_new_tokens,
|
64 |
-
do_sample=True,
|
65 |
-
temperature=temperature,
|
66 |
-
top_p=top_p,
|
67 |
-
repetition_penalty=1.1,
|
68 |
-
num_return_sequences=1,
|
69 |
-
eos_token_id=128258,
|
70 |
)
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
y_np = y_tensor.detach().cpu().numpy()
|
75 |
-
return (24000, y_np)
|
76 |
|
77 |
-
#
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
tokens_slider = gr.Slider(minimum=100, maximum=2000, step=50, value=1200, label="Max New Tokens")
|
89 |
-
|
90 |
-
output_audio = gr.Audio(type="numpy", label="Generated Audio")
|
91 |
-
generate_button = gr.Button("Generate Audio")
|
92 |
-
|
93 |
-
generate_button.click(
|
94 |
-
fn=generate_audio,
|
95 |
-
inputs=[text_input, temp_slider, top_p_slider, tokens_slider],
|
96 |
-
outputs=output_audio
|
97 |
-
)
|
98 |
|
99 |
-
|
100 |
demo.launch()
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
import spaces
|
3 |
+
import torch
|
4 |
import gradio as gr
|
5 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
6 |
|
7 |
+
def get_args():
|
8 |
+
parser = argparse.ArgumentParser()
|
9 |
+
parser.add_argument("--model", type=str, default="prithivMLmods/Pocket-Llama-3.2-3B-Instruct")
|
10 |
+
parser.add_argument("--max_length", type=int, default=512)
|
11 |
+
parser.add_argument("--do_sample", action="store_true")
|
12 |
+
# This allows ignoring unrecognized arguments, e.g., from Jupyter
|
13 |
+
return parser.parse_known_args()
|
14 |
|
15 |
+
def load_model(model_name):
|
16 |
+
"""Load model and tokenizer from Hugging Face."""
|
17 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
18 |
+
model = AutoModelForCausalLM.from_pretrained(
|
19 |
+
model_name,
|
20 |
+
torch_dtype=torch.bfloat16,
|
21 |
+
device_map="auto"
|
22 |
+
)
|
23 |
+
return model, tokenizer
|
24 |
+
|
25 |
+
def generate_reply(model, tokenizer, prompt, max_length, do_sample):
|
26 |
+
"""Generate text from the model given a prompt."""
|
27 |
+
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
|
28 |
+
# We’re returning just the final string; no streaming here
|
29 |
+
output_tokens = model.generate(
|
30 |
+
**inputs,
|
31 |
+
max_length=max_length,
|
32 |
+
do_sample=do_sample
|
33 |
+
)
|
34 |
+
return tokenizer.decode(output_tokens[0], skip_special_tokens=True)
|
35 |
+
|
36 |
+
@sapces.GPU
|
37 |
+
def main():
|
38 |
+
args, _ = get_args()
|
39 |
+
model, tokenizer = load_model(args.model)
|
40 |
+
|
41 |
+
def respond(user_message, chat_history):
|
42 |
+
"""
|
43 |
+
Gradio expects a function that takes the last user message and the
|
44 |
+
conversation history, then returns the updated history.
|
45 |
|
46 |
+
chat_history is a list of (user_message, bot_reply) pairs.
|
47 |
+
"""
|
48 |
+
# Build a single text prompt from the conversation so far
|
49 |
+
prompt = ""
|
50 |
+
for (old_user_msg, old_bot_msg) in chat_history:
|
51 |
+
prompt += f"User: {old_user_msg}\nBot: {old_bot_msg}\n"
|
52 |
+
# Add the new user query
|
53 |
+
prompt += f"User: {user_message}\nBot:"
|
54 |
+
|
55 |
+
# Generate the response
|
56 |
+
bot_message = generate_reply(
|
57 |
+
model=model,
|
58 |
+
tokenizer=tokenizer,
|
59 |
+
prompt=prompt,
|
60 |
+
max_length=args.max_length,
|
61 |
+
do_sample=args.do_sample
|
62 |
+
)
|
63 |
|
64 |
+
# In many cases, the model output will contain the entire prompt again,
|
65 |
+
# so we can strip that off or just let it show. If you see repeated
|
66 |
+
# text, you can try to remove the prompt prefix from bot_message.
|
67 |
+
if bot_message.startswith(prompt):
|
68 |
+
bot_message = bot_message[len(prompt):]
|
69 |
|
70 |
+
# Append the new user-message and bot-response to the history
|
71 |
+
chat_history.append((user_message, bot_message))
|
72 |
+
return chat_history, chat_history
|
|
|
|
|
73 |
|
74 |
+
# Define the Gradio interface
|
75 |
+
with gr.Blocks() as demo:
|
76 |
+
gr.Markdown("<h2 style='text-align: center;'>Chat with Your Model</h2>")
|
77 |
+
|
78 |
+
# A Chatbot component that will display the conversation
|
79 |
+
chatbot = gr.Chatbot(label="Chat")
|
80 |
+
|
81 |
+
# A text box for user input
|
82 |
+
user_input = gr.Textbox(
|
83 |
+
show_label=False,
|
84 |
+
placeholder="Type your message here and press Enter"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
85 |
)
|
86 |
+
|
87 |
+
# A button to clear the conversation
|
88 |
+
clear_button = gr.Button("Clear")
|
|
|
|
|
89 |
|
90 |
+
# When the user hits Enter in the textbox, call 'respond'
|
91 |
+
# - Inputs: [user_input, chatbot] (the last user message and history)
|
92 |
+
# - Outputs: [chatbot, chatbot] (updates the chatbot display and history)
|
93 |
+
user_input.submit(respond, [user_input, chatbot], [chatbot, chatbot])
|
94 |
+
|
95 |
+
# Define a helper function for clearing
|
96 |
+
def clear_conversation():
|
97 |
+
return [], []
|
98 |
+
|
99 |
+
# When "Clear" is clicked, reset the conversation
|
100 |
+
clear_button.click(fn=clear_conversation, outputs=[chatbot, chatbot])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
101 |
|
102 |
+
# Launch the Gradio app
|
103 |
demo.launch()
|
104 |
+
|
105 |
+
if __name__ == "__main__":
|
106 |
+
main()
|