Spaces:
Sleeping
Sleeping
import gradio as gr | |
import torch | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
from simple_salesforce import Salesforce | |
import os | |
from dotenv import load_dotenv | |
# Load environment variables | |
load_dotenv() | |
# Check if required environment variables are set | |
required_env_vars = ['SF_USERNAME', 'SF_PASSWORD', 'SF_SECURITY_TOKEN'] | |
missing_vars = [var for var in required_env_vars if not os.getenv(var)] | |
if missing_vars: | |
raise EnvironmentError(f"Missing required environment variables: {missing_vars}") | |
# Get configurable values for KPI_Flag__c and Engagement_Score__c | |
KPI_FLAG_DEFAULT = os.getenv('KPI_FLAG', 'True') == 'True' # Default to True if not set | |
ENGAGEMENT_SCORE_DEFAULT = float(os.getenv('ENGAGEMENT_SCORE', '85.0')) # Default to 85.0 | |
# Initialize model and tokenizer | |
model_name = "distilgpt2" | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
model = AutoModelForCausalLM.from_pretrained(model_name) | |
# Avoid warnings by setting pad token | |
if tokenizer.pad_token is None: | |
tokenizer.pad_token = tokenizer.eos_token if tokenizer.eos_token else "[PAD]" | |
tokenizer.pad_token_id = tokenizer.convert_tokens_to_ids(tokenizer.pad_token) | |
model.config.pad_token_id = tokenizer.pad_token_id | |
# Prompt template for generating structured output | |
PROMPT_TEMPLATE = """You are an AI coach for construction supervisors. Based on the following inputs, generate a daily checklist, focus suggestions, and a motivational quote. | |
Inputs: | |
Role: {role} | |
Project: {project_id} | |
Milestones: {milestones} | |
Reflection: {reflection} | |
Format your response clearly like this: | |
Checklist: | |
- {milestones_list} | |
Suggestions: | |
- {suggestions_list} | |
Quote: | |
- Your motivational quote here | |
""" | |
# Function to get all roles from Salesforce | |
def get_roles_from_salesforce(): | |
try: | |
sf = Salesforce( | |
username=os.getenv('SF_USERNAME'), | |
password=os.getenv('SF_PASSWORD'), | |
security_token=os.getenv('SF_SECURITY_TOKEN'), | |
domain=os.getenv('SF_DOMAIN', 'login') | |
) | |
# Query distinct Role__c values | |
result = sf.query("SELECT Role__c FROM Supervisor__c WHERE Role__c != NULL") | |
# Extract roles and remove duplicates | |
roles = list(set(record['Role__c'] for record in result.get('records', []))) | |
print(f"β Fetched {len(roles)} unique roles from Salesforce") | |
return roles | |
except Exception as e: | |
print(f"β οΈ Error fetching roles from Salesforce: {e}") | |
print("Using fallback roles...") | |
return ["Site Manager", "Safety Officer", "Project Lead"] # Match actual active roles | |
# Function to get supervisor's Name (Auto Number) by role | |
def get_supervisor_name_by_role(role): | |
try: | |
sf = Salesforce( | |
username=os.getenv('SF_USERNAME'), | |
password=os.getenv('SF_PASSWORD'), | |
security_token=os.getenv('SF_SECURITY_TOKEN'), | |
domain=os.getenv('SF_DOMAIN', 'login') | |
) | |
# Escape single quotes in the role to prevent SOQL injection | |
role = role.replace("'", "\\'") | |
# Query all supervisors for the selected role | |
result = sf.query(f"SELECT Name FROM Supervisor__c WHERE Role__c = '{role}'") | |
if result['totalSize'] == 0: | |
print("β No matching supervisors found.") | |
return [] | |
# Extract all supervisor names | |
supervisor_names = [record['Name'] for record in result['records']] | |
print(f"β Found supervisors: {supervisor_names} for role: {role}") | |
return supervisor_names | |
except Exception as e: | |
print(f"β οΈ Error fetching supervisor names: {e}") | |
return [] | |
# Function to get project IDs and names assigned to selected supervisor | |
def get_projects_for_supervisor(supervisor_name): | |
try: | |
# Use the selected supervisor name to fetch the associated project | |
sf = Salesforce( | |
username=os.getenv('SF_USERNAME'), | |
password=os.getenv('SF_PASSWORD'), | |
security_token=os.getenv('SF_SECURITY_TOKEN'), | |
domain=os.getenv('SF_DOMAIN', 'login') | |
) | |
# Escape single quotes in the supervisor_name | |
supervisor_name = supervisor_name.replace("'", "\\'") | |
# Step 1: Get the Salesforce record ID of the supervisor based on the Name | |
supervisor_result = sf.query(f"SELECT Id FROM Supervisor__c WHERE Name = '{supervisor_name}' LIMIT 1") | |
if supervisor_result['totalSize'] == 0: | |
print("β No supervisor found with the given name.") | |
return "" | |
supervisor_id = supervisor_result['records'][0]['Id'] | |
# Step 2: Query Project__c records where Supervisor_ID__c matches the supervisor's record ID | |
project_result = sf.query(f"SELECT Name FROM Project__c WHERE Supervisor_ID__c = '{supervisor_id}' LIMIT 1") | |
if project_result['totalSize'] == 0: | |
print("β No project found for supervisor.") | |
return "" | |
project_name = project_result['records'][0]['Name'] | |
print(f"β Found project: {project_name} for supervisor: {supervisor_name}") | |
return project_name | |
except Exception as e: | |
print(f"β οΈ Error fetching project for supervisor: {e}") | |
return "" | |
# Function to generate AI-based coaching output | |
def generate_outputs(role, supervisor_name, project_id, milestones, reflection): | |
if not all([role, supervisor_name, project_id, milestones, reflection]): | |
return "Error: All fields are required.", "", "" | |
# Format the prompt | |
milestones_list = "\n- ".join([m.strip() for m in milestones.split(",")]) | |
suggestions_list = "" | |
if "delays" in reflection.lower(): | |
suggestions_list = "- Consider adjusting timelines to accommodate delays.\n- Communicate delays to all relevant stakeholders." | |
elif "weather" in reflection.lower(): | |
suggestions_list = "- Ensure team has rain gear.\n- Monitor weather updates for possible further delays." | |
elif "equipment" in reflection.lower(): | |
suggestions_list = "- Inspect all equipment to ensure no malfunctions.\n- Schedule maintenance if necessary." | |
# Fill in the prompt template | |
prompt = PROMPT_TEMPLATE.format( | |
role=role, | |
project_id=project_id, | |
milestones=milestones, | |
reflection=reflection, | |
milestones_list=milestones_list, | |
suggestions_list=suggestions_list | |
) | |
# Tokenize input | |
inputs = tokenizer(prompt, return_tensors="pt", max_length=512, truncation=True, padding=True) | |
# Generate response | |
try: | |
with torch.no_grad(): | |
outputs = model.generate( | |
inputs['input_ids'], | |
max_length=1024, # Increased to allow for longer outputs | |
num_return_sequences=1, | |
no_repeat_ngram_size=2, | |
do_sample=True, | |
top_p=0.9, | |
temperature=0.8, | |
pad_token_id=tokenizer.pad_token_id | |
) | |
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
except Exception as e: | |
print(f"β οΈ Error during model generation: {e}") | |
return "Error: Failed to generate outputs.", "", "" | |
# Parse sections | |
def extract_section(text, start_marker, end_marker): | |
start = text.find(start_marker) | |
if start == -1: | |
return "Not found" | |
start += len(start_marker) | |
end = text.find(end_marker, start) if end_marker else len(text) | |
return text[start:end].strip() | |
checklist = extract_section(generated_text, "Checklist:\n", "Suggestions:") | |
suggestions = extract_section(generated_text, "Suggestions:\n", "Quote:") | |
quote = extract_section(generated_text, "Quote:\n", None) | |
# Save to Salesforce | |
save_to_salesforce(role, project_id, milestones, reflection, checklist, suggestions, quote, supervisor_name) | |
return checklist, suggestions, quote | |
# Function to check if a field exists in a Salesforce object | |
def field_exists(sf, object_name, field_name): | |
try: | |
# Describe the object to get its fields | |
obj_desc = getattr(sf, object_name).describe() | |
fields = [field['name'] for field in obj_desc['fields']] | |
return field_name in fields | |
except Exception as e: | |
print(f"β οΈ Error checking if field {field_name} exists in {object_name}: {e}") | |
return False | |
# Function to create a record in Salesforce | |
def save_to_salesforce(role, project_id, milestones, reflection, checklist, suggestions, quote, supervisor_name): | |
try: | |
sf = Salesforce( | |
username=os.getenv('SF_USERNAME'), | |
password=os.getenv('SF_PASSWORD'), | |
security_token=os.getenv('SF_SECURITY_TOKEN'), | |
domain=os.getenv('SF_DOMAIN', 'login') | |
) | |
# Escape single quotes in supervisor_name and project_id | |
supervisor_name = supervisor_name.replace("'", "\\'") | |
project_id = project_id.replace("'", "\\'") | |
# Step 1: Get the Salesforce record ID for the supervisor | |
supervisor_result = sf.query(f"SELECT Id FROM Supervisor__c WHERE Name = '{supervisor_name}' LIMIT 1") | |
if supervisor_result['totalSize'] == 0: | |
print(f"β No supervisor found with Name: {supervisor_name}") | |
return | |
supervisor_id = supervisor_result['records'][0]['Id'] | |
# Step 2: Get the Salesforce record ID for the project | |
project_result = sf.query(f"SELECT Id FROM Project__c WHERE Name = '{project_id}' LIMIT 1") | |
if project_result['totalSize'] == 0: | |
print(f"β No project found with Name: {project_id}") | |
return | |
project_record_id = project_result['records'][0]['Id'] | |
# Truncate text fields to avoid exceeding Salesforce field length limits (assuming 255 characters for simplicity) | |
MAX_TEXT_LENGTH = 255 | |
checklist = checklist[:MAX_TEXT_LENGTH] if checklist else "" | |
suggestions = suggestions[:MAX_TEXT_LENGTH] if suggestions else "" | |
reflection = reflection[:MAX_TEXT_LENGTH] if reflection else "" | |
# Prepare data for Salesforce with explicit mapping | |
data = { | |
'Supervisor_ID__c': supervisor_id, # Lookup field expects the record ID of Supervisor__c | |
'Project_ID__c': project_record_id, # Lookup field expects the record ID of Project__c | |
'Daily_Checklist__c': checklist, # Maps to the generated Daily Checklist | |
'Suggested_Tips__c': suggestions, # Maps to the generated Focus Suggestions | |
'Reflection_Log__c': reflection, # Maps to the Reflection Log input | |
'KPI_Flag__c': KPI_FLAG_DEFAULT, # Configurable via .env | |
'Engagement_Score__c': ENGAGEMENT_SCORE_DEFAULT # Configurable via .env | |
} | |
# Check if Milestones_KPIs__c field exists before mapping | |
if field_exists(sf, 'Supervisor_AI_Coaching__c', 'Milestones_KPIs__c'): | |
# Truncate milestones as well if the field exists | |
milestones = milestones[:MAX_TEXT_LENGTH] if milestones else "" | |
data['Milestones_KPIs__c'] = milestones | |
else: | |
print("β οΈ Milestones_KPIs__c field does not exist in Supervisor_AI_Coaching__c. Skipping mapping.") | |
# Create record | |
response = sf.Supervisor_AI_Coaching__c.create(data) | |
print("β Record created successfully in Salesforce.") | |
print("Record ID:", response['id']) | |
except Exception as e: | |
print(f"β Error saving to Salesforce: {e}") | |
print("Data being sent:", data) | |
if hasattr(e, 'content'): | |
print("Salesforce API response:", e.content) | |
# Gradio Interface | |
def create_interface(): | |
# Fetch roles from Salesforce | |
roles = get_roles_from_salesforce() | |
print(f"Fetched Roles: {roles}") | |
with gr.Blocks(theme="soft") as demo: | |
gr.Markdown("# ποΈ Construction Supervisor AI Coach") | |
gr.Markdown("Enter details to generate a daily checklist, focus suggestions, and a motivational quote.") | |
with gr.Row(): | |
role = gr.Dropdown(choices=roles, label="Role") | |
supervisor_name = gr.Dropdown(choices=[], label="Supervisor Name") | |
project_id = gr.Textbox(label="Project ID", interactive=False) | |
milestones = gr.Textbox(label="Milestones (comma-separated KPIs)") | |
reflection = gr.Textbox(label="Reflection Log", lines=4) | |
with gr.Row(): | |
submit = gr.Button("Generate", variant="primary") | |
clear = gr.Button("Clear") | |
refresh_btn = gr.Button("π Refresh Roles") | |
checklist_output = gr.Textbox(label="β Daily Checklist") | |
suggestions_output = gr.Textbox(label="π‘ Focus Suggestions") | |
quote_output = gr.Textbox(label="β¨ Motivational Quote") | |
# Event: When role changes, update supervisor name dropdown | |
role.change( | |
fn=lambda r: gr.update(choices=get_supervisor_name_by_role(r)), | |
inputs=[role], | |
outputs=[supervisor_name] | |
) | |
# Event: When supervisor name changes, update project ID | |
supervisor_name.change( | |
fn=get_projects_for_supervisor, | |
inputs=[supervisor_name], | |
outputs=[project_id] | |
) | |
submit.click( | |
fn=generate_outputs, | |
inputs=[role, supervisor_name, project_id, milestones, reflection], | |
outputs=[checklist_output, suggestions_output, quote_output] | |
) | |
clear.click( | |
fn=lambda: ("", "", "", "", ""), | |
inputs=None, | |
outputs=[role, supervisor_name, project_id, milestones, reflection] | |
) | |
refresh_btn.click( | |
fn=lambda: gr.update(choices=get_roles_from_salesforce()), | |
outputs=role | |
) | |
return demo | |
if __name__ == "__main__": | |
app = create_interface() | |
app.launch() |