geethareddy commited on
Commit
12aab4e
·
verified ·
1 Parent(s): ff01f98

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +94 -135
app.py CHANGED
@@ -8,44 +8,50 @@ from dotenv import load_dotenv
8
  # Load environment variables
9
  load_dotenv()
10
 
11
- # Check for required environment variables
12
  required_env_vars = ['SF_USERNAME', 'SF_PASSWORD', 'SF_SECURITY_TOKEN']
13
  missing_vars = [var for var in required_env_vars if not os.getenv(var)]
14
  if missing_vars:
15
  raise EnvironmentError(f"Missing required environment variables: {missing_vars}")
16
 
17
- # Configurable defaults
18
  KPI_FLAG_DEFAULT = os.getenv('KPI_FLAG', 'True') == 'True'
19
  ENGAGEMENT_SCORE_DEFAULT = float(os.getenv('ENGAGEMENT_SCORE', '85.0'))
20
 
21
- # Initialize model and tokenizer with optimized settings
22
  model_name = "distilgpt2"
23
- tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True) # Use fast tokenizer
24
- model = AutoModelForCausalLM.from_pretrained(model_name, low_cpu_mem_usage=True) # Optimize memory usage
25
 
26
- # Set pad token
27
  if tokenizer.pad_token is None:
28
- tokenizer.pad_token = tokenizer.eos_token or "[PAD]"
29
  tokenizer.pad_token_id = tokenizer.convert_tokens_to_ids(tokenizer.pad_token)
30
  model.config.pad_token_id = tokenizer.pad_token_id
31
 
32
- # Prompt template
33
- PROMPT_TEMPLATE = """You are an AI coach for construction supervisors. Based on the following inputs, generate a daily checklist, focus suggestions, and a motivational quote.
 
 
 
 
 
34
  Inputs:
35
  Role: {role}
36
- Project: {project_id}
37
  Milestones: {milestones}
38
- Reflection: {reflection}
39
- Format:
 
40
  Checklist:
41
- - {milestones_list}
 
42
  Suggestions:
43
- - {suggestions_list}
 
44
  Quote:
45
- - Your motivational quote here
46
  """
47
 
48
- # Fetch roles from Salesforce
49
  def get_roles_from_salesforce():
50
  try:
51
  sf = Salesforce(
@@ -55,13 +61,11 @@ def get_roles_from_salesforce():
55
  domain=os.getenv('SF_DOMAIN', 'login')
56
  )
57
  result = sf.query("SELECT Role__c FROM Supervisor__c WHERE Role__c != NULL")
58
- roles = list(set(record['Role__c'] for record in result.get('records', [])))
59
- return roles
60
  except Exception as e:
61
  print(f"⚠️ Error fetching roles: {e}")
62
  return ["Site Manager", "Safety Officer", "Project Lead"]
63
 
64
- # Fetch supervisor names by role
65
  def get_supervisor_name_by_role(role):
66
  try:
67
  sf = Salesforce(
@@ -72,14 +76,11 @@ def get_supervisor_name_by_role(role):
72
  )
73
  role = role.replace("'", "\\'")
74
  result = sf.query(f"SELECT Name FROM Supervisor__c WHERE Role__c = '{role}'")
75
- if result['totalSize'] == 0:
76
- return []
77
  return [record['Name'] for record in result['records']]
78
  except Exception as e:
79
  print(f"⚠️ Error fetching supervisor names: {e}")
80
  return []
81
 
82
- # Fetch project for supervisor
83
  def get_projects_for_supervisor(supervisor_name):
84
  try:
85
  sf = Salesforce(
@@ -94,76 +95,11 @@ def get_projects_for_supervisor(supervisor_name):
94
  return ""
95
  supervisor_id = supervisor_result['records'][0]['Id']
96
  project_result = sf.query(f"SELECT Name FROM Project__c WHERE Supervisor_ID__c = '{supervisor_id}' LIMIT 1")
97
- if project_result['totalSize'] == 0:
98
- return ""
99
- return project_result['records'][0]['Name']
100
  except Exception as e:
101
  print(f"⚠️ Error fetching project: {e}")
102
  return ""
103
 
104
- # Generate AI outputs
105
- def generate_outputs(role, supervisor_name, project_id, milestones, reflection):
106
- if not all([role, supervisor_name, project_id, milestones, reflection]):
107
- return "Error: All fields are required.", "", ""
108
-
109
- milestones_list = "\n- ".join([m.strip() for m in milestones.split(",")])
110
- suggestions_list = ""
111
- if "delays" in reflection.lower():
112
- suggestions_list = "- Adjust timelines for delays.\n- Inform stakeholders."
113
- elif "weather" in reflection.lower():
114
- suggestions_list = "- Ensure rain gear availability.\n- Monitor weather updates."
115
- elif "equipment" in reflection.lower():
116
- suggestions_list = "- Inspect equipment.\n- Schedule maintenance."
117
-
118
- prompt = PROMPT_TEMPLATE.format(
119
- role=role,
120
- project_id=project_id,
121
- milestones=milestones,
122
- reflection=reflection,
123
- milestones_list=milestones_list,
124
- suggestions_list=suggestions_list
125
- )
126
-
127
- # Tokenize with optimized settings
128
- inputs = tokenizer(prompt, return_tensors="pt", max_length=256, truncation=True, padding=True)
129
-
130
- # Faster generation with adjusted parameters
131
- try:
132
- with torch.no_grad():
133
- outputs = model.generate(
134
- inputs['input_ids'],
135
- max_length=512, # Reduced for speed
136
- num_return_sequences=1,
137
- no_repeat_ngram_size=2,
138
- do_sample=True,
139
- top_p=0.85, # Slightly tighter for faster convergence
140
- temperature=0.7, # Lower for more deterministic output
141
- pad_token_id=tokenizer.pad_token_id
142
- )
143
- generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
144
- except Exception as e:
145
- print(f"⚠️ Error during generation: {e}")
146
- return "Error: Failed to generate outputs.", "", ""
147
-
148
- # Extract sections
149
- def extract_section(text, start_marker, end_marker):
150
- start = text.find(start_marker)
151
- if start == -1:
152
- return "Not found"
153
- start += len(start_marker)
154
- end = text.find(end_marker, start) if end_marker else len(text)
155
- return text[start:end].strip()
156
-
157
- checklist = extract_section(generated_text, "Checklist:\n", "Suggestions:")
158
- suggestions = extract_section(generated_text, "Suggestions:\n", "Quote:")
159
- quote = extract_section(generated_text, "Quote:\n", None)
160
-
161
- # Save to Salesforce
162
- save_to_salesforce(role, project_id, milestones, reflection, checklist, suggestions, quote, supervisor_name)
163
-
164
- return checklist, suggestions, quote
165
-
166
- # Check if Salesforce field exists
167
  def field_exists(sf, object_name, field_name):
168
  try:
169
  obj_desc = getattr(sf, object_name).describe()
@@ -172,7 +108,6 @@ def field_exists(sf, object_name, field_name):
172
  print(f"⚠️ Error checking field {field_name}: {e}")
173
  return False
174
 
175
- # Save to Salesforce
176
  def save_to_salesforce(role, project_id, milestones, reflection, checklist, suggestions, quote, supervisor_name):
177
  try:
178
  sf = Salesforce(
@@ -181,89 +116,113 @@ def save_to_salesforce(role, project_id, milestones, reflection, checklist, sugg
181
  security_token=os.getenv('SF_SECURITY_TOKEN'),
182
  domain=os.getenv('SF_DOMAIN', 'login')
183
  )
184
- supervisor_name = supervisor_name.replace("'", "\\'")
185
- project_id = project_id.replace("'", "\\'")
186
  supervisor_result = sf.query(f"SELECT Id FROM Supervisor__c WHERE Name = '{supervisor_name}' LIMIT 1")
187
  if supervisor_result['totalSize'] == 0:
188
- print(f"❌ No supervisor found: {supervisor_name}")
189
  return
190
  supervisor_id = supervisor_result['records'][0]['Id']
191
  project_result = sf.query(f"SELECT Id FROM Project__c WHERE Name = '{project_id}' LIMIT 1")
192
  if project_result['totalSize'] == 0:
193
- print(f"❌ No project found: {project_id}")
194
  return
195
  project_record_id = project_result['records'][0]['Id']
196
 
197
- MAX_TEXT_LENGTH = 255
198
- checklist = checklist[:MAX_TEXT_LENGTH] if checklist else ""
199
- suggestions = suggestions[:MAX_TEXT_LENGTH] if suggestions else ""
200
- reflection = reflection[:MAX_TEXT_LENGTH] if reflection else ""
201
-
202
  data = {
203
  'Supervisor_ID__c': supervisor_id,
204
  'Project_ID__c': project_record_id,
205
- 'Daily_Checklist__c': checklist,
206
- 'Suggested_Tips__c': suggestions,
207
- 'Reflection_Log__c': reflection,
208
  'KPI_Flag__c': KPI_FLAG_DEFAULT,
209
  'Engagement_Score__c': ENGAGEMENT_SCORE_DEFAULT
210
  }
211
 
212
  if field_exists(sf, 'Supervisor_AI_Coaching__c', 'Milestones_KPIs__c'):
213
- milestones = milestones[:MAX_TEXT_LENGTH] if milestones else ""
214
- data['Milestones_KPIs__c'] = milestones
215
 
216
  sf.Supervisor_AI_Coaching__c.create(data)
217
  print("✅ Record created in Salesforce.")
218
  except Exception as e:
219
  print(f"❌ Error saving to Salesforce: {e}")
220
 
221
- # Create Gradio interface
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
222
  def create_interface():
223
  roles = get_roles_from_salesforce()
224
  with gr.Blocks(theme="soft") as demo:
225
- gr.Markdown("# 🏗️ Construction Supervisor AI Coach")
226
- gr.Markdown("Enter details for daily checklist, suggestions, and quote.")
227
  with gr.Row():
228
  role = gr.Dropdown(choices=roles, label="Role")
229
  supervisor_name = gr.Dropdown(choices=[], label="Supervisor Name")
230
  project_id = gr.Textbox(label="Project ID", interactive=False)
231
- milestones = gr.Textbox(label="Milestones (comma-separated KPIs)")
232
- reflection = gr.Textbox(label="Reflection Log", lines=4)
 
 
233
  with gr.Row():
234
- submit = gr.Button("Generate", variant="primary")
235
  clear = gr.Button("Clear")
236
- refresh_btn = gr.Button("🔄 Refresh Roles")
 
237
  checklist_output = gr.Textbox(label="✅ Daily Checklist")
238
  suggestions_output = gr.Textbox(label="💡 Focus Suggestions")
239
  quote_output = gr.Textbox(label="✨ Motivational Quote")
240
 
241
- role.change(
242
- fn=lambda r: gr.update(choices=get_supervisor_name_by_role(r)),
243
- inputs=[role],
244
- outputs=[supervisor_name]
245
- )
246
- supervisor_name.change(
247
- fn=get_projects_for_supervisor,
248
- inputs=[supervisor_name],
249
- outputs=[project_id]
250
- )
251
- submit.click(
252
- fn=generate_outputs,
253
- inputs=[role, supervisor_name, project_id, milestones, reflection],
254
- outputs=[checklist_output, suggestions_output, quote_output]
255
- )
256
- clear.click(
257
- fn=lambda: ("", "", "", "", ""),
258
- inputs=None,
259
- outputs=[role, supervisor_name, project_id, milestones, reflection]
260
- )
261
- refresh_btn.click(
262
- fn=lambda: gr.update(choices=get_roles_from_salesforce()),
263
- outputs=role
264
- )
265
  return demo
266
 
267
  if __name__ == "__main__":
268
  app = create_interface()
269
- app.launch()
 
8
  # Load environment variables
9
  load_dotenv()
10
 
11
+ # Required env vars check
12
  required_env_vars = ['SF_USERNAME', 'SF_PASSWORD', 'SF_SECURITY_TOKEN']
13
  missing_vars = [var for var in required_env_vars if not os.getenv(var)]
14
  if missing_vars:
15
  raise EnvironmentError(f"Missing required environment variables: {missing_vars}")
16
 
17
+ # Defaults
18
  KPI_FLAG_DEFAULT = os.getenv('KPI_FLAG', 'True') == 'True'
19
  ENGAGEMENT_SCORE_DEFAULT = float(os.getenv('ENGAGEMENT_SCORE', '85.0'))
20
 
21
+ # Load model and tokenizer
22
  model_name = "distilgpt2"
23
+ tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
24
+ model = AutoModelForCausalLM.from_pretrained(model_name, low_cpu_mem_usage=True)
25
 
 
26
  if tokenizer.pad_token is None:
27
+ tokenizer.pad_token = tokenizer.eos_token
28
  tokenizer.pad_token_id = tokenizer.convert_tokens_to_ids(tokenizer.pad_token)
29
  model.config.pad_token_id = tokenizer.pad_token_id
30
 
31
+ # Better Prompt
32
+ PROMPT_TEMPLATE = """You are an AI coach for construction supervisors. Given the role, project, milestones, and a reflection log, generate:
33
+
34
+ 1. A Daily Checklist based on the milestones and role.
35
+ 2. Focus Suggestions based on concerns or keywords in the reflection log.
36
+ 3. A Motivational Quote for the day.
37
+
38
  Inputs:
39
  Role: {role}
40
+ Project ID: {project_id}
41
  Milestones: {milestones}
42
+ Reflection Log: {reflection}
43
+
44
+ Output Format:
45
  Checklist:
46
+ -
47
+
48
  Suggestions:
49
+ -
50
+
51
  Quote:
 
52
  """
53
 
54
+ # Salesforce Functions
55
  def get_roles_from_salesforce():
56
  try:
57
  sf = Salesforce(
 
61
  domain=os.getenv('SF_DOMAIN', 'login')
62
  )
63
  result = sf.query("SELECT Role__c FROM Supervisor__c WHERE Role__c != NULL")
64
+ return list(set(record['Role__c'] for record in result['records']))
 
65
  except Exception as e:
66
  print(f"⚠️ Error fetching roles: {e}")
67
  return ["Site Manager", "Safety Officer", "Project Lead"]
68
 
 
69
  def get_supervisor_name_by_role(role):
70
  try:
71
  sf = Salesforce(
 
76
  )
77
  role = role.replace("'", "\\'")
78
  result = sf.query(f"SELECT Name FROM Supervisor__c WHERE Role__c = '{role}'")
 
 
79
  return [record['Name'] for record in result['records']]
80
  except Exception as e:
81
  print(f"⚠️ Error fetching supervisor names: {e}")
82
  return []
83
 
 
84
  def get_projects_for_supervisor(supervisor_name):
85
  try:
86
  sf = Salesforce(
 
95
  return ""
96
  supervisor_id = supervisor_result['records'][0]['Id']
97
  project_result = sf.query(f"SELECT Name FROM Project__c WHERE Supervisor_ID__c = '{supervisor_id}' LIMIT 1")
98
+ return project_result['records'][0]['Name'] if project_result['totalSize'] > 0 else ""
 
 
99
  except Exception as e:
100
  print(f"⚠️ Error fetching project: {e}")
101
  return ""
102
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
103
  def field_exists(sf, object_name, field_name):
104
  try:
105
  obj_desc = getattr(sf, object_name).describe()
 
108
  print(f"⚠️ Error checking field {field_name}: {e}")
109
  return False
110
 
 
111
  def save_to_salesforce(role, project_id, milestones, reflection, checklist, suggestions, quote, supervisor_name):
112
  try:
113
  sf = Salesforce(
 
116
  security_token=os.getenv('SF_SECURITY_TOKEN'),
117
  domain=os.getenv('SF_DOMAIN', 'login')
118
  )
 
 
119
  supervisor_result = sf.query(f"SELECT Id FROM Supervisor__c WHERE Name = '{supervisor_name}' LIMIT 1")
120
  if supervisor_result['totalSize'] == 0:
121
+ print("❌ No supervisor found.")
122
  return
123
  supervisor_id = supervisor_result['records'][0]['Id']
124
  project_result = sf.query(f"SELECT Id FROM Project__c WHERE Name = '{project_id}' LIMIT 1")
125
  if project_result['totalSize'] == 0:
126
+ print("❌ No project found.")
127
  return
128
  project_record_id = project_result['records'][0]['Id']
129
 
 
 
 
 
 
130
  data = {
131
  'Supervisor_ID__c': supervisor_id,
132
  'Project_ID__c': project_record_id,
133
+ 'Daily_Checklist__c': checklist[:255],
134
+ 'Suggested_Tips__c': suggestions[:255],
135
+ 'Reflection_Log__c': reflection[:255],
136
  'KPI_Flag__c': KPI_FLAG_DEFAULT,
137
  'Engagement_Score__c': ENGAGEMENT_SCORE_DEFAULT
138
  }
139
 
140
  if field_exists(sf, 'Supervisor_AI_Coaching__c', 'Milestones_KPIs__c'):
141
+ data['Milestones_KPIs__c'] = milestones[:255]
 
142
 
143
  sf.Supervisor_AI_Coaching__c.create(data)
144
  print("✅ Record created in Salesforce.")
145
  except Exception as e:
146
  print(f"❌ Error saving to Salesforce: {e}")
147
 
148
+ # Generate Function
149
+ def generate_outputs(role, supervisor_name, project_id, milestones, reflection):
150
+ if not all([role, supervisor_name, project_id, milestones, reflection]):
151
+ return "❗ Please fill all fields.", "", ""
152
+
153
+ prompt = PROMPT_TEMPLATE.format(
154
+ role=role,
155
+ project_id=project_id,
156
+ milestones=milestones,
157
+ reflection=reflection
158
+ )
159
+
160
+ inputs = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True, max_length=512)
161
+ try:
162
+ with torch.no_grad():
163
+ outputs = model.generate(
164
+ inputs['input_ids'],
165
+ max_new_tokens=150,
166
+ no_repeat_ngram_size=2,
167
+ do_sample=True,
168
+ top_p=0.9,
169
+ temperature=0.8,
170
+ pad_token_id=tokenizer.pad_token_id
171
+ )
172
+ result = tokenizer.decode(outputs[0], skip_special_tokens=True)
173
+ except Exception as e:
174
+ print(f"⚠️ Generation error: {e}")
175
+ return "Error during text generation", "", ""
176
+
177
+ def extract_between(text, start, end):
178
+ s = text.find(start)
179
+ e = text.find(end, s) if end else len(text)
180
+ return text[s + len(start):e].strip() if s != -1 else "Not found"
181
+
182
+ checklist = extract_between(result, "Checklist:\n", "Suggestions:")
183
+ suggestions = extract_between(result, "Suggestions:\n", "Quote:")
184
+ quote = extract_between(result, "Quote:\n", None)
185
+
186
+ save_to_salesforce(role, project_id, milestones, reflection, checklist, suggestions, quote, supervisor_name)
187
+ return checklist, suggestions, quote
188
+
189
+ # Interface
190
  def create_interface():
191
  roles = get_roles_from_salesforce()
192
  with gr.Blocks(theme="soft") as demo:
193
+ gr.Markdown("## 🧠 AI-Powered Supervisor Assistant")
194
+
195
  with gr.Row():
196
  role = gr.Dropdown(choices=roles, label="Role")
197
  supervisor_name = gr.Dropdown(choices=[], label="Supervisor Name")
198
  project_id = gr.Textbox(label="Project ID", interactive=False)
199
+
200
+ milestones = gr.Textbox(label="Milestones (comma-separated KPIs)", placeholder="E.g. Safety training, daily inspection")
201
+ reflection = gr.Textbox(label="Reflection Log", lines=4, placeholder="Any concerns, delays, updates...")
202
+
203
  with gr.Row():
204
+ generate = gr.Button("Generate")
205
  clear = gr.Button("Clear")
206
+ refresh = gr.Button("🔄 Refresh Roles")
207
+
208
  checklist_output = gr.Textbox(label="✅ Daily Checklist")
209
  suggestions_output = gr.Textbox(label="💡 Focus Suggestions")
210
  quote_output = gr.Textbox(label="✨ Motivational Quote")
211
 
212
+ role.change(fn=lambda r: gr.update(choices=get_supervisor_name_by_role(r)), inputs=role, outputs=supervisor_name)
213
+ supervisor_name.change(fn=get_projects_for_supervisor, inputs=supervisor_name, outputs=project_id)
214
+
215
+ generate.click(fn=generate_outputs,
216
+ inputs=[role, supervisor_name, project_id, milestones, reflection],
217
+ outputs=[checklist_output, suggestions_output, quote_output])
218
+
219
+ clear.click(fn=lambda: ("", "", "", "", ""), inputs=None,
220
+ outputs=[role, supervisor_name, project_id, milestones, reflection])
221
+
222
+ refresh.click(fn=lambda: gr.update(choices=get_roles_from_salesforce()), outputs=role)
223
+
 
 
 
 
 
 
 
 
 
 
 
 
224
  return demo
225
 
226
  if __name__ == "__main__":
227
  app = create_interface()
228
+ app.launch()