Joash2024 commited on
Commit
2d708a8
·
1 Parent(s): d2da9d1

feat: load models on demand with better memory management

Browse files
Files changed (1) hide show
  1. app.py +94 -86
app.py CHANGED
@@ -5,30 +5,24 @@ from peft import PeftModel
5
  from monitoring import PerformanceMonitor, measure_time
6
 
7
  # Model configurations
8
- BASE_MODEL = "HuggingFaceTB/SmolLM2-1.7B-Instruct" # Base model
9
- ADAPTER_MODEL = "Joash2024/Math-SmolLM2-1.7B" # Our LoRA adapter
 
 
 
 
 
 
 
 
10
 
11
  # Initialize performance monitor
12
  monitor = PerformanceMonitor()
13
 
14
  print("Loading tokenizer...")
15
- tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
16
  tokenizer.pad_token = tokenizer.eos_token
17
 
18
- print("Loading base model...")
19
- base_model = AutoModelForCausalLM.from_pretrained(
20
- BASE_MODEL,
21
- device_map="auto",
22
- torch_dtype=torch.float16
23
- )
24
-
25
- print("Loading fine-tuned model...")
26
- finetuned_model = PeftModel.from_pretrained(base_model, ADAPTER_MODEL)
27
-
28
- # Set models to eval mode
29
- base_model.eval()
30
- finetuned_model.eval()
31
-
32
  def format_prompt(problem: str, problem_type: str) -> str:
33
  """Format input prompt for the model"""
34
  if problem_type == "Derivative":
@@ -48,63 +42,80 @@ Function: {problem}
48
  The derivative is:"""
49
 
50
  @measure_time
51
- def get_model_response(problem: str, problem_type: str, model) -> str:
52
- """Generate response from a specific model"""
53
- # Format prompt
54
- prompt = format_prompt(problem, problem_type)
55
-
56
- # Tokenize
57
- inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
58
-
59
- # Generate
60
- with torch.no_grad():
61
- outputs = model.generate(
62
- **inputs,
63
- max_length=100,
64
- num_return_sequences=1,
65
- temperature=0.1,
66
- do_sample=True,
67
- pad_token_id=tokenizer.eos_token_id
68
- )
69
-
70
- # Decode and extract response
71
- generated = tokenizer.decode(outputs[0], skip_special_tokens=True)
72
- response = generated[len(prompt):].strip()
73
-
74
- return response
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
 
76
- def solve_problem(problem: str, problem_type: str) -> tuple:
77
- """Solve a math problem using both models"""
78
  if not problem:
79
- return "Please enter a problem", "Please enter a problem", None
80
 
81
  # Record problem type
82
  monitor.record_problem_type(problem_type)
83
 
84
- # Get responses from both models with timing
85
- base_response, base_time = get_model_response(problem, problem_type, base_model)
86
- finetuned_response, finetuned_time = get_model_response(problem, problem_type, finetuned_model)
87
 
88
- # Format responses with steps
89
- base_output = f"""Solution: {base_response}
90
-
91
- Let's verify this step by step:
92
- 1. Starting with f(x) = {problem}
93
- 2. Applying differentiation rules
94
- 3. We get f'(x) = {base_response}"""
95
-
96
- finetuned_output = f"""Solution: {finetuned_response}
97
 
98
  Let's verify this step by step:
99
  1. Starting with f(x) = {problem}
100
  2. Applying differentiation rules
101
- 3. We get f'(x) = {finetuned_response}"""
102
 
103
  # Record metrics
104
- monitor.record_response_time("base", base_time)
105
- monitor.record_response_time("finetuned", finetuned_time)
106
- monitor.record_success("base", not base_response.startswith("Error"))
107
- monitor.record_success("finetuned", not finetuned_response.startswith("Error"))
108
 
109
  # Get updated statistics
110
  stats = monitor.get_statistics()
@@ -114,24 +125,22 @@ Let's verify this step by step:
114
  ### Performance Metrics
115
 
116
  #### Response Times (seconds)
117
- - Base Model: {stats.get('base_avg_response_time', 0):.2f} avg
118
- - Fine-tuned Model: {stats.get('finetuned_avg_response_time', 0):.2f} avg
119
 
120
  #### Success Rates
121
- - Base Model: {stats.get('base_success_rate', 0):.1f}%
122
- - Fine-tuned Model: {stats.get('finetuned_success_rate', 0):.1f}%
123
 
124
  #### Problem Types Used
