neerajkalyank commited on
Commit
9e6531c
·
verified ·
1 Parent(s): e4ced37

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +61 -16
app.py CHANGED
@@ -2,24 +2,36 @@ import gradio as gr
2
  import torch
3
  from transformers import AutoModelForCausalLM, AutoTokenizer
4
  from joblib import Memory
5
- import hashlib
6
  import datetime
7
 
8
  # Initialize cache
9
  cache_dir = "./cache"
10
  memory = Memory(cache_dir, verbose=0)
11
 
12
- # Load fine-tuned model and tokenizer (replace with your model path)
13
- model_name = "distilgpt2" # Placeholder; use your fine-tuned model from Hugging Face Hub
14
  tokenizer = AutoTokenizer.from_pretrained(model_name)
15
  model = AutoModelForCausalLM.from_pretrained(model_name)
16
 
17
- # Define prompt template
18
- PROMPT_TEMPLATE = """Role: {role}
 
 
 
 
 
 
 
 
 
 
 
 
19
  Project: {project_id}
20
  Milestones: {milestones}
21
  Reflection: {reflection}
22
- Generate a daily checklist, focus suggestions, and a motivational quote for a construction supervisor."""
 
23
 
24
  # Cache reset check
25
  last_reset = datetime.date.today()
@@ -31,14 +43,14 @@ def reset_cache_if_new_day():
31
  memory.clear()
32
  last_reset = today
33
 
34
- # Cached generation function
35
  @memory.cache
36
  def generate_outputs(role, project_id, milestones, reflection):
37
  reset_cache_if_new_day()
38
 
39
  # Validate inputs
40
  if not all([role, project_id, milestones, reflection]):
41
- return "Error: All fields are required."
42
 
43
  # Create prompt
44
  prompt = PROMPT_TEMPLATE.format(
@@ -52,25 +64,58 @@ def generate_outputs(role, project_id, milestones, reflection):
52
  inputs = tokenizer(prompt, return_tensors="pt", max_length=512, truncation=True)
53
  outputs = model.generate(
54
  inputs["input_ids"],
55
- max_length=1000,
56
  num_return_sequences=1,
57
  no_repeat_ngram_size=2,
58
  do_sample=True,
59
  top_p=0.9,
60
- temperature=0.7
61
  )
62
 
63
- # Decode and parse output
64
  generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
65
- sections = generated_text.split("\n\n")
66
 
67
- checklist = sections[0] if len(sections) > 0 else "No checklist generated."
68
- suggestions = sections[1] if len(sections) > 1 else "No suggestions generated."
69
- quote = sections[2] if len(sections) > 2 else "No quote generated."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
 
71
  return checklist, suggestions, quote
72
 
73
- # Gradio interface
74
  def create_interface():
75
  with gr.Blocks() as demo:
76
  gr.Markdown("# Construction Supervisor AI Coach")
 
2
  import torch
3
  from transformers import AutoModelForCausalLM, AutoTokenizer
4
  from joblib import Memory
 
5
  import datetime
6
 
7
  # Initialize cache
8
  cache_dir = "./cache"
9
  memory = Memory(cache_dir, verbose=0)
10
 
11
+ # Load pre-trained model and tokenizer (using distilgpt2 for testing)
12
+ model_name = "distilgpt2"
13
  tokenizer = AutoTokenizer.from_pretrained(model_name)
14
  model = AutoModelForCausalLM.from_pretrained(model_name)
15
 
16
+ # Define a more explicit prompt template
17
+ PROMPT_TEMPLATE = """You are an AI coach for construction supervisors. Based on the following inputs, generate a daily checklist, focus suggestions, and a motivational quote. Format your response with clear labels as follows:
18
+
19
+ Checklist:
20
+ - Item 1
21
+ - Item 2
22
+ Suggestions:
23
+ - Suggestion 1
24
+ - Suggestion 2
25
+ Quote:
26
+ - Your motivational quote here
27
+
28
+ Inputs:
29
+ Role: {role}
30
  Project: {project_id}
31
  Milestones: {milestones}
32
  Reflection: {reflection}
33
+
34
+ Now, generate the checklist, suggestions, and quote."""
35
 
36
  # Cache reset check
37
  last_reset = datetime.date.today()
 
43
  memory.clear()
44
  last_reset = today
45
 
46
+ # Cached generation function with improved parsing
47
  @memory.cache
48
  def generate_outputs(role, project_id, milestones, reflection):
49
  reset_cache_if_new_day()
50
 
51
  # Validate inputs
52
  if not all([role, project_id, milestones, reflection]):
53
+ return "Error: All fields are required.", "", ""
54
 
55
  # Create prompt
56
  prompt = PROMPT_TEMPLATE.format(
 
64
  inputs = tokenizer(prompt, return_tensors="pt", max_length=512, truncation=True)
65
  outputs = model.generate(
66
  inputs["input_ids"],
67
+ max_length=1500, # Increased for longer outputs
68
  num_return_sequences=1,
69
  no_repeat_ngram_size=2,
70
  do_sample=True,
71
  top_p=0.9,
72
+ temperature=0.8 # Slightly higher for more creativity
73
  )
74
 
75
+ # Decode generated text
76
  generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
 
77
 
78
+ # Parse the output using labels
79
+ checklist = "No checklist generated."
80
+ suggestions = "No suggestions generated."
81
+ quote = "No quote generated."
82
+
83
+ # Look for sections using labels
84
+ if "Checklist:" in generated_text:
85
+ checklist_start = generated_text.find("Checklist:") + len("Checklist:")
86
+ suggestions_start = generated_text.find("Suggestions:")
87
+ if suggestions_start == -1:
88
+ suggestions_start = len(generated_text)
89
+ checklist = generated_text[checklist_start:suggestions_start].strip()
90
+ if not checklist:
91
+ checklist = "No checklist generated."
92
+
93
+ if "Suggestions:" in generated_text:
94
+ suggestions_start = generated_text.find("Suggestions:") + len("Suggestions:")
95
+ quote_start = generated_text.find("Quote:")
96
+ if quote_start == -1:
97
+ quote_start = len(generated_text)
98
+ suggestions = generated_text[suggestions_start:quote_start].strip()
99
+ if not suggestions:
100
+ suggestions = "No suggestions generated."
101
+
102
+ if "Quote:" in generated_text:
103
+ quote_start = generated_text.find("Quote:") + len("Quote:")
104
+ quote = generated_text[quote_start:].strip()
105
+ if not quote:
106
+ quote = "No quote generated."
107
+
108
+ # Fallback if sections are still empty (due to model not being fine-tuned)
109
+ if checklist == "No checklist generated.":
110
+ checklist = "- Review milestones.\n- Assign tasks to team."
111
+ if suggestions == "No suggestions generated.":
112
+ suggestions = "- Coordinate with suppliers.\n- Plan for contingencies."
113
+ if quote == "No quote generated.":
114
+ quote = "- Keep pushing forward!"
115
 
116
  return checklist, suggestions, quote
117
 
118
+ # Gradio interface (unchanged from original)
119
  def create_interface():
120
  with gr.Blocks() as demo:
121
  gr.Markdown("# Construction Supervisor AI Coach")