File size: 3,398 Bytes
8df9fb2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9f03894
 
 
8df9fb2
 
 
9f03894
 
 
 
 
 
8df9fb2
 
 
 
 
 
 
 
 
f371603
8df9fb2
 
 
 
 
 
 
 
 
 
 
 
 
 
f371603
8df9fb2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f371603
8df9fb2
 
 
 
f371603
8df9fb2
 
 
 
f371603
8df9fb2
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel

# Model configurations
BASE_MODEL = "HuggingFaceTB/SmolLM2-1.7B-Instruct"  # Base model
ADAPTER_MODEL = "Joash2024/Math-SmolLM2-1.7B"       # Our LoRA adapter

print("Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
tokenizer.pad_token = tokenizer.eos_token

print("Loading base model...")
model = AutoModelForCausalLM.from_pretrained(
    BASE_MODEL,
    device_map="auto",
    torch_dtype=torch.float16,
    low_cpu_mem_usage=True,
    use_safetensors=True
)

print("Loading LoRA adapter...")
model = PeftModel.from_pretrained(
    model, 
    ADAPTER_MODEL,
    torch_dtype=torch.float16,
    device_map="auto"
)
model.eval()

def format_prompt(function: str) -> str:
    """Format input prompt for the model"""
    return f"""Given a mathematical function, find its derivative.

Function: {function}
The derivative of this function is:"""

def generate_derivative(function: str, max_length: int = 100) -> str:
    """Generate derivative for a given function"""
    # Format the prompt
    prompt = format_prompt(function)
    
    # Tokenize
    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
    
    # Generate
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_length=max_length,
            num_return_sequences=1,
            temperature=0.1,
            do_sample=False,  # Deterministic generation
            pad_token_id=tokenizer.eos_token_id
        )
    
    # Decode and extract derivative
    generated = tokenizer.decode(outputs[0], skip_special_tokens=True)
    derivative = generated[len(prompt):].strip()
    
    return derivative

def solve_derivative(function: str) -> str:
    """Solve derivative and format output"""
    if not function:
        return "Please enter a function"
    
    print(f"\nGenerating derivative for: {function}")
    derivative = generate_derivative(function)
    
    # Format output with step-by-step explanation
    output = f"""Generated derivative: {derivative}

Let's verify this step by step:
1. Starting with f(x) = {function}
2. Applying differentiation rules
3. We get f'(x) = {derivative}"""
    
    return output

# Create Gradio interface
with gr.Blocks(title="Mathematics Derivative Solver") as demo:
    gr.Markdown("# Mathematics Derivative Solver")
    gr.Markdown("Using our fine-tuned model to solve derivatives")
    
    with gr.Row():
        with gr.Column():
            function_input = gr.Textbox(
                label="Enter a function",
                placeholder="Example: x^2, sin(x), e^x"
            )
            solve_btn = gr.Button("Find Derivative", variant="primary")
    
    with gr.Row():
        output = gr.Textbox(
            label="Solution with Steps",
            lines=6
        )
    
    # Example functions (reduced)
    gr.Examples(
        examples=[
            ["x^2"],
            ["\\sin{\\left(x\\right)}"],
            ["e^x"]
        ],
        inputs=function_input,
        outputs=output,
        fn=solve_derivative,
        cache_examples=False  # Disable caching
    )
    
    # Connect the interface
    solve_btn.click(
        fn=solve_derivative,
        inputs=[function_input],
        outputs=output
    )

if __name__ == "__main__":
    demo.launch()