File size: 3,829 Bytes
1eb7acd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
from transformers import AutoModelForCausalLM, AutoTokenizer
import gradio as gr
import torch
import spaces

# Initialize the model and tokenizer
model_name = "Qwen/Qwen2.5-Math-1.5B-Instruct"
device = "cuda" if torch.cuda.is_available() else "cpu"

model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype=torch.float16 if device == "cuda" else torch.float32,
    device_map="auto" if device == "cuda" else None
)
tokenizer = AutoTokenizer.from_pretrained(model_name)

# System instruction
SYSTEM_INSTRUCTION = (
    "You are a math tutor providing hints and guidance. "
    "Do not reveal final answers. Offer step-by-step assistance only."
)

def apply_chat_template(messages):
    """
    Prepares the messages for the model using the tokenizer's chat template.
    """
    return tokenizer.apply_chat_template(
        messages,
        tokenize=False,
        add_generation_prompt=True
    )

@spaces.GPU
def generate_response(history, user_input):
    """
    Generates a response from the model based on the chat history and user input.
    """
    # Append user input to the chat history
    history.append({"role": "user", "content": user_input})

    # Build messages for the model
    messages = [{"role": "system", "content": SYSTEM_INSTRUCTION}] + history

    # Tokenize input for the model
    text = apply_chat_template(messages)
    model_inputs = tokenizer([text], return_tensors="pt").to(device)

    # Generate response
    generated_ids = model.generate(
        **model_inputs,
        max_new_tokens=512
    )
    generated_ids = [
        output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
    ]
    response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]

    # Append the assistant's response to history
    history.append({"role": "assistant", "content": response})

    # Format the conversation for display
    formatted_history = format_chat_history(history)

    return formatted_history, history

def format_chat_history(history):
    """
    Formats the conversation history for a user-friendly chat display.
    """
    chat_display = ""
    for message in history:
        if message["role"] == "user":
            chat_display += f"**User:** {message['content']}\n\n"
        elif message["role"] == "assistant":
            chat_display += f"**MathTutor:** {message['content']}\n\n"
    return chat_display

# Gradio chat interface
def create_chat_interface():
    """
    Creates the Gradio interface for the chat application.
    """
    with gr.Blocks() as chat_app:
        gr.Markdown("## Math Hint Chat")
        gr.Markdown(
            "This chat application helps with math problems by providing hints and guidance. "
            "It keeps a history of your conversation and ensures no direct answers are given."
        )

        with gr.Row():
            with gr.Column():
                user_input = gr.Textbox(
                    label="Your Math Query",
                    placeholder="Ask about a math problem (e.g., Solve for x: 4x + 5 = 6x + 7)",
                    lines=2
                )
                send_button = gr.Button("Send")
            with gr.Column():
                chat_history = gr.Textbox(
                    label="Chat History",
                    placeholder="Chat history will appear here.",
                    lines=20,
                    interactive=False
                )

        # Hidden state for storing conversation history
        history_state = gr.State([])

        # Button interaction
        send_button.click(
            fn=generate_response,
            inputs=[history_state, user_input],
            outputs=[chat_history, history_state]
        )

    return chat_app


app = create_chat_interface()
app.launch(debug=True)