geethareddy commited on
Commit
34607a1
·
verified ·
1 Parent(s): de5735a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +43 -90
app.py CHANGED
@@ -1,14 +1,9 @@
1
  import gradio as gr
2
  import torch
3
  from transformers import AutoModelForCausalLM, AutoTokenizer
4
- from joblib import Memory
5
  import datetime
6
 
7
- # Initialize cache
8
- cache_dir = "./cache"
9
- memory = Memory(cache_dir, verbose=0)
10
-
11
- # Load pre-trained model and tokenizer (allow online download)
12
  model_name = "distilgpt2"
13
  tokenizer = AutoTokenizer.from_pretrained(model_name)
14
  model = AutoModelForCausalLM.from_pretrained(model_name)
@@ -17,22 +12,18 @@ model = AutoModelForCausalLM.from_pretrained(model_name)
17
  tokenizer.pad_token = tokenizer.eos_token
18
  model.config.pad_token_id = tokenizer.eos_token_id
19
 
20
- # Define a prompt template (structured format)
21
  PROMPT_TEMPLATE = """You are an AI coach for construction supervisors. Based on the following inputs, generate a daily checklist, focus suggestions, and a motivational quote. Format your response with clear labels as follows:
22
 
23
  Checklist:
24
- - Item 1
25
- - Item 2
26
 
27
  Suggestions:
28
- - Suggestion 1
29
- - Suggestion 2
30
 
31
  Quote:
32
  - Your motivational quote here
33
 
34
- Now, generate the checklist, suggestions, and quote for the following inputs:
35
-
36
  Inputs:
37
  Role: {role}
38
  Project: {project_id}
@@ -40,113 +31,75 @@ Milestones: {milestones}
40
  Reflection: {reflection}
41
  """
42
 
43
- # Cache reset check
44
- last_reset = datetime.date.today()
45
-
46
- def reset_cache_if_new_day():
47
- global last_reset
48
- today = datetime.date.today()
49
- if today > last_reset:
50
- memory.clear()
51
- last_reset = today
52
-
53
- # Cached generation function with improved parsing and context-aware fallbacks
54
- @memory.cache
55
  def generate_outputs(role, project_id, milestones, reflection):
56
- reset_cache_if_new_day()
57
-
58
- # Validate inputs
59
  if not all([role, project_id, milestones, reflection]):
60
  return "Error: All fields are required.", "", ""
61
 
62
- # Create prompt
 
 
 
 
 
 
 
 
 
 
 
63
  prompt = PROMPT_TEMPLATE.format(
64
  role=role,
65
  project_id=project_id,
66
  milestones=milestones,
67
- reflection=reflection
68
- )
69
-
70
- # Tokenize with attention_mask
71
- inputs = tokenizer(
72
- prompt,
73
- return_tensors="pt",
74
- max_length=512,
75
- truncation=True,
76
- padding=True,
77
- return_attention_mask=True
78
  )
 
 
 
79
 
80
- # Generate with attention_mask
81
- outputs = model.generate(
82
- inputs["input_ids"],
83
- attention_mask=inputs["attention_mask"],
84
- max_length=1500,
85
- num_return_sequences=1,
86
- no_repeat_ngram_size=2,
87
- do_sample=True,
88
- top_p=0.9,
89
- temperature=0.8,
90
- pad_token_id=tokenizer.eos_token_id
91
- )
92
 
93
- # Decode generated text
94
  generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
95
-
96
- # Parse the output using labels
97
  checklist = "No checklist generated."
98
  suggestions = "No suggestions generated."
99
  quote = "No quote generated."
100
-
101
- # Look for sections using labels
102
  if "Checklist:" in generated_text:
103
  checklist_start = generated_text.find("Checklist:") + len("Checklist:")
104
  suggestions_start = generated_text.find("Suggestions:")
105
- if suggestions_start == -1:
106
- suggestions_start = len(generated_text)
107
  checklist = generated_text[checklist_start:suggestions_start].strip()
108
- if not checklist:
109
- checklist = "No checklist generated."
110
 
111
  if "Suggestions:" in generated_text:
112
  suggestions_start = generated_text.find("Suggestions:") + len("Suggestions:")
113
  quote_start = generated_text.find("Quote:")
114
- if quote_start == -1:
115
- quote_start = len(generated_text)
116
  suggestions = generated_text[suggestions_start:quote_start].strip()
117
- if not suggestions:
118
- suggestions = "No suggestions generated."
119
 
120
  if "Quote:" in generated_text:
121
  quote_start = generated_text.find("Quote:") + len("Quote:")
122
  quote = generated_text[quote_start:].strip()
123
- if not quote:
124
- quote = "No quote generated."
125
-
126
- # Context-aware fallbacks based on inputs
127
- if checklist == "No checklist generated.":
128
- checklist_items = []
129
- milestone_list = [m.strip() for m in milestones.split(",")]
130
- for i, milestone in enumerate(milestone_list, 1):
131
- checklist_items.append(f"- {milestone} by {8 + i*2} AM.")
132
- checklist_items.append("- Check equipment status before end of day.")
133
- checklist = "\n".join(checklist_items)
134
-
135
- if suggestions == "No suggestions generated.":
136
- suggestions_items = []
137
- if "equipment issues" in reflection.lower():
138
- suggestions_items.append("- Schedule equipment maintenance to avoid future delays.")
139
- if "suppliers" in reflection.lower():
140
- suggestions_items.append("- Set up a morning call with suppliers to confirm timelines.")
141
- suggestions_items.append("- Brief the team on tomorrow’s goals during the daily huddle.")
142
- suggestions = "\n".join(suggestions_items if suggestions_items else ["- Coordinate with the team.", "- Plan for contingencies."])
143
-
144
- if quote == "No quote generated.":
145
- quote = "- Keep building—every step forward counts!"
146
-
147
  return checklist, suggestions, quote
