File size: 3,894 Bytes
9810ea7
6ee09b1
9810ea7
9fbf2d1
9810ea7
08d30fe
9810ea7
 
 
 
 
 
 
96784fc
9810ea7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6693a45
9810ea7
 
 
6693a45
 
9810ea7
 
 
 
9fbf2d1
9810ea7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
364cb51
9810ea7
 
 
 
 
364cb51
9810ea7
 
 
364cb51
9810ea7
 
 
 
 
 
 
 
 
 
 
9fbf2d1
9810ea7
 
 
364cb51
9810ea7
 
 
 
 
 
 
 
 
 
 
a2a8e37
9810ea7
96784fc
9810ea7
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import argparse
import spaces
import torch
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer

def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--model", type=str, default="prithivMLmods/Pocket-Llama-3.2-3B-Instruct")
    parser.add_argument("--max_length", type=int, default=512)
    parser.add_argument("--do_sample", action="store_true")
    # This allows ignoring unrecognized arguments, e.g., from Jupyter
    return parser.parse_known_args()

def load_model(model_name):
    """Load model and tokenizer from Hugging Face."""
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModelForCausalLM.from_pretrained(
        model_name, 
        torch_dtype=torch.bfloat16,
        device_map="auto"
    )
    return model, tokenizer

def generate_reply(model, tokenizer, prompt, max_length, do_sample):
    """Generate text from the model given a prompt."""
    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
    # We’re returning just the final string; no streaming here
    output_tokens = model.generate(
        **inputs,
        max_length=max_length,
        do_sample=do_sample
    )
    return tokenizer.decode(output_tokens[0], skip_special_tokens=True)


def main():
    args, _ = get_args()
    model, tokenizer = load_model(args.model)
    
    @spaces.GPU
    def respond(user_message, chat_history):
        """
        Gradio expects a function that takes the last user message and the 
        conversation history, then returns the updated history.
        
        chat_history is a list of (user_message, bot_reply) pairs.
        """
        # Build a single text prompt from the conversation so far
        prompt = ""
        for (old_user_msg, old_bot_msg) in chat_history:
            prompt += f"User: {old_user_msg}\nBot: {old_bot_msg}\n"
        # Add the new user query
        prompt += f"User: {user_message}\nBot:"

        # Generate the response
        bot_message = generate_reply(
            model=model,
            tokenizer=tokenizer,
            prompt=prompt,
            max_length=args.max_length,
            do_sample=args.do_sample
        )

        # In many cases, the model output will contain the entire prompt again, 
        # so we can strip that off or just let it show. If you see repeated 
        # text, you can try to remove the prompt prefix from bot_message.
        if bot_message.startswith(prompt):
            bot_message = bot_message[len(prompt):]

        # Append the new user-message and bot-response to the history
        chat_history.append((user_message, bot_message))
        return chat_history, chat_history

    # Define the Gradio interface
    with gr.Blocks() as demo:
        gr.Markdown("<h2 style='text-align: center;'>Chat with Your Model</h2>")
        
        # A Chatbot component that will display the conversation
        chatbot = gr.Chatbot(label="Chat")
        
        # A text box for user input
        user_input = gr.Textbox(
            show_label=False, 
            placeholder="Type your message here and press Enter"
        )
        
        # A button to clear the conversation
        clear_button = gr.Button("Clear")

        # When the user hits Enter in the textbox, call 'respond'
        #   - Inputs: [user_input, chatbot]  (the last user message and history)
        #   - Outputs: [chatbot, chatbot]    (updates the chatbot display and history)
        user_input.submit(respond, [user_input, chatbot], [chatbot, chatbot])
        
        # Define a helper function for clearing
        def clear_conversation():
            return [], []
        
        # When "Clear" is clicked, reset the conversation
        clear_button.click(fn=clear_conversation, outputs=[chatbot, chatbot])

    # Launch the Gradio app
    demo.launch()

if __name__ == "__main__":
    main()