125
  """
126
  for ptype, percentage in stats.get('problem_type_distribution', {}).items():
127
  stats_display += f"- {ptype}: {percentage:.1f}%\n"
128
 
129
- return base_output, finetuned_output, stats_display
130
 
131
  # Create Gradio interface
132
  with gr.Blocks(title="Mathematics Problem Solver") as demo:
133
  gr.Markdown("# Mathematics Problem Solver")
134
- gr.Markdown("Compare solutions between base and fine-tuned models")
135
 
136
  with gr.Row():
137
  with gr.Column():
@@ -140,6 +149,11 @@ with gr.Blocks(title="Mathematics Problem Solver") as demo:
140
  value="Derivative",
141
  label="Problem Type"
142
  )
 
 
 
 
 
143
  problem_input = gr.Textbox(
144
  label="Enter your math problem",
145
  placeholder="Example: x^2 + 3x"
@@ -147,13 +161,7 @@ with gr.Blocks(title="Mathematics Problem Solver") as demo:
147
  solve_btn = gr.Button("Solve", variant="primary")
148
 
149
  with gr.Row():
150
- with gr.Column():
151
- gr.Markdown("### Base Model")
152
- base_output = gr.Textbox(label="Base Model Solution", lines=5)
153
-
154
- with gr.Column():
155
- gr.Markdown("### Fine-tuned Model")
156
- finetuned_output = gr.Textbox(label="Fine-tuned Model Solution", lines=5)
157
 
158
  # Performance metrics display
159
  with gr.Row():
@@ -162,17 +170,17 @@ with gr.Blocks(title="Mathematics Problem Solver") as demo:
162
  # Example problems
163
  gr.Examples(
164
  examples=[
165
- ["x^2 + 3x", "Derivative"],
166
- ["144", "Root Finding"],
167
- ["235 + 567", "Addition"],
168
- ["\\sin{\\left(x\\right)}", "Derivative"],
169
- ["e^x", "Derivative"],
170
- ["\\frac{1}{x}", "Derivative"],
171
- ["x^3 + 2x", "Derivative"],
172
- ["\\cos{\\left(x^2\\right)}", "Derivative"]
173
  ],
174
- inputs=[problem_input, problem_type],
175
- outputs=[base_output, finetuned_output, metrics_display],
176
  fn=solve_problem,
177
  cache_examples=True,
178
  )
@@ -180,8 +188,8 @@ with gr.Blocks(title="Mathematics Problem Solver") as demo:
180
  # Connect the interface
181
  solve_btn.click(
182
  fn=solve_problem,
183
- inputs=[problem_input, problem_type],
184
- outputs=[base_output, finetuned_output, metrics_display]
185
  )
186
 
187
  if __name__ == "__main__":
 
5
  from monitoring import PerformanceMonitor, measure_time
6
 
7
  # Model configurations
8
+ MODEL_OPTIONS = {
9
+ "Base Model": {
10
+ "id": "HuggingFaceTB/SmolLM2-1.7B-Instruct",
11
+ "is_base": True
12
+ },
13
+ "Fine-tuned Model": {
14
+ "id": "Joash2024/Math-SmolLM2-1.7B",
15
+ "is_base": False
16
+ }
17
+ }
18
 
19
  # Initialize performance monitor
20
  monitor = PerformanceMonitor()
21
 
22
  print("Loading tokenizer...")
23
+ tokenizer = AutoTokenizer.from_pretrained("HuggingFaceTB/SmolLM2-1.7B-Instruct")
24
  tokenizer.pad_token = tokenizer.eos_token
25
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
  def format_prompt(problem: str, problem_type: str) -> str:
27
  """Format input prompt for the model"""
28
  if problem_type == "Derivative":
 
42
  The derivative is:"""
43
 
44
  @measure_time
45
+ def get_model_response(problem: str, problem_type: str, model_info) -> str:
46
+ """Get response from a specific model"""
47
+ try:
48
+ # Load model
49
+ if model_info["is_base"]:
50
+ print(f"Loading {model_info['id']}...")
51
+ model = AutoModelForCausalLM.from_pretrained(
52
+ model_info["id"],
53
+ device_map="auto",
54
+ torch_dtype=torch.float16
55
+ )
56
+ else:
57
+ print("Loading base model for fine-tuned...")
58
+ base = AutoModelForCausalLM.from_pretrained(
59
+ "HuggingFaceTB/SmolLM2-1.7B-Instruct",
60
+ device_map="auto",
61
+ torch_dtype=torch.float16
62
+ )
63
+ print(f"Loading {model_info['id']}...")
64
+ model = PeftModel.from_pretrained(base, model_info["id"])
65
+
66
+ model.eval()
67
+
68
+ # Format prompt and generate
69
+ prompt = format_prompt(problem, problem_type)
70
+ inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
71
+
72
+ with torch.no_grad():
73
+ outputs = model.generate(
74
+ **inputs,
75
+ max_length=100,
76
+ num_return_sequences=1,
77
+ temperature=0.1,
78
+ do_sample=True,
79
+ pad_token_id=tokenizer.eos_token_id
80
+ )
81
+
82
+ # Decode and extract response
83
+ generated = tokenizer.decode(outputs[0], skip_special_tokens=True)
84
+ response = generated[len(prompt):].strip()
85
+
86
+ # Clean up
87
+ del model
88
+ if not model_info["is_base"]:
89
+ del base
90
+ torch.cuda.empty_cache()
91
+
92
+ return response
93
+ except Exception as e:
94
+ return f"Error: {str(e)}"
95
 
