geethareddy commited on
Commit
6a44015
·
verified ·
1 Parent(s): 903520a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +63 -54
app.py CHANGED
@@ -1,6 +1,6 @@
1
  import gradio as gr
2
  import torch
3
- from transformers import AutoTokenizer, AutoModelForCausalLM
4
  from simple_salesforce import Salesforce
5
  import os
6
  from dotenv import load_dotenv
@@ -8,28 +8,32 @@ from dotenv import load_dotenv
8
  # Load environment variables
9
  load_dotenv()
10
 
11
- # Check for required environment variables
12
  required_env_vars = ['SF_USERNAME', 'SF_PASSWORD', 'SF_SECURITY_TOKEN']
13
  missing_vars = [var for var in required_env_vars if not os.getenv(var)]
14
  if missing_vars:
15
  raise EnvironmentError(f"Missing required environment variables: {missing_vars}")
16
 
17
- # Configurable defaults
18
  KPI_FLAG_DEFAULT = os.getenv('KPI_FLAG', 'True') == 'True'
19
  ENGAGEMENT_SCORE_DEFAULT = float(os.getenv('ENGAGEMENT_SCORE', '85.0'))
20
 
21
- # Use instruction-tuned model
22
- model_name = "mistralai/Mistral-7B-Instruct-v0.1"
23
- tokenizer = AutoTokenizer.from_pretrained(model_name)
24
- model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16, device_map="auto")
25
 
26
- # Prompt template
27
- PROMPT_TEMPLATE = """
28
- You are an AI coach for construction supervisors. Based on the information below, generate:
 
29
 
30
- 1. A Daily Checklist based on role and milestones.
31
- 2. Focus Suggestions based on reflection.
32
- 3. A Motivational Quote.
 
 
 
33
 
34
  Inputs:
35
  Role: {role}
@@ -39,15 +43,15 @@ Reflection Log: {reflection}
39
 
40
  Output Format:
41
  Checklist:
42
- -
43
 
44
  Suggestions:
45
- -
46
 
47
  Quote:
48
  """
49
 
50
- # Salesforce functions
51
  def get_roles_from_salesforce():
52
  try:
53
  sf = Salesforce(
@@ -70,6 +74,7 @@ def get_supervisor_name_by_role(role):
70
  security_token=os.getenv('SF_SECURITY_TOKEN'),
71
  domain=os.getenv('SF_DOMAIN', 'login')
72
  )
 
73
  result = sf.query(f"SELECT Name FROM Supervisor__c WHERE Role__c = '{role}'")
74
  return [record['Name'] for record in result['records']]
75
  except Exception as e:
@@ -84,22 +89,23 @@ def get_projects_for_supervisor(supervisor_name):
84
  security_token=os.getenv('SF_SECURITY_TOKEN'),
85
  domain=os.getenv('SF_DOMAIN', 'login')
86
  )
87
- result = sf.query(f"SELECT Id FROM Supervisor__c WHERE Name = '{supervisor_name}' LIMIT 1")
88
- if result['totalSize'] == 0:
 
89
  return ""
90
- supervisor_id = result['records'][0]['Id']
91
- project = sf.query(f"SELECT Name FROM Project__c WHERE Supervisor_ID__c = '{supervisor_id}' LIMIT 1")
92
- return project['records'][0]['Name'] if project['totalSize'] else ""
93
  except Exception as e:
94
  print(f"⚠️ Error fetching project: {e}")
95
  return ""
96
 
97
  def field_exists(sf, object_name, field_name):
98
  try:
99
- desc = getattr(sf, object_name).describe()
100
- return field_name in [f['name'] for f in desc['fields']]
101
  except Exception as e:
102
- print(f"⚠️ Error checking field: {e}")
103
  return False
104
 
105
  def save_to_salesforce(role, project_id, milestones, reflection, checklist, suggestions, quote, supervisor_name):
@@ -110,37 +116,39 @@ def save_to_salesforce(role, project_id, milestones, reflection, checklist, sugg
110
  security_token=os.getenv('SF_SECURITY_TOKEN'),
111
  domain=os.getenv('SF_DOMAIN', 'login')
112
  )
113
- supervisor = sf.query(f"SELECT Id FROM Supervisor__c WHERE Name = '{supervisor_name}' LIMIT 1")
114
- if not supervisor['totalSize']:
 
