Joash2024 commited on
Commit
98a6116
·
1 Parent(s): a444b7e

fix: simplify model loading and generation

Browse files
Files changed (1) hide show
  1. app.py +32 -19
app.py CHANGED
@@ -1,5 +1,5 @@
1
  import gradio as gr
2
- from transformers import AutoTokenizer, pipeline
3
  import torch
4
  import numpy as np
5
  from monitoring import PerformanceMonitor, measure_time
@@ -13,33 +13,33 @@ monitor = PerformanceMonitor()
13
 
14
  def format_prompt(problem):
15
  """Format the input problem according to the model's expected format"""
16
- return f"<|im_start|>user\nCan you help me solve this math problem? {problem}<|im_end|>\n"
17
 
18
  @measure_time
19
  def get_model_response(problem, model_id):
20
  """Get response from a specific model"""
21
  try:
22
- # Initialize pipeline
23
  pipe = pipeline(
24
  "text-generation",
25
  model=model_id,
26
  torch_dtype=torch.float16,
27
  device_map="auto",
 
28
  )
29
 
30
  # Format prompt and generate response
31
  prompt = format_prompt(problem)
32
  response = pipe(
33
  prompt,
34
- max_new_tokens=100,
35
  temperature=0.1,
36
- top_p=0.95,
37
- repetition_penalty=1.15
 
38
  )[0]["generated_text"]
39
 
40
- # Extract assistant's response
41
- assistant_response = response.split("<|im_start|>assistant\n")[-1].split("<|im_end|>")[0]
42
- return assistant_response.strip()
43
  except Exception as e:
44
  return f"Error: {str(e)}"
45
 
@@ -59,11 +59,24 @@ def solve_problem(problem, problem_type):
59
  base_response, base_time = get_model_response(problem, BASE_MODEL_ID)
60
  finetuned_response, finetuned_time = get_model_response(problem, FINETUNED_MODEL_ID)
61
 
62
- # Record response times
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
  monitor.record_response_time("base", base_time)
64
  monitor.record_response_time("finetuned", finetuned_time)
65
-
66
- # Record success (basic check - no error message)
67
  monitor.record_success("base", not base_response.startswith("Error"))
68
  monitor.record_success("finetuned", not finetuned_response.startswith("Error"))
69
 
@@ -82,12 +95,12 @@ def solve_problem(problem, problem_type):
82
  - Base Model: {stats.get('base_success_rate', 0):.1f}%
83
  - Fine-tuned Model: {stats.get('finetuned_success_rate', 0):.1f}%
84
 
85
- #### Problem Type Distribution
86
  """
87
  for ptype, percentage in stats.get('problem_type_distribution', {}).items():
88
  stats_display += f"- {ptype}: {percentage:.1f}%\n"
89
 
90
- return base_response, finetuned_response, stats_display
91
 
92
  # Create Gradio interface
93
  with gr.Blocks(title="Mathematics Problem Solver") as demo:
@@ -98,12 +111,12 @@ with gr.Blocks(title="Mathematics Problem Solver") as demo:
98
  with gr.Column():
99
  problem_type = gr.Dropdown(
100
  choices=["Addition", "Root Finding", "Derivative", "Custom"],
101
- value="Custom",
102
  label="Problem Type"
103
  )
104
  problem_input = gr.Textbox(
105
  label="Enter your math problem",
106
- placeholder="Example: Find the derivative of x^2 + 3x"
107
  )
108
  solve_btn = gr.Button("Solve", variant="primary")
109
 
@@ -123,9 +136,9 @@ with gr.Blocks(title="Mathematics Problem Solver") as demo:
123
  # Example problems
124
  gr.Examples(
125
  examples=[
126
- ["Find the derivative of x^2 + 3x", "Derivative"],
127
- ["What is the square root of 144?", "Root Finding"],
128
- ["Calculate 235 + 567", "Addition"],
129
  ["\\sin{\\left(x\\right)}", "Derivative"],
130
  ["e^x", "Derivative"],
131
  ["\\frac{1}{x}", "Derivative"],
 
1
  import gradio as gr
2
+ from transformers import pipeline
3
  import torch
4
  import numpy as np
5
  from monitoring import PerformanceMonitor, measure_time
 
13
 
14
  def format_prompt(problem):
15
  """Format the input problem according to the model's expected format"""
