Spaces:
Sleeping
Sleeping
import gradio as gr | |
import torch | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
from simple_salesforce import Salesforce | |
import os | |
from dotenv import load_dotenv | |
# Load environment variables | |
load_dotenv() | |
# Required env vars check | |
required_env_vars = ['SF_USERNAME', 'SF_PASSWORD', 'SF_SECURITY_TOKEN'] | |
missing_vars = [var for var in required_env_vars if not os.getenv(var)] | |
if missing_vars: | |
raise EnvironmentError(f"Missing required environment variables: {missing_vars}") | |
# Defaults | |
KPI_FLAG_DEFAULT = os.getenv('KPI_FLAG', 'True') == 'True' | |
ENGAGEMENT_SCORE_DEFAULT = float(os.getenv('ENGAGEMENT_SCORE', '85.0')) | |
# Load model and tokenizer (Updated to use distilgpt2) | |
model_name = "distilgpt2" # Using distilgpt2 for faster response | |
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True) | |
model = AutoModelForCausalLM.from_pretrained(model_name, low_cpu_mem_usage=True) | |
if tokenizer.pad_token is None: | |
tokenizer.pad_token = tokenizer.eos_token | |
tokenizer.pad_token_id = tokenizer.convert_tokens_to_ids(tokenizer.pad_token) | |
model.config.pad_token_id = tokenizer.pad_token_id | |
# Refined Prompt to generate day-by-day tasks based on milestones | |
PROMPT_TEMPLATE = """You are an AI assistant for construction supervisors. Given the role, project, milestones, and a reflection log, generate: | |
1. A Daily Checklist with clear and concise tasks based on the role and milestones. | |
Split the checklist into day-by-day tasks for a specified time period (e.g., one week). | |
2. Focus Suggestions based on concerns or keywords in the reflection log. Provide at least 2 suggestions. | |
Inputs: | |
Role: {role} | |
Project ID: {project_id} | |
Milestones: {milestones} | |
Reflection Log: {reflection} | |
Output Format: | |
Checklist (Day-by-Day): | |
- Day 1: | |
- Task 1 | |
- Task 2 | |
- Day 2: | |
- Task 1 | |
- Task 2 | |
... | |
Suggestions: | |
- | |
""" | |
# Salesforce Functions | |
def get_roles_from_salesforce(): | |
try: | |
sf = Salesforce( | |
username=os.getenv('SF_USERNAME'), | |
password=os.getenv('SF_PASSWORD'), | |
security_token=os.getenv('SF_SECURITY_TOKEN'), | |
domain=os.getenv('SF_DOMAIN', 'login') | |
) | |
result = sf.query("SELECT Role__c FROM Supervisor__c WHERE Role__c != NULL") | |
return list(set(record['Role__c'] for record in result['records'])) | |
except Exception as e: | |
print(f"⚠️ Error fetching roles: {e}") | |
return ["Site Manager", "Safety Officer", "Project Lead"] | |
def get_supervisor_name_by_role(role): | |
try: | |
sf = Salesforce( | |
username=os.getenv('SF_USERNAME'), | |
password=os.getenv('SF_PASSWORD'), | |
security_token=os.getenv('SF_SECURITY_TOKEN'), | |
domain=os.getenv('SF_DOMAIN', 'login') | |
) | |
role = role.replace("'", "\\'") | |
result = sf.query(f"SELECT Name FROM Supervisor__c WHERE Role__c = '{role}'") | |
return [record['Name'] for record in result['records']] | |
except Exception as e: | |
print(f"⚠️ Error fetching supervisor names: {e}") | |
return [] | |
def get_projects_for_supervisor(supervisor_name): | |
try: | |
sf = Salesforce( | |
username=os.getenv('SF_USERNAME'), | |
password=os.getenv('SF_PASSWORD'), | |
security_token=os.getenv('SF_SECURITY_TOKEN'), | |
domain=os.getenv('SF_DOMAIN', 'login') | |
) | |
supervisor_name = supervisor_name.replace("'", "\\'") | |
supervisor_result = sf.query(f"SELECT Id FROM Supervisor__c WHERE Name = '{supervisor_name}' LIMIT 1") | |
if supervisor_result['totalSize'] == 0: | |
return "" | |
supervisor_id = supervisor_result['records'][0]['Id'] | |
project_result = sf.query(f"SELECT Name FROM Project__c WHERE Supervisor_ID__c = '{supervisor_id}' LIMIT 1") | |
return project_result['records'][0]['Name'] if project_result['totalSize'] > 0 else "" | |
except Exception as e: | |
print(f"⚠️ Error fetching project: {e}") | |
return "" | |
def field_exists(sf, object_name, field_name): | |
try: | |
obj_desc = getattr(sf, object_name).describe() | |
return field_name in [field['name'] for field in obj_desc['fields']] | |
except Exception as e: | |
print(f"⚠️ Error checking field {field_name}: {e}") | |
return False | |
# New function to generate Salesforce dashboard URL (Visualforce Page) | |
def generate_salesforce_dashboard_url(supervisor_name, project_id): | |
# Use the provided Salesforce Visualforce URL with supervisorName and projectId as parameters | |
return f"https://aicoachforsitesupervisors-dev-ed--c.develop.vf.force.com/apex/DashboardPage?supervisorName={supervisor_name}&projectId={project_id}" | |
# Dashboard button function | |
def open_dashboard(role, supervisor_name, project_id): | |
# Generate dynamic URL based on supervisor and project | |
dashboard_url = generate_salesforce_dashboard_url(supervisor_name, project_id) | |
return f'<a href="{dashboard_url}" target="_blank" rel="noopener noreferrer" style="font-size:16px;">Open Salesforce Dashboard</a>' | |
# Generate function | |
def generate_outputs(role, supervisor_name, project_id, milestones, reflection): | |
if not all([role, supervisor_name, project_id, milestones, reflection]): | |
return "❗ Please fill all fields.", "" | |
prompt = PROMPT_TEMPLATE.format( | |
role=role, | |
project_id=project_id, | |
milestones=milestones, | |
reflection=reflection | |
) | |
inputs = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True, max_length=512) | |
try: | |
with torch.no_grad(): | |
outputs = model.generate( | |
inputs['input_ids'], | |
max_new_tokens=150, # Increased max tokens to capture more content | |
no_repeat_ngram_size=2, | |
do_sample=True, | |
top_p=0.9, | |
temperature=0.7, | |
pad_token_id=tokenizer.pad_token_id | |
) | |
result = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
except Exception as e: | |
print(f"⚠️ Generation error: {e}") | |
return "", "" | |
def extract_between(text, start, end): | |
s = text.find(start) | |
e = text.find(end, s) if end else len(text) | |
return text[s + len(start):e].strip() if s != -1 else "" | |
# Extract the checklist and suggestions | |
checklist = extract_between(result, "Checklist:\n", "Suggestions:") | |
suggestions = extract_between(result, "Suggestions:\n", None) | |
# If checklist or suggestions are empty, generate fallback content | |
if not checklist.strip(): | |
checklist = generate_fallback_checklist(role, milestones) | |
if not suggestions.strip(): | |
suggestions = generate_fallback_suggestions(reflection) | |
return checklist, suggestions | |
# Fallback generation for checklist and suggestions | |
def generate_fallback_checklist(role, milestones): | |
checklist_items = [] | |
# If milestones are provided, add them to the checklist directly | |
if milestones and milestones.strip(): | |
kpis = [kpi.strip() for kpi in milestones.split(",")] | |
for kpi in kpis: | |
checklist_items.append(f"- Ensure progress on {kpi}") | |
else: | |
checklist_items.append("- Perform daily safety inspection") | |
return "\n".join(checklist_items) | |
def generate_fallback_suggestions(reflection): | |
suggestions_items = [] | |
reflection_lower = reflection.lower() | |
if "student" in reflection_lower or "learning" in reflection_lower: | |
suggestions_items.append("- Ensure students are logging incidents consistently") | |
suggestions_items.append("- Provide guidance on timely incident recording") | |
if "incident" in reflection_lower: | |
suggestions_items.append("- Follow up on reported incidents with corrective actions") | |
if not suggestions_items: | |
suggestions_items.append("- Monitor team coordination") | |
suggestions_items.append("- Review safety protocols with the team") | |
return "\n".join(suggestions_items) | |
# Interface | |
def create_interface(): | |
roles = get_roles_from_salesforce() | |
with gr.Blocks(theme="soft") as demo: | |
gr.Markdown("## 🧠 AI-Powered Supervisor Assistant") | |
with gr.Row(): | |
role = gr.Dropdown(choices=roles, label="Role") | |
supervisor_name = gr.Dropdown(choices=[], label="Supervisor Name") | |
project_id = gr.Textbox(label="Project ID", interactive=False) | |
milestones = gr.Textbox(label="Milestones (comma-separated KPIs)", placeholder="E.g. Safety training, daily inspection") | |
reflection = gr.Textbox(label="Reflection Log", lines=4, placeholder="Any concerns, delays, updates...") | |
with gr.Row(): | |
generate = gr.Button("Generate") | |
clear = gr.Button("Clear") | |
refresh = gr.Button("🔄 Refresh Roles") | |
dashboard_btn = gr.Button("Dashboard") | |
checklist_output = gr.Textbox(label="✅ Daily Checklist") | |
suggestions_output = gr.Textbox(label="💡 Focus Suggestions") | |
dashboard_link = gr.HTML("") | |
role.change(fn=lambda r: gr.update(choices=get_supervisor_name_by_role(r)), inputs=role, outputs=supervisor_name) | |
supervisor_name.change(fn=get_projects_for_supervisor, inputs=supervisor_name, outputs=project_id) | |
generate.click(fn=generate_outputs, | |
inputs=[role, supervisor_name, project_id, milestones, reflection], | |
outputs=[checklist_output, suggestions_output]) | |
clear.click(fn=lambda: ("", "", "", "", ""), inputs=None, | |
outputs=[role, supervisor_name, project_id, milestones, reflection]) | |
refresh.click(fn=lambda: gr.update(choices=get_roles_from_salesforce()), outputs=role) | |
dashboard_btn.click(fn=open_dashboard, inputs=[role, supervisor_name, project_id], outputs=dashboard_link) | |
return demo | |
if __name__ == "__main__": | |
app = create_interface() | |
app.launch() | |