geethareddy commited on
Commit
35d271b
·
verified ·
1 Parent(s): af2003c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -11
app.py CHANGED
@@ -19,7 +19,7 @@ KPI_FLAG_DEFAULT = os.getenv('KPI_FLAG', 'True') == 'True'
19
  ENGAGEMENT_SCORE_DEFAULT = float(os.getenv('ENGAGEMENT_SCORE', '85.0'))
20
 
21
  # Load model and tokenizer (Updated to use distilgpt2)
22
- model_name = "distilgpt2" # Using distilgpt2
23
  tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
24
  model = AutoModelForCausalLM.from_pretrained(model_name, low_cpu_mem_usage=True)
25
 
@@ -145,14 +145,12 @@ def save_to_salesforce(role, project_id, milestones, reflection, checklist, sugg
145
  # Fallback generation for checklist and suggestions
146
  def generate_fallback_checklist(role, milestones):
147
  checklist_items = []
148
- # Base checklist on role
149
  if role.lower() == "safety officer":
150
  checklist_items.append("- Conduct morning safety briefing")
151
  checklist_items.append("- Review incident reports from previous day")
152
 
153
- # If milestones are provided, add them
154
- if milestones and milestones.strip() and milestones != "A milestone is a numbered marker placed on a route such as a road, railway line, canal or boundary.":
155
- kpis = [kpi.strip() for kpi in milestones.split(",")] # Fixed syntax error here
156
  for kpi in kpis:
157
  checklist_items.append(f"- Ensure progress on {kpi}")
158
  else:
@@ -162,11 +160,10 @@ def generate_fallback_checklist(role, milestones):
162
 
163
  def generate_fallback_suggestions(reflection):
164
  suggestions_items = []
165
- # Base suggestions on reflection log keywords
166
  reflection_lower = reflection.lower()
167
  if "student" in reflection_lower or "learning" in reflection_lower:
168
  suggestions_items.append("- Ensure students are logging incidents consistently")
169
- suggestions_items.append("- Provide guidance on timely incident recording") # Fixed typo "Provide Provide"
170
  if "incident" in reflection_lower:
171
  suggestions_items.append("- Follow up on reported incidents with corrective actions")
172
 
@@ -193,7 +190,7 @@ def generate_outputs(role, supervisor_name, project_id, milestones, reflection):
193
  with torch.no_grad():
194
  outputs = model.generate(
195
  inputs['input_ids'],
196
- max_new_tokens=100,
197
  no_repeat_ngram_size=2,
198
  do_sample=True,
199
  top_p=0.9,
@@ -203,7 +200,6 @@ def generate_outputs(role, supervisor_name, project_id, milestones, reflection):
203
  result = tokenizer.decode(outputs[0], skip_special_tokens=True)
204
  except Exception as e:
205
  print(f"⚠️ Generation error: {e}")
206
- # Use fallback if generation fails
207
  checklist = generate_fallback_checklist(role, milestones)
208
  suggestions = generate_fallback_suggestions(reflection)
209
  save_to_salesforce(role, project_id, milestones, reflection, checklist, suggestions, supervisor_name)
@@ -217,7 +213,6 @@ def generate_outputs(role, supervisor_name, project_id, milestones, reflection):
217
  checklist = extract_between(result, "Checklist:\n", "Suggestions:")
218
  suggestions = extract_between(result, "Suggestions:\n", None)
219
 
220
- # Use fallback if checklist or suggestions are empty
221
  if not checklist.strip():
222
  checklist = generate_fallback_checklist(role, milestones)
223
  if not suggestions.strip():
@@ -274,4 +269,3 @@ def create_interface():
274
  if __name__ == "__main__":
275
  app = create_interface()
276
  app.launch()
277
-
 
19
  ENGAGEMENT_SCORE_DEFAULT = float(os.getenv('ENGAGEMENT_SCORE', '85.0'))
20
 
21
  # Load model and tokenizer (Updated to use distilgpt2)
22
+ model_name = "distilgpt2" # Using distilgpt2 for faster response
23
  tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
24
  model = AutoModelForCausalLM.from_pretrained(model_name, low_cpu_mem_usage=True)
25
 
 
145
  # Fallback generation for checklist and suggestions
146
  def generate_fallback_checklist(role, milestones):
147
  checklist_items = []
 
148
  if role.lower() == "safety officer":
149
  checklist_items.append("- Conduct morning safety briefing")
150
  checklist_items.append("- Review incident reports from previous day")
151
 
152
+ if milestones and milestones.strip():
153
+ kpis = [kpi.strip() for kpi in milestones.split(",")]
 
154
  for kpi in kpis:
155
  checklist_items.append(f"- Ensure progress on {kpi}")
156
  else:
 
160
 
161
  def generate_fallback_suggestions(reflection):
162
  suggestions_items = []
 
163
  reflection_lower = reflection.lower()
164
  if "student" in reflection_lower or "learning" in reflection_lower:
165
  suggestions_items.append("- Ensure students are logging incidents consistently")
166
+ suggestions_items.append("- Provide guidance on timely incident recording")
167
  if "incident" in reflection_lower:
168
  suggestions_items.append("- Follow up on reported incidents with corrective actions")
169
 
 
190
  with torch.no_grad():
191
  outputs = model.generate(
192
  inputs['input_ids'],
193
+ max_new_tokens=100, # Limit output length to avoid delays
194
  no_repeat_ngram_size=2,
195
  do_sample=True,
196
  top_p=0.9,
 
200
  result = tokenizer.decode(outputs[0], skip_special_tokens=True)
201
  except Exception as e:
202
  print(f"⚠️ Generation error: {e}")
 
203
  checklist = generate_fallback_checklist(role, milestones)
204
  suggestions = generate_fallback_suggestions(reflection)
205
  save_to_salesforce(role, project_id, milestones, reflection, checklist, suggestions, supervisor_name)
 
213
  checklist = extract_between(result, "Checklist:\n", "Suggestions:")
214
  suggestions = extract_between(result, "Suggestions:\n", None)
215
 
 
216
  if not checklist.strip():
217
  checklist = generate_fallback_checklist(role, milestones)
218
  if not suggestions.strip():
 
269
  if __name__ == "__main__":
270
  app = create_interface()
271
  app.launch()