115
  return
116
- supervisor_id = supervisor['records'][0]['Id']
117
-
118
- project = sf.query(f"SELECT Id FROM Project__c WHERE Name = '{project_id}' LIMIT 1")
119
- if not project['totalSize']:
120
  return
121
- project_id_val = project['records'][0]['Id']
122
 
123
  data = {
124
  'Supervisor_ID__c': supervisor_id,
125
- 'Project_ID__c': project_id_val,
126
  'Daily_Checklist__c': checklist[:255],
127
  'Suggested_Tips__c': suggestions[:255],
128
  'Reflection_Log__c': reflection[:255],
129
  'KPI_Flag__c': KPI_FLAG_DEFAULT,
130
  'Engagement_Score__c': ENGAGEMENT_SCORE_DEFAULT
131
  }
 
132
  if field_exists(sf, 'Supervisor_AI_Coaching__c', 'Milestones_KPIs__c'):
133
  data['Milestones_KPIs__c'] = milestones[:255]
134
 
135
  sf.Supervisor_AI_Coaching__c.create(data)
136
- print("✅ Record saved.")
137
  except Exception as e:
138
  print(f"❌ Error saving to Salesforce: {e}")
139
 
140
- # Generation logic
141
  def generate_outputs(role, supervisor_name, project_id, milestones, reflection):
142
  if not all([role, supervisor_name, project_id, milestones, reflection]):
143
- return "Please fill all fields.", "", ""
144
 
145
  prompt = PROMPT_TEMPLATE.format(
146
  role=role,
@@ -149,49 +157,48 @@ def generate_outputs(role, supervisor_name, project_id, milestones, reflection):
149
  reflection=reflection
150
  )
151
 
152
- inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
153
  try:
154
  with torch.no_grad():
155
- output_ids = model.generate(
156
  inputs['input_ids'],
157
- max_new_tokens=256,
 
158
  do_sample=True,
159
  top_p=0.9,
160
- temperature=0.7,
161
  pad_token_id=tokenizer.pad_token_id
162
  )
163
- result = tokenizer.decode(output_ids[0], skip_special_tokens=True)
164
  except Exception as e:
165
- return f"Error: {e}", "", ""
 
166
 
167
- def extract(text, start, end=None):
168
  s = text.find(start)
169
- if s == -1:
170
- return ""
171
- s += len(start)
172
  e = text.find(end, s) if end else len(text)
173
- return text[s:e].strip()
174
 
175
- checklist = extract(result, "Checklist:\n", "Suggestions:")
176
- suggestions = extract(result, "Suggestions:\n", "Quote:")
177
- quote = extract(result, "Quote:\n")
178
 
179
  save_to_salesforce(role, project_id, milestones, reflection, checklist, suggestions, quote, supervisor_name)
180
  return checklist, suggestions, quote
181
 
182
- # Gradio UI
183
  def create_interface():
184
  roles = get_roles_from_salesforce()
185
  with gr.Blocks(theme="soft") as demo:
186
- gr.Markdown("## 👷‍♂️ AI Coaching Assistant for Construction Supervisors")
187
 
188
  with gr.Row():
189
  role = gr.Dropdown(choices=roles, label="Role")
190
  supervisor_name = gr.Dropdown(choices=[], label="Supervisor Name")
191
  project_id = gr.Textbox(label="Project ID", interactive=False)
192
 
193
- milestones = gr.Textbox(label="Milestones (comma-separated)", placeholder="e.g., Safety check, Equipment inspection")
194
- reflection = gr.Textbox(label="Reflection Log", lines=4, placeholder="Any thoughts, challenges, delays...")
195
 
196
  with gr.Row():
197
  generate = gr.Button("Generate")
@@ -209,7 +216,9 @@ def create_interface():
209
  inputs=[role, supervisor_name, project_id, milestones, reflection],
210
  outputs=[checklist_output, suggestions_output, quote_output])
211
 
212
- clear.click(fn=lambda: ("", "", "", "", ""), outputs=[role, supervisor_name, project_id, milestones, reflection])
 
 
213
  refresh.click(fn=lambda: gr.update(choices=get_roles_from_salesforce()), outputs=role)
214
 
215
  return demo
 
1
  import gradio as gr
2
  import torch
