Zul001 commited on
Commit
92a04df
·
verified ·
1 Parent(s): 125c7b2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +52 -60
app.py CHANGED
@@ -1,17 +1,13 @@
1
- #importing libraries
2
  import gradio as gr
3
  import tensorflow.keras as keras
4
  import time
5
  import keras_nlp
6
  import os
7
 
8
-
9
  model_path = "Zul001/HydroSense_Gemma_Finetuned_Model"
10
  gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset(f"hf://{model_path}")
11
 
12
-
13
  reset_triggered = False
14
-
15
  custom_css = """
16
  @import url('https://fonts.googleapis.com/css2?family=Edu+AU+VIC+WA+NT+Dots:[email protected]&family=Give+You+Glory&family=Sofia&family=Sunshiney&family=Vujahday+Script&display=swap');
17
  .gradio-container, .gradio-container * {
@@ -31,27 +27,53 @@ function refresh() {
31
  }
32
  """
33
 
34
-
35
  previous_sessions = []
36
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
  def post_process_output(prompt, result):
38
- # Remove the prompt if it's repeated at the beginning of the answer
39
  answer = result.strip()
40
  if answer.startswith(prompt):
41
  answer = answer[len(prompt):].strip()
42
 
43
- # Remove any leading colons or whitespace
44
  answer = answer.lstrip(':')
45
-
46
- # Ensure the answer starts with a capital letter
47
  answer = answer.capitalize()
48
 
49
- # Ensure the answer ends with a period if it doesn't already
50
  if not answer.endswith('.'):
51
  answer += '.'
52
 
53
  return f"{answer}"
54
-
55
 
56
  def add_session(prompt):
57
  global previous_sessions
@@ -62,60 +84,36 @@ def add_session(prompt):
62
 
63
  return "\n".join(previous_sessions) # Return only the session logs as a string
64
 
 
 
 
 
 
 
 
 
 
65
 
 
 
66
 
67
  def inference(prompt):
68
  global reset_triggered
69
  if reset_triggered:
70
- #do nothing
71
  return "", ""
72
 
73
- prompt_text = prompt
74
- generated_text = gemma_lm.generate(prompt_text)
75
-
76
- #Apply post-processing
77
- formatted_output = post_process_output(prompt_text, generated_text)
78
- print(formatted_output)
79
 
80
- #adding a bit of delay
81
  time.sleep(1)
 
82
  result = formatted_output
83
- sessions = add_session(prompt_text)
84
  return result, sessions
85
 
86
-
87
- # def inference(prompt):
88
-
89
- # time.sleep(1)
90
- # result = "Your Result"
91
- # # sessions = add_session(prompt)
92
- # return result
93
-
94
-
95
- # def remember(prompt, result):
96
- # global memory
97
- # # Store the session as a dictionary
98
- # session = {'prompt': prompt, 'result': result}
99
- # memory.append(session)
100
-
101
- # # Update previous_sessions for display
102
- # session_display = [f"Q: {s['prompt']} \nA: {s['result']}" for s in memory]
103
-
104
- # return "\n\n".join(session_display) # Return formatted sessions as a string
105
-
106
-
107
-
108
- def clear_sessions():
109
- global previous_sessions
110
- previous_sessions.clear()
111
- return "\n".join(previous_sessions)
112
-
113
- def clear_fields():
114
- global reset_triggered
115
- reset_triggered = True
116
- return "", "" # Return empty strings to clear the prompt and output fields
117
-
118
-
119
  with gr.Blocks(theme='gradio/soft', css=custom_css) as demo:
120
 
121
  gr.Markdown("<center><h1>HydroSense LLM Demo</h1></center>")
@@ -135,9 +133,6 @@ with gr.Blocks(theme='gradio/soft', css=custom_css) as demo:
135
  generate_btn = gr.Button("Generate Answer", variant="primary", size="sm")
136
  reset_btn = gr.Button("Clear Content", variant="secondary", size="sm", elem_id="primary")
137
 
138
-
139
-
140
-
141
  generate_btn.click(
142
  fn=inference,
143
  inputs=[prompt],
@@ -156,19 +151,16 @@ with gr.Blocks(theme='gradio/soft', css=custom_css) as demo:
156
  outputs=[prompt, output]
157
  )
158
 
159
-
160
- # Button to clear the prompt and output fields
161
  add_button.click(
162
  fn=clear_fields, # Only call the clear_fields function
163
  inputs=None, # No inputs needed
164
  outputs=[prompt, output] # Clear the prompt and output fields
165
  )
166
 
167
-
168
  clear_session.click(
169
  fn=clear_sessions,
170
  inputs=None,
171
  outputs=[session_list]
172
  )
