geethaAICoach2 / app.py
geethareddy's picture
Update app.py
77c2912 verified
raw
history blame
9.91 kB
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from simple_salesforce import Salesforce
import os
from dotenv import load_dotenv
# Load environment variables
load_dotenv()
# Required env vars check
required_env_vars = ['SF_USERNAME', 'SF_PASSWORD', 'SF_SECURITY_TOKEN']
missing_vars = [var for var in required_env_vars if not os.getenv(var)]
if missing_vars:
raise EnvironmentError(f"Missing required environment variables: {missing_vars}")
# Defaults
KPI_FLAG_DEFAULT = os.getenv('KPI_FLAG', 'True') == 'True'
ENGAGEMENT_SCORE_DEFAULT = float(os.getenv('ENGAGEMENT_SCORE', '85.0'))
# Load model and tokenizer (Updated to use distilgpt2)
model_name = "distilgpt2" # Using distilgpt2 for faster response
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
model = AutoModelForCausalLM.from_pretrained(model_name, low_cpu_mem_usage=True)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
tokenizer.pad_token_id = tokenizer.convert_tokens_to_ids(tokenizer.pad_token)
model.config.pad_token_id = tokenizer.pad_token_id
# Refined Prompt to generate day-by-day tasks based on milestones
PROMPT_TEMPLATE = """You are an AI assistant for construction supervisors. Given the role, project, milestones, and a reflection log, generate:
1. A Daily Checklist with clear and concise tasks based on the role and milestones.
Split the checklist into day-by-day tasks for a specified time period (e.g., one week).
2. Focus Suggestions based on concerns or keywords in the reflection log. Provide at least 2 suggestions.
Inputs:
Role: {role}
Project ID: {project_id}
Milestones: {milestones}
Reflection Log: {reflection}
Output Format:
Checklist (Day-by-Day):
- Day 1:
- Task 1
- Task 2
- Day 2:
- Task 1
- Task 2
...
Suggestions:
-
"""
# Salesforce Functions
def get_roles_from_salesforce():
try:
sf = Salesforce(
username=os.getenv('SF_USERNAME'),
password=os.getenv('SF_PASSWORD'),
security_token=os.getenv('SF_SECURITY_TOKEN'),
domain=os.getenv('SF_DOMAIN', 'login')
)
result = sf.query("SELECT Role__c FROM Supervisor__c WHERE Role__c != NULL")
return list(set(record['Role__c'] for record in result['records']))
except Exception as e:
print(f"⚠️ Error fetching roles: {e}")
return ["Site Manager", "Safety Officer", "Project Lead"]
def get_supervisor_name_by_role(role):
try:
sf = Salesforce(
username=os.getenv('SF_USERNAME'),
password=os.getenv('SF_PASSWORD'),
security_token=os.getenv('SF_SECURITY_TOKEN'),
domain=os.getenv('SF_DOMAIN', 'login')
)
role = role.replace("'", "\\'")
result = sf.query(f"SELECT Name FROM Supervisor__c WHERE Role__c = '{role}'")
return [record['Name'] for record in result['records']]
except Exception as e:
print(f"⚠️ Error fetching supervisor names: {e}")
return []
def get_projects_for_supervisor(supervisor_name):
try:
sf = Salesforce(
username=os.getenv('SF_USERNAME'),
password=os.getenv('SF_PASSWORD'),
security_token=os.getenv('SF_SECURITY_TOKEN'),
domain=os.getenv('SF_DOMAIN', 'login')
)
supervisor_name = supervisor_name.replace("'", "\\'")
supervisor_result = sf.query(f"SELECT Id FROM Supervisor__c WHERE Name = '{supervisor_name}' LIMIT 1")
if supervisor_result['totalSize'] == 0:
return ""
supervisor_id = supervisor_result['records'][0]['Id']
project_result = sf.query(f"SELECT Name FROM Project__c WHERE Supervisor_ID__c = '{supervisor_id}' LIMIT 1")
return project_result['records'][0]['Name'] if project_result['totalSize'] > 0 else ""
except Exception as e:
print(f"⚠️ Error fetching project: {e}")
return ""
def field_exists(sf, object_name, field_name):
try:
obj_desc = getattr(sf, object_name).describe()
return field_name in [field['name'] for field in obj_desc['fields']]
except Exception as e:
print(f"⚠️ Error checking field {field_name}: {e}")
return False
# New function to generate Salesforce dashboard URL (Visualforce Page)
def generate_salesforce_dashboard_url(supervisor_name, project_id):
# Use the provided Salesforce Visualforce URL with supervisorName and projectId as parameters
return f"https://aicoachforsitesupervisors-dev-ed--c.develop.vf.force.com/apex/DashboardPage?supervisorName={supervisor_name}&projectId={project_id}"
# Dashboard button function
def open_dashboard(role, supervisor_name, project_id):
# Generate dynamic URL based on supervisor and project
dashboard_url = generate_salesforce_dashboard_url(supervisor_name, project_id)
return f'<a href="{dashboard_url}" target="_blank" rel="noopener noreferrer" style="font-size:16px;">Open Salesforce Dashboard</a>'
# Generate function
def generate_outputs(role, supervisor_name, project_id, milestones, reflection):
if not all([role, supervisor_name, project_id, milestones, reflection]):
return "❗ Please fill all fields.", ""
prompt = PROMPT_TEMPLATE.format(
role=role,
project_id=project_id,
milestones=milestones,
reflection=reflection
)
inputs = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True, max_length=512)
try:
with torch.no_grad():
outputs = model.generate(
inputs['input_ids'],
max_new_tokens=150, # Increased max tokens to capture more content
no_repeat_ngram_size=2,
do_sample=True,
top_p=0.9,
temperature=0.7,
pad_token_id=tokenizer.pad_token_id
)
result = tokenizer.decode(outputs[0], skip_special_tokens=True)
except Exception as e:
print(f"⚠️ Generation error: {e}")
return "", ""
def extract_between(text, start, end):
s = text.find(start)
e = text.find(end, s) if end else len(text)
return text[s + len(start):e].strip() if s != -1 else ""
# Extract the checklist and suggestions
checklist = extract_between(result, "Checklist:\n", "Suggestions:")
suggestions = extract_between(result, "Suggestions:\n", None)
# If checklist or suggestions are empty, generate fallback content
if not checklist.strip():
checklist = generate_fallback_checklist(role, milestones)
if not suggestions.strip():
suggestions = generate_fallback_suggestions(reflection)
return checklist, suggestions
# Fallback generation for checklist and suggestions
def generate_fallback_checklist(role, milestones):
checklist_items = []
# If milestones are provided, add them to the checklist directly
if milestones and milestones.strip():
kpis = [kpi.strip() for kpi in milestones.split(",")]
for kpi in kpis:
checklist_items.append(f"- Ensure progress on {kpi}")
else:
checklist_items.append("- Perform daily safety inspection")
return "\n".join(checklist_items)
def generate_fallback_suggestions(reflection):
suggestions_items = []
reflection_lower = reflection.lower()
if "student" in reflection_lower or "learning" in reflection_lower:
suggestions_items.append("- Ensure students are logging incidents consistently")
suggestions_items.append("- Provide guidance on timely incident recording")
if "incident" in reflection_lower:
suggestions_items.append("- Follow up on reported incidents with corrective actions")
if not suggestions_items:
suggestions_items.append("- Monitor team coordination")
suggestions_items.append("- Review safety protocols with the team")
return "\n".join(suggestions_items)
# Interface
def create_interface():
roles = get_roles_from_salesforce()
with gr.Blocks(theme="soft") as demo:
gr.Markdown("## 🧠 AI-Powered Supervisor Assistant")
with gr.Row():
role = gr.Dropdown(choices=roles, label="Role")
supervisor_name = gr.Dropdown(choices=[], label="Supervisor Name")
project_id = gr.Textbox(label="Project ID", interactive=False)
milestones = gr.Textbox(label="Milestones (comma-separated KPIs)", placeholder="E.g. Safety training, daily inspection")
reflection = gr.Textbox(label="Reflection Log", lines=4, placeholder="Any concerns, delays, updates...")
with gr.Row():
generate = gr.Button("Generate")
clear = gr.Button("Clear")
refresh = gr.Button("🔄 Refresh Roles")
dashboard_btn = gr.Button("Dashboard")
checklist_output = gr.Textbox(label="✅ Daily Checklist")
suggestions_output = gr.Textbox(label="💡 Focus Suggestions")
dashboard_link = gr.HTML("")
role.change(fn=lambda r: gr.update(choices=get_supervisor_name_by_role(r)), inputs=role, outputs=supervisor_name)
supervisor_name.change(fn=get_projects_for_supervisor, inputs=supervisor_name, outputs=project_id)
generate.click(fn=generate_outputs,
inputs=[role, supervisor_name, project_id, milestones, reflection],
outputs=[checklist_output, suggestions_output])
clear.click(fn=lambda: ("", "", "", "", ""), inputs=None,
outputs=[role, supervisor_name, project_id, milestones, reflection])
refresh.click(fn=lambda: gr.update(choices=get_roles_from_salesforce()), outputs=role)
dashboard_btn.click(fn=open_dashboard, inputs=[role, supervisor_name, project_id], outputs=dashboard_link)
return demo
if __name__ == "__main__":
app = create_interface()
app.launch()