Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -2,24 +2,36 @@ import gradio as gr
|
|
2 |
import torch
|
3 |
from transformers import AutoModelForCausalLM, AutoTokenizer
|
4 |
from joblib import Memory
|
5 |
-
import hashlib
|
6 |
import datetime
|
7 |
|
8 |
# Initialize cache
|
9 |
cache_dir = "./cache"
|
10 |
memory = Memory(cache_dir, verbose=0)
|
11 |
|
12 |
-
# Load
|
13 |
-
model_name = "distilgpt2"
|
14 |
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
15 |
model = AutoModelForCausalLM.from_pretrained(model_name)
|
16 |
|
17 |
-
# Define prompt template
|
18 |
-
PROMPT_TEMPLATE = """
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
19 |
Project: {project_id}
|
20 |
Milestones: {milestones}
|
21 |
Reflection: {reflection}
|
22 |
-
|
|
|
23 |
|
24 |
# Cache reset check
|
25 |
last_reset = datetime.date.today()
|
@@ -31,14 +43,14 @@ def reset_cache_if_new_day():
|
|
31 |
memory.clear()
|
32 |
last_reset = today
|
33 |
|
34 |
-
# Cached generation function
|
35 |
@memory.cache
|
36 |
def generate_outputs(role, project_id, milestones, reflection):
|
37 |
reset_cache_if_new_day()
|
38 |
|
39 |
# Validate inputs
|
40 |
if not all([role, project_id, milestones, reflection]):
|
41 |
-
return "Error: All fields are required."
|
42 |
|
43 |
# Create prompt
|
44 |
prompt = PROMPT_TEMPLATE.format(
|
@@ -52,25 +64,58 @@ def generate_outputs(role, project_id, milestones, reflection):
|
|
52 |
inputs = tokenizer(prompt, return_tensors="pt", max_length=512, truncation=True)
|
53 |
outputs = model.generate(
|
54 |
inputs["input_ids"],
|
55 |
-
max_length=
|
56 |
num_return_sequences=1,
|
57 |
no_repeat_ngram_size=2,
|
58 |
do_sample=True,
|
59 |
top_p=0.9,
|
60 |
-
temperature=0.
|
61 |
)
|
62 |
|
63 |
-
# Decode
|
64 |
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
65 |
-
sections = generated_text.split("\n\n")
|
66 |
|
67 |
-
|
68 |
-
|
69 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
70 |
|
71 |
return checklist, suggestions, quote
|
72 |
|
73 |
-
# Gradio interface
|
74 |
def create_interface():
|
75 |
with gr.Blocks() as demo:
|
76 |
gr.Markdown("# Construction Supervisor AI Coach")
|
|
|
2 |
import torch
|
3 |
from transformers import AutoModelForCausalLM, AutoTokenizer
|
4 |
from joblib import Memory
|
|
|
5 |
import datetime
|
6 |
|
7 |
# Initialize cache
|
8 |
cache_dir = "./cache"
|
9 |
memory = Memory(cache_dir, verbose=0)
|
10 |
|
11 |
+
# Load pre-trained model and tokenizer (using distilgpt2 for testing)
|
12 |
+
model_name = "distilgpt2"
|
13 |
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
14 |
model = AutoModelForCausalLM.from_pretrained(model_name)
|
15 |
|
16 |
+
# Define a more explicit prompt template
|
17 |
+
PROMPT_TEMPLATE = """You are an AI coach for construction supervisors. Based on the following inputs, generate a daily checklist, focus suggestions, and a motivational quote. Format your response with clear labels as follows:
|
18 |
+
|
19 |
+
Checklist:
|
20 |
+
- Item 1
|
21 |
+
- Item 2
|
22 |
+
Suggestions:
|
23 |
+
- Suggestion 1
|
24 |
+
- Suggestion 2
|
25 |
+
Quote:
|
26 |
+
- Your motivational quote here
|
27 |
+
|
28 |
+
Inputs:
|
29 |
+
Role: {role}
|
30 |
Project: {project_id}
|
31 |
Milestones: {milestones}
|
32 |
Reflection: {reflection}
|
33 |
+
|
34 |
+
Now, generate the checklist, suggestions, and quote."""
|
35 |
|
36 |
# Cache reset check
|
37 |
last_reset = datetime.date.today()
|
|
|
43 |
memory.clear()
|
44 |
last_reset = today
|
45 |
|
46 |
+
# Cached generation function with improved parsing
|
47 |
@memory.cache
|
48 |
def generate_outputs(role, project_id, milestones, reflection):
|
49 |
reset_cache_if_new_day()
|
50 |
|
51 |
# Validate inputs
|
52 |
if not all([role, project_id, milestones, reflection]):
|
53 |
+
return "Error: All fields are required.", "", ""
|
54 |
|
55 |
# Create prompt
|
56 |
prompt = PROMPT_TEMPLATE.format(
|
|
|
64 |
inputs = tokenizer(prompt, return_tensors="pt", max_length=512, truncation=True)
|
65 |
outputs = model.generate(
|
66 |
inputs["input_ids"],
|
67 |
+
max_length=1500, # Increased for longer outputs
|
68 |
num_return_sequences=1,
|
69 |
no_repeat_ngram_size=2,
|
70 |
do_sample=True,
|
71 |
top_p=0.9,
|
72 |
+
temperature=0.8 # Slightly higher for more creativity
|
73 |
)
|
74 |
|
75 |
+
# Decode generated text
|
76 |
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
|
|
77 |
|
78 |
+
# Parse the output using labels
|
79 |
+
checklist = "No checklist generated."
|
80 |
+
suggestions = "No suggestions generated."
|
81 |
+
quote = "No quote generated."
|
82 |
+
|
83 |
+
# Look for sections using labels
|
84 |
+
if "Checklist:" in generated_text:
|
85 |
+
checklist_start = generated_text.find("Checklist:") + len("Checklist:")
|
86 |
+
suggestions_start = generated_text.find("Suggestions:")
|
87 |
+
if suggestions_start == -1:
|
88 |
+
suggestions_start = len(generated_text)
|
89 |
+
checklist = generated_text[checklist_start:suggestions_start].strip()
|
90 |
+
if not checklist:
|
91 |
+
checklist = "No checklist generated."
|
92 |
+
|
93 |
+
if "Suggestions:" in generated_text:
|
94 |
+
suggestions_start = generated_text.find("Suggestions:") + len("Suggestions:")
|
95 |
+
quote_start = generated_text.find("Quote:")
|
96 |
+
if quote_start == -1:
|
97 |
+
quote_start = len(generated_text)
|
98 |
+
suggestions = generated_text[suggestions_start:quote_start].strip()
|
99 |
+
if not suggestions:
|
100 |
+
suggestions = "No suggestions generated."
|
101 |
+
|
102 |
+
if "Quote:" in generated_text:
|
103 |
+
quote_start = generated_text.find("Quote:") + len("Quote:")
|
104 |
+
quote = generated_text[quote_start:].strip()
|
105 |
+
if not quote:
|
106 |
+
quote = "No quote generated."
|
107 |
+
|
108 |
+
# Fallback if sections are still empty (due to model not being fine-tuned)
|
109 |
+
if checklist == "No checklist generated.":
|
110 |
+
checklist = "- Review milestones.\n- Assign tasks to team."
|
111 |
+
if suggestions == "No suggestions generated.":
|
112 |
+
suggestions = "- Coordinate with suppliers.\n- Plan for contingencies."
|
113 |
+
if quote == "No quote generated.":
|
114 |
+
quote = "- Keep pushing forward!"
|
115 |
|
116 |
return checklist, suggestions, quote
|
117 |
|
118 |
+
# Gradio interface (unchanged from original)
|
119 |
def create_interface():
|
120 |
with gr.Blocks() as demo:
|
121 |
gr.Markdown("# Construction Supervisor AI Coach")
|