3
+ from transformers import AutoModelForCausalLM, AutoTokenizer
4
  from simple_salesforce import Salesforce
5
  import os
6
  from dotenv import load_dotenv
 
8
  # Load environment variables
9
  load_dotenv()
10
 
11
+ # Required env vars check
12
  required_env_vars = ['SF_USERNAME', 'SF_PASSWORD', 'SF_SECURITY_TOKEN']
13
  missing_vars = [var for var in required_env_vars if not os.getenv(var)]
14
  if missing_vars:
15
  raise EnvironmentError(f"Missing required environment variables: {missing_vars}")
16
 
17
+ # Defaults
18
  KPI_FLAG_DEFAULT = os.getenv('KPI_FLAG', 'True') == 'True'
19
  ENGAGEMENT_SCORE_DEFAULT = float(os.getenv('ENGAGEMENT_SCORE', '85.0'))
20
 
21
+ # Load model and tokenizer
22
+ model_name = "distilgpt2"
23
+ tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
24
+ model = AutoModelForCausalLM.from_pretrained(model_name, low_cpu_mem_usage=True)
25
 
26
+ if tokenizer.pad_token is None:
27
+ tokenizer.pad_token = tokenizer.eos_token
28
+ tokenizer.pad_token_id = tokenizer.convert_tokens_to_ids(tokenizer.pad_token)
29
+ model.config.pad_token_id = tokenizer.pad_token_id
30
 
31
+ # Better Prompt
32
+ PROMPT_TEMPLATE = """You are an AI coach for construction supervisors. Given the role, project, milestones, and a reflection log, generate:
33
+
34
+ 1. A Daily Checklist based on the milestones and role.
35
+ 2. Focus Suggestions based on concerns or keywords in the reflection log.
36
+ 3. A Motivational Quote for the day.
37
 
38
  Inputs:
39
  Role: {role}
 
43
 
44
  Output Format:
45
  Checklist:
46
+ -
47
 
48
  Suggestions:
49
+ -
50
 
51
  Quote:
52
  """
53
 
54
+ # Salesforce Functions
55
  def get_roles_from_salesforce():
56
  try:
57
  sf = Salesforce(
 
74
  security_token=os.getenv('SF_SECURITY_TOKEN'),
75
  domain=os.getenv('SF_DOMAIN', 'login')
76
  )
77
+ role = role.replace("'", "\\'")
78
  result = sf.query(f"SELECT Name FROM Supervisor__c WHERE Role__c = '{role}'")
79
  return [record['Name'] for record in result['records']]
80
  except Exception as e:
 
89
  security_token=os.getenv('SF_SECURITY_TOKEN'),
90
  domain=os.getenv('SF_DOMAIN', 'login')
91
  )
92
+ supervisor_name = supervisor_name.replace("'", "\\'")
93
+ supervisor_result = sf.query(f"SELECT Id FROM Supervisor__c WHERE Name = '{supervisor_name}' LIMIT 1")
94
+ if supervisor_result['totalSize'] == 0:
95
  return ""
96
+ supervisor_id = supervisor_result['records'][0]['Id']
97
+ project_result = sf.query(f"SELECT Name FROM Project__c WHERE Supervisor_ID__c = '{supervisor_id}' LIMIT 1")
98
+ return project_result['records'][0]['Name'] if project_result['totalSize'] > 0 else ""
99
  except Exception as e:
100
  print(f"⚠️ Error fetching project: {e}")
101
  return ""
102
 
103
  def field_exists(sf, object_name, field_name):
104
  try:
105
+ obj_desc = getattr(sf, object_name).describe()
106
+ return field_name in [field['name'] for field in obj_desc['fields']]
107
  except Exception as e:
108
+ print(f"⚠️ Error checking field {field_name}: {e}")
109
  return False
110
 
111
  def save_to_salesforce(role, project_id, milestones, reflection, checklist, suggestions, quote, supervisor_name):
 
116
  security_token=os.getenv('SF_SECURITY_TOKEN'),
117
  domain=os.getenv('SF_DOMAIN', 'login')
118
  )
119
+ supervisor_result = sf.query(f"SELECT Id FROM Supervisor__c WHERE Name = '{supervisor_name}' LIMIT 1")
120
+ if supervisor_result['totalSize'] == 0:
121
+ print("❌ No supervisor found.")
122
  return