173
 
174
- demo.launch(share=True)
 
 
1
  import gradio as gr
2
  import tensorflow.keras as keras
3
  import time
4
  import keras_nlp
5
  import os
6
 
 
7
  model_path = "Zul001/HydroSense_Gemma_Finetuned_Model"
8
  gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset(f"hf://{model_path}")
9
 
 
10
  reset_triggered = False
 
11
  custom_css = """
12
  @import url('https://fonts.googleapis.com/css2?family=Edu+AU+VIC+WA+NT+Dots:[email protected]&family=Give+You+Glory&family=Sofia&family=Sunshiney&family=Vujahday+Script&display=swap');
13
  .gradio-container, .gradio-container * {
 
27
  }
28
  """
29
 
 
30
  previous_sessions = []
31
 
32
+ class ChatState():
33
+ __START_TURN_USER__ = "<start_of_turn>user\n"
34
+ __START_TURN_MODEL__ = "<start_of_turn>model\n"
35
+ __END_TURN__ = "<end_of_turn>\n"
36
+
37
+ def __init__(self, model, system=""):
38
+ self.model = model
39
+ self.system = system
40
+ self.history = []
41
+
42
+ def add_to_history_as_user(self, message):
43
+ self.history.append(self.__START_TURN_USER__ + message + self.__END_TURN__)
44
+
45
+ def add_to_history_as_model(self, message):
46
+ self.history.append(self.__START_TURN_MODEL__ + message)
47
+
48
+ def get_history(self):
49
+ return "".join([*self.history])
50
+
51
+ def get_full_prompt(self):
52
+ prompt = self.get_history() + self.__START_TURN_MODEL__
53
+ if len(self.system) > 0:
54
+ prompt = self.system + "\n" + prompt
55
+ return prompt
56
+
57
+ def send_message(self, message):
58
+ self.add_to_history_as_user(message)
59
+ prompt = self.get_full_prompt()
60
+ response = self.model.generate(prompt, max_length=2048)
61
+ result = response.replace(prompt, "") # Extract only the new response
62
+ self.add_to_history_as_model(result)
63
+ return result
64
+
65
  def post_process_output(prompt, result):
 
66
  answer = result.strip()
67
  if answer.startswith(prompt):
68
  answer = answer[len(prompt):].strip()
69
 
 
70
  answer = answer.lstrip(':')
 
 
71
  answer = answer.capitalize()
72
 
 
73
  if not answer.endswith('.'):
74
  answer += '.'
75
 
76
  return f"{answer}"
 
77
 
78
  def add_session(prompt):
79
  global previous_sessions
 
84
 
85
  return "\n".join(previous_sessions) # Return only the session logs as a string
86
 
87
+ def clear_sessions():
88
+ global previous_sessions
89
+ previous_sessions.clear()
90
+ return "\n".join(previous_sessions)
91
+
92
+ def clear_fields():
93
+ global reset_triggered
94
+ reset_triggered = True
95
+ return "", "" # Return empty strings to clear the prompt and output fields
96
 
97
+ # Initialize the ChatState
98
+ chat_state = ChatState(gemma_lm)
99
 
100
  def inference(prompt):
101
  global reset_triggered
102
  if reset_triggered:
 
103
  return "", ""
104
 
105
+ chat_state.send_message(prompt) # Process the user's message
106
+
107
+ # Post-process the output from the model
108
+ formatted_output = post_process_output(chat_state.get_full_prompt(), chat_state.get_history())
 
 
109
 
110
+ # Apply a bit of delay for a realistic response time
111
  time.sleep(1)
112
+
113
  result = formatted_output
114
+ sessions = add_session(chat_state.get_history())
115
  return result, sessions
116
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
117
  with gr.Blocks(theme='gradio/soft', css=custom_css) as demo:
118
 
119
  gr.Markdown("<center><h1>HydroSense LLM Demo</h1></center>")
 
133
  generate_btn = gr.Button("Generate Answer", variant="primary", size="sm")
134
  reset_btn = gr.Button("Clear Content", variant="secondary", size="sm", elem_id="primary")
135
 
 
 
 
136
  generate_btn.click(
137
  fn=inference,
138
  inputs=[prompt],
 
151
  outputs=[prompt, output]
152
  )
153
 
 
 
154
  add_button.click(
155
  fn=clear_fields, # Only call the clear_fields function
156
  inputs=None, # No inputs needed
157
  outputs=[prompt, output] # Clear the prompt and output fields
158
  )
159
 
 
160
  clear_session.click(
161
  fn=clear_sessions,
162
  inputs=None,
163
  outputs=[session_list]
164
  )
165
 
166
+ demo.launch(share=True)