16
+ return f"Given a mathematical function, find its derivative.\n\nFunction: {problem}\nThe derivative of this function is:"
17
 
18
  @measure_time
19
  def get_model_response(problem, model_id):
20
  """Get response from a specific model"""
21
  try:
22
+ # Initialize pipeline for each request
23
  pipe = pipeline(
24
  "text-generation",
25
  model=model_id,
26
  torch_dtype=torch.float16,
27
  device_map="auto",
28
+ model_kwargs={"low_cpu_mem_usage": True}
29
  )
30
 
31
  # Format prompt and generate response
32
  prompt = format_prompt(problem)
33
  response = pipe(
34
  prompt,
35
+ max_new_tokens=50, # Shorter response
36
  temperature=0.1,
37
+ do_sample=False, # Deterministic
38
+ num_return_sequences=1,
39
+ return_full_text=False # Only return new text
40
  )[0]["generated_text"]
41
 
42
+ return response.strip()
 
 
43
  except Exception as e:
44
  return f"Error: {str(e)}"
45
 
 
59
  base_response, base_time = get_model_response(problem, BASE_MODEL_ID)
60
  finetuned_response, finetuned_time = get_model_response(problem, FINETUNED_MODEL_ID)
61
 
62
+ # Format responses with steps
63
+ base_output = f"""Solution: {base_response}
64
+
65
+ Let's verify this step by step:
66
+ 1. Starting with f(x) = {problem}
67
+ 2. Applying differentiation rules
68
+ 3. We get f'(x) = {base_response}"""
69
+
70
+ finetuned_output = f"""Solution: {finetuned_response}
71
+
72
+ Let's verify this step by step:
73
+ 1. Starting with f(x) = {problem}
74
+ 2. Applying differentiation rules
75
+ 3. We get f'(x) = {finetuned_response}"""
76
+
77
+ # Record metrics
78
  monitor.record_response_time("base", base_time)
79
  monitor.record_response_time("finetuned", finetuned_time)
 
 
80
  monitor.record_success("base", not base_response.startswith("Error"))
81
  monitor.record_success("finetuned", not finetuned_response.startswith("Error"))
82
 
 
95
  - Base Model: {stats.get('base_success_rate', 0):.1f}%
96
  - Fine-tuned Model: {stats.get('finetuned_success_rate', 0):.1f}%
97
 
98
+ #### Problem Types Used
99
  """
100
  for ptype, percentage in stats.get('problem_type_distribution', {}).items():
101
  stats_display += f"- {ptype}: {percentage:.1f}%\n"
102
 
103
+ return base_output, finetuned_output, stats_display
104
 
105
  # Create Gradio interface
106
  with gr.Blocks(title="Mathematics Problem Solver") as demo:
 
111
  with gr.Column():
112
  problem_type = gr.Dropdown(
113
  choices=["Addition", "Root Finding", "Derivative", "Custom"],
114
+ value="Derivative",
115
  label="Problem Type"
116
  )
117
  problem_input = gr.Textbox(
118
  label="Enter your math problem",
119
+ placeholder="Example: x^2 + 3x"
120
  )
121
  solve_btn = gr.Button("Solve", variant="primary")
122
 
 
136
  # Example problems
137
  gr.Examples(
138
  examples=[
139
+ ["x^2 + 3x", "Derivative"],
140
+ ["144", "Root Finding"],
141
+ ["235 + 567", "Addition"],
142
  ["\\sin{\\left(x\\right)}", "Derivative"],
143
  ["e^x", "Derivative"],
144
  ["\\frac{1}{x}", "Derivative"],