Spaces:
Sleeping
Sleeping
import gradio as gr | |
import torch | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
import datetime | |
# Initialize model and tokenizer (preloading them for quicker response) | |
model_name = "distilgpt2" | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
model = AutoModelForCausalLM.from_pretrained(model_name) | |
# Set pad_token_id to eos_token_id to avoid warnings | |
tokenizer.pad_token = tokenizer.eos_token | |
model.config.pad_token_id = tokenizer.eos_token_id | |
# Define a more contextual prompt template | |
PROMPT_TEMPLATE = """You are an AI coach for construction supervisors. Based on the following inputs, generate a daily checklist, focus suggestions, and a motivational quote. Format your response with clear labels as follows: | |
Checklist: | |
- {milestones_list} | |
Suggestions: | |
- {suggestions_list} | |
Quote: | |
- Your motivational quote here | |
Inputs: | |
Role: {role} | |
Project: {project_id} | |
Milestones: {milestones} | |
Reflection: {reflection} | |
""" | |
# Function to generate outputs based on inputs | |
def generate_outputs(role, project_id, milestones, reflection): | |
# Validate inputs to ensure no missing fields | |
if not all([role, project_id, milestones, reflection]): | |
return "Error: All fields are required.", "", "" | |
# Create prompt from template | |
milestones_list = "\n- ".join([m.strip() for m in milestones.split(",")]) | |
suggestions_list = "" | |
if "delays" in reflection.lower(): | |
suggestions_list = "- Consider adjusting timelines to accommodate delays.\n- Communicate delays to all relevant stakeholders." | |
elif "weather" in reflection.lower(): | |
suggestions_list = "- Ensure team has rain gear.\n- Monitor weather updates for possible further delays." | |
elif "equipment" in reflection.lower(): | |
suggestions_list = "- Inspect all equipment to ensure no malfunctions.\n- Schedule maintenance if necessary." | |
# Create final prompt | |
prompt = PROMPT_TEMPLATE.format( | |
role=role, | |
project_id=project_id, | |
milestones=milestones, | |
reflection=reflection, | |
milestones_list=milestones_list, | |
suggestions_list=suggestions_list | |
) | |
# Tokenize inputs for model processing | |
inputs = tokenizer(prompt, return_tensors="pt", max_length=512, truncation=True, padding=True) | |
# Generate response from the model | |
with torch.no_grad(): | |
outputs = model.generate( | |
inputs['input_ids'], | |
max_length=512, | |
num_return_sequences=1, | |
no_repeat_ngram_size=2, | |
do_sample=True, | |
top_p=0.9, | |
temperature=0.8, | |
pad_token_id=tokenizer.eos_token_id | |
) | |
# Decode the response | |
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
# Parse the output and ensure it is structured | |
checklist = "No checklist generated." | |
suggestions = "No suggestions generated." | |
quote = "No quote generated." | |
if "Checklist:" in generated_text: | |
checklist_start = generated_text.find("Checklist:") + len("Checklist:") | |
suggestions_start = generated_text.find("Suggestions:") | |
checklist = generated_text[checklist_start:suggestions_start].strip() | |
if "Suggestions:" in generated_text: | |
suggestions_start = generated_text.find("Suggestions:") + len("Suggestions:") | |
quote_start = generated_text.find("Quote:") | |
suggestions = generated_text[suggestions_start:quote_start].strip() | |
if "Quote:" in generated_text: | |
quote_start = generated_text.find("Quote:") + len("Quote:") | |
quote = generated_text[quote_start:].strip() | |
# Return structured outputs | |
return checklist, suggestions, quote | |
# Gradio interface for fast user interaction | |
def create_interface(): | |
with gr.Blocks() as demo: | |
gr.Markdown("# Construction Supervisor AI Coach") | |
gr.Markdown("Enter details to generate a daily checklist, focus suggestions, and a motivational quote.") | |
with gr.Row(): | |
role = gr.Dropdown(choices=["Supervisor", "Foreman", "Project Manager"], label="Role") | |
project_id = gr.Textbox(label="Project ID") | |
milestones = gr.Textbox(label="Milestones (comma-separated KPIs)") | |
reflection = gr.Textbox(label="Reflection Log", lines=5) | |
with gr.Row(): | |
submit = gr.Button("Generate") | |
clear = gr.Button("Clear") | |
checklist_output = gr.Textbox(label="Daily Checklist") | |
suggestions_output = gr.Textbox(label="Focus Suggestions") | |
quote_output = gr.Textbox(label="Motivational Quote") | |
submit.click( | |
fn=generate_outputs, | |
inputs=[role, project_id, milestones, reflection], | |
outputs=[checklist_output, suggestions_output, quote_output] | |
) | |
clear.click( | |
fn=lambda: ("", "", "", ""), | |
inputs=None, | |
outputs=[role, project_id, milestones, reflection] | |
) | |
return demo | |
if __name__ == "__main__": | |
demo = create_interface() | |
demo.launch() | |