96
+ def solve_problem(problem: str, problem_type: str, model_type: str) -> tuple:
97
+ """Solve a math problem using selected model"""
98
  if not problem:
99
+ return "Please enter a problem", None
100
 
101
  # Record problem type
102
  monitor.record_problem_type(problem_type)
103
 
104
+ # Get response from selected model
105
+ model_info = MODEL_OPTIONS[model_type]
106
+ response, time_taken = get_model_response(problem, problem_type, model_info)
107
 
108
+ # Format response with steps
109
+ output = f"""Solution: {response}
 
 
 
 
 
 
 
110
 
111
  Let's verify this step by step:
112
  1. Starting with f(x) = {problem}
113
  2. Applying differentiation rules
114
+ 3. We get f'(x) = {response}"""
115
 
116
  # Record metrics
117
+ monitor.record_response_time(model_type, time_taken)
118
+ monitor.record_success(model_type, not response.startswith("Error"))
 
 
119
 
120
  # Get updated statistics
121
  stats = monitor.get_statistics()
 
125
  ### Performance Metrics
126
 
127
  #### Response Times (seconds)
128
+ - {model_type}: {stats.get(f'{model_type}_avg_response_time', 0):.2f} avg
 
129
 
130
  #### Success Rates
131
+ - {model_type}: {stats.get(f'{model_type}_success_rate', 0):.1f}%
 
132
 
133
  #### Problem Types Used
134
  """
135
  for ptype, percentage in stats.get('problem_type_distribution', {}).items():
136
  stats_display += f"- {ptype}: {percentage:.1f}%\n"
137
 
138
+ return output, stats_display
139
 
140
  # Create Gradio interface
141
  with gr.Blocks(title="Mathematics Problem Solver") as demo:
142
  gr.Markdown("# Mathematics Problem Solver")
143
+ gr.Markdown("Test our models on mathematical problems")
144
 
145
  with gr.Row():
146
  with gr.Column():
 
149
  value="Derivative",
150
  label="Problem Type"
151
  )
152
+ model_type = gr.Dropdown(
153
+ choices=list(MODEL_OPTIONS.keys()),
154
+ value="Fine-tuned Model",
155
+ label="Model to Use"
156
+ )
157
  problem_input = gr.Textbox(
158
  label="Enter your math problem",
159
  placeholder="Example: x^2 + 3x"
 
161
  solve_btn = gr.Button("Solve", variant="primary")
162
 
163
  with gr.Row():
164
+ solution_output = gr.Textbox(label="Solution", lines=5)
 
 
 
 
 
 
165
 
166
  # Performance metrics display
167
  with gr.Row():
 
170
  # Example problems
171
  gr.Examples(
172
  examples=[
173
+ ["x^2 + 3x", "Derivative", "Fine-tuned Model"],
174
+ ["144", "Root Finding", "Fine-tuned Model"],
175
+ ["235 + 567", "Addition", "Fine-tuned Model"],
176
+ ["\\sin{\\left(x\\right)}", "Derivative", "Fine-tuned Model"],
177
+ ["e^x", "Derivative", "Fine-tuned Model"],
178
+ ["\\frac{1}{x}", "Derivative", "Fine-tuned Model"],
179
+ ["x^3 + 2x", "Derivative", "Fine-tuned Model"],
180
+ ["\\cos{\\left(x^2\\right)}", "Derivative", "Fine-tuned Model"]
181
  ],
182
+ inputs=[problem_input, problem_type, model_type],
183
+ outputs=[solution_output, metrics_display],
184
  fn=solve_problem,
185
  cache_examples=True,
186
  )
 
188
  # Connect the interface
189
  solve_btn.click(
190
  fn=solve_problem,
191
+ inputs=[problem_input, problem_type, model_type],
192
+ outputs=[solution_output, metrics_display]
193
  )
194
 
195
  if __name__ == "__main__":