123
+ supervisor_id = supervisor_result['records'][0]['Id']
124
+ project_result = sf.query(f"SELECT Id FROM Project__c WHERE Name = '{project_id}' LIMIT 1")
125
+ if project_result['totalSize'] == 0:
126
+ print("❌ No project found.")
127
  return
128
+ project_record_id = project_result['records'][0]['Id']
129
 
130
  data = {
131
  'Supervisor_ID__c': supervisor_id,
132
+ 'Project_ID__c': project_record_id,
133
  'Daily_Checklist__c': checklist[:255],
134
  'Suggested_Tips__c': suggestions[:255],
135
  'Reflection_Log__c': reflection[:255],
136
  'KPI_Flag__c': KPI_FLAG_DEFAULT,
137
  'Engagement_Score__c': ENGAGEMENT_SCORE_DEFAULT
138
  }
139
+
140
  if field_exists(sf, 'Supervisor_AI_Coaching__c', 'Milestones_KPIs__c'):
141
  data['Milestones_KPIs__c'] = milestones[:255]
142
 
143
  sf.Supervisor_AI_Coaching__c.create(data)
144
+ print("✅ Record created in Salesforce.")
145
  except Exception as e:
146
  print(f"❌ Error saving to Salesforce: {e}")
147
 
148
+ # Generate Function
149
  def generate_outputs(role, supervisor_name, project_id, milestones, reflection):
150
  if not all([role, supervisor_name, project_id, milestones, reflection]):
151
+ return "Please fill all fields.", "", ""
152
 
153
  prompt = PROMPT_TEMPLATE.format(
154
  role=role,
 
157
  reflection=reflection
158
  )
159
 
160
+ inputs = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True, max_length=512)
161
  try:
162
  with torch.no_grad():
163
+ outputs = model.generate(
164
  inputs['input_ids'],
165
+ max_new_tokens=150,
166
+ no_repeat_ngram_size=2,
167
  do_sample=True,
168
  top_p=0.9,
169
+ temperature=0.8,
170
  pad_token_id=tokenizer.pad_token_id
171
  )
172
+ result = tokenizer.decode(outputs[0], skip_special_tokens=True)
173
  except Exception as e:
174
+ print(f"⚠️ Generation error: {e}")
175
+ return "Error during text generation", "", ""
176
 
177
+ def extract_between(text, start, end):
178
  s = text.find(start)
 
 
 
179
  e = text.find(end, s) if end else len(text)
180
+ return text[s + len(start):e].strip() if s != -1 else "Not found"
181
 
182
+ checklist = extract_between(result, "Checklist:\n", "Suggestions:")
183
+ suggestions = extract_between(result, "Suggestions:\n", "Quote:")
184
+ quote = extract_between(result, "Quote:\n", None)
185
 
186
  save_to_salesforce(role, project_id, milestones, reflection, checklist, suggestions, quote, supervisor_name)
187
  return checklist, suggestions, quote
188
 
189
+ # Interface
190
  def create_interface():
191
  roles = get_roles_from_salesforce()
192
  with gr.Blocks(theme="soft") as demo:
193
+ gr.Markdown("## 🧠 AI-Powered Supervisor Assistant")
194
 
195
  with gr.Row():
196
  role = gr.Dropdown(choices=roles, label="Role")
197
  supervisor_name = gr.Dropdown(choices=[], label="Supervisor Name")
198
  project_id = gr.Textbox(label="Project ID", interactive=False)
199
 
200
+ milestones = gr.Textbox(label="Milestones (comma-separated KPIs)", placeholder="E.g. Safety training, daily inspection")
201
+ reflection = gr.Textbox(label="Reflection Log", lines=4, placeholder="Any concerns, delays, updates...")
202
 
203
  with gr.Row():
204
  generate = gr.Button("Generate")
 
216
  inputs=[role, supervisor_name, project_id, milestones, reflection],
217
  outputs=[checklist_output, suggestions_output, quote_output])
218
 
219
+ clear.click(fn=lambda: ("", "", "", "", ""), inputs=None,
220
+ outputs=[role, supervisor_name, project_id, milestones, reflection])
221
+
222
  refresh.click(fn=lambda: gr.update(choices=get_roles_from_salesforce()), outputs=role)
223
 
224
  return demo