148
 
149
- # Gradio interface
150
  def create_interface():
151
  with gr.Blocks() as demo:
152
  gr.Markdown("# Construction Supervisor AI Coach")
 
1
  import gradio as gr
2
  import torch
3
  from transformers import AutoModelForCausalLM, AutoTokenizer
 
4
  import datetime
5
 
6
+ # Initialize model and tokenizer (preloading them for quicker response)
 
 
 
 
7
  model_name = "distilgpt2"
8
  tokenizer = AutoTokenizer.from_pretrained(model_name)
9
  model = AutoModelForCausalLM.from_pretrained(model_name)
 
12
  tokenizer.pad_token = tokenizer.eos_token
13
  model.config.pad_token_id = tokenizer.eos_token_id
14
 
15
+ # Define a more contextual prompt template
16
  PROMPT_TEMPLATE = """You are an AI coach for construction supervisors. Based on the following inputs, generate a daily checklist, focus suggestions, and a motivational quote. Format your response with clear labels as follows:
17
 
18
  Checklist:
19
+ - {milestones_list}
 
20
 
21
  Suggestions:
22
+ - {suggestions_list}
 
23
 
24
  Quote:
25
  - Your motivational quote here
26
 
 
 
27
  Inputs:
28
  Role: {role}
29
  Project: {project_id}
 
31
  Reflection: {reflection}
32
  """
33
 
34
+ # Function to generate outputs based on inputs
 
 
 
 
 
 
 
 
 
 
 
35
  def generate_outputs(role, project_id, milestones, reflection):
36
+ # Validate inputs to ensure no missing fields
 
 
37
  if not all([role, project_id, milestones, reflection]):
38
  return "Error: All fields are required.", "", ""
39
 
40
+ # Create prompt from template
41
+ milestones_list = "\n- ".join([m.strip() for m in milestones.split(",")])
42
+
43
+ suggestions_list = ""
44
+ if "delays" in reflection.lower():
45
+ suggestions_list = "- Consider adjusting timelines to accommodate delays.\n- Communicate delays to all relevant stakeholders."
46
+ elif "weather" in reflection.lower():
47
+ suggestions_list = "- Ensure team has rain gear.\n- Monitor weather updates for possible further delays."
48
+ elif "equipment" in reflection.lower():
49
+ suggestions_list = "- Inspect all equipment to ensure no malfunctions.\n- Schedule maintenance if necessary."
50
+
51
+ # Create final prompt
52
  prompt = PROMPT_TEMPLATE.format(
53
  role=role,
54
  project_id=project_id,
55
  milestones=milestones,
56
+ reflection=reflection,
57
+ milestones_list=milestones_list,
58
+ suggestions_list=suggestions_list
 
 
 
 
 
 
 
 
59
  )
60
+
61
+ # Tokenize inputs for model processing
62
+ inputs = tokenizer(prompt, return_tensors="pt", max_length=512, truncation=True, padding=True)
63
 
64
+ # Generate response from the model
65
+ with torch.no_grad():
66
+ outputs = model.generate(
67
+ inputs['input_ids'],
68
+ max_length=512,
69
+ num_return_sequences=1,
70
+ no_repeat_ngram_size=2,
71
+ do_sample=True,
72
+ top_p=0.9,
73
+ temperature=0.8,
74
+ pad_token_id=tokenizer.eos_token_id
75
+ )
76
 
77
+ # Decode the response
78
  generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
79
+
80
+ # Parse the output and ensure it is structured
81
  checklist = "No checklist generated."
82
  suggestions = "No suggestions generated."
83
  quote = "No quote generated."
84
+
 
85
  if "Checklist:" in generated_text:
86
  checklist_start = generated_text.find("Checklist:") + len("Checklist:")
87
  suggestions_start = generated_text.find("Suggestions:")
 
 
88
  checklist = generated_text[checklist_start:suggestions_start].strip()
 
 
89
 
90
  if "Suggestions:" in generated_text:
91
  suggestions_start = generated_text.find("Suggestions:") + len("Suggestions:")
92
  quote_start = generated_text.find("Quote:")
 
 
93
  suggestions = generated_text[suggestions_start:quote_start].strip()
 
 
94
 
95
  if "Quote:" in generated_text:
96
  quote_start = generated_text.find("Quote:") + len("Quote:")
97
  quote = generated_text[quote_start:].strip()
98
+
99
+ # Return structured outputs
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100
  return checklist, suggestions, quote
101
 
102
+ # Gradio interface for fast user interaction
103
  def create_interface():
104
  with gr.Blocks() as demo:
105
  gr.Markdown("# Construction Supervisor AI Coach")