tonyhui2234 commited on
Commit
cecbcaf
·
verified ·
1 Parent(s): 71b65a7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +43 -19
app.py CHANGED
@@ -7,6 +7,7 @@ from PIL import Image
7
  from transformers import pipeline, AutoTokenizer, AutoModelForSeq2SeqLM
8
  import re
9
 
 
10
  # Define maximum dimensions for the fortune image (in pixels)
11
  MAX_SIZE = (400, 400)
12
 
@@ -14,11 +15,11 @@ MAX_SIZE = (400, 400)
14
  if "button_count_temp" not in st.session_state:
15
  st.session_state.button_count_temp = 0
16
 
17
- # Set page configuration
18
  st.set_page_config(page_title="Fortuen Stick Enquiry", layout="wide")
19
  st.title("Fortuen Stick Enquiry")
20
 
21
- # Initialize session state variables
22
  if "submitted_text" not in st.session_state:
23
  st.session_state.submitted_text = False
24
  if "fortune_number" not in st.session_state:
@@ -32,6 +33,7 @@ if "cfu_explain_text" not in st.session_state:
32
  if "stick_clicked" not in st.session_state:
33
  st.session_state.stick_clicked = False
34
 
 
35
  if "fortune_data" not in st.session_state:
36
  try:
37
  st.session_state.fortune_data = pd.read_csv("/home/user/app/resources/detail.csv")
@@ -39,9 +41,11 @@ if "fortune_data" not in st.session_state:
39
  st.error(f"Error loading CSV: {e}")
40
  st.session_state.fortune_data = None
41
 
 
 
42
  def load_finetuned_classifier_model(question):
43
  label_list = ["Geomancy", "Lost Property", "Personal Well-Being", "Future Prospect", "Traveling"]
44
- # Create a mapping dictionary to convert the default "LABEL_x" output.
45
  mapping = {f"LABEL_{i}": label for i, label in enumerate(label_list)}
46
 
47
  pipe = pipeline("text-classification", model="tonyhui2234/CustomModel_classifier_model_10")
@@ -50,7 +54,7 @@ def load_finetuned_classifier_model(question):
50
  print(predicted_label)
51
  return predicted_label
52
 
53
- # Define your inference function
54
  def generate_answer(question, fortune):
55
  tokenizer = AutoTokenizer.from_pretrained("tonyhui2234/finetuned_model_text_gen")
56
  model = AutoModelForSeq2SeqLM.from_pretrained("tonyhui2234/finetuned_model_text_gen")
@@ -67,30 +71,36 @@ def generate_answer(question, fortune):
67
  answer = tokenizer.decode(outputs[0], skip_special_tokens=True)
68
  return answer
69
 
 
70
  def analysis(row_detail, classifiy, question):
71
- # Use the classifier's output (e.g. "Personal Well-Being") in the regex.
72
  pattern = re.compile(re.escape(classifiy) + r":\s*(.*?)(?:\.|$)", re.IGNORECASE)
73
  match = pattern.search(row_detail)
74
  if match:
75
  result = match.group(1)
76
- # If you want to generate a custom answer, you can call generate_answer()
77
  return generate_answer(question, result)
78
  else:
79
  return "Heaven's secret cannot be revealed."
80
 
 
81
  def check_sentence_is_english_model(question):
82
- pipe_english = pipeline("text-classification", model="papluca/xlm-roberta-base-language-detection")
83
  return pipe_english(question)[0]['label'] == 'en'
84
 
 
85
  def check_sentence_is_question_model(question):
86
  pipe_question = pipeline("text-classification", model="shahrukhx01/question-vs-statement-classifier")
87
  return pipe_question(question)[0]['label'] == 'LABEL_1'
88
 
 
 
89
  def submit_text_callback():
90
  question = st.session_state.get("user_sentence", "")
91
  # Clear any previous error message
92
  st.session_state.error_message = ""
93
 
 
94
  if not check_sentence_is_english_model(question):
95
  st.session_state.error_message = "Please enter in English!"
96
  st.session_state.button_count_temp = 0
@@ -101,18 +111,20 @@ def submit_text_callback():
101
  st.session_state.button_count_temp = 0
102
  return
103
 
 
104
  if st.session_state.button_count_temp == 0:
105
  st.session_state.error_message = "Please take a moment to quietly reflect on your question in your mind, then click submit again!"
106
  st.session_state.button_count_temp = 1
107
  return
108
 
 
109
  st.session_state.submitted_text = True
110
- st.session_state.button_count_temp = 0 # Reset the counter once submission is accepted
111
 
112
- # Randomly generate a number from 1 to 100
113
  st.session_state.fortune_number = random.randint(1, 100)
114
 
115
- # Look up the row in the CSV where CNumber matches the generated fortune number.
116
  df = st.session_state.fortune_data
117
  row_detail = ''
118
  if df is not None:
@@ -137,6 +149,7 @@ def submit_text_callback():
137
  }
138
  print(row_detail)
139
 
 
140
  def load_and_resize_image(path, max_size=MAX_SIZE):
141
  try:
142
  img = Image.open(path)
@@ -146,6 +159,7 @@ def load_and_resize_image(path, max_size=MAX_SIZE):
146
  st.error(f"Error loading image: {e}")
147
  return None
148
 
 
149
  def download_and_resize_image(url, max_size=MAX_SIZE):
150
  try:
151
  response = requests.get(url)
@@ -158,49 +172,58 @@ def download_and_resize_image(url, max_size=MAX_SIZE):
158
  st.error(f"Error loading image from URL: {e}")
159
  return None
160
 
 
161
  def stick_enquiry_callback():
162
- # Retrieve the user's question and the fortune detail
163
  question = st.session_state.get("user_sentence", "")
164
  if not st.session_state.fortune_row:
165
  st.error("Fortune data is not available. Please submit your question first.")
166
  return
167
  row_detail = st.session_state.fortune_row.get("Detail", "No detail available.")
168
- # Run the classifier model after the image has loaded
 
169
  classifiy = load_finetuned_classifier_model(question)
170
- # Generate the explanation using the analysis function
171
  cfu_explain = analysis(row_detail, classifiy, question)
172
- # Save the returned value in session state for later display
173
  st.session_state.cfu_explain_text = cfu_explain
174
  st.session_state.stick_clicked = True
175
 
176
- # Main layout: Left (input) and Right (fortune display)
 
177
  left_col, _, right_col = st.columns([3, 1, 5])
178
 
179
- # ---- Left Column ----
180
  with left_col:
181
  left_top = st.container()
182
  left_bottom = st.container()
 
 
183
  with left_top:
184
  st.text_area("Enter your question in English", key="user_sentence", height=150)
185
  st.button("submit", key="submit_button", on_click=submit_text_callback)
186
  if st.session_state.error_message:
187
  st.error(st.session_state.error_message)
 
 
188
  if st.session_state.submitted_text:
189
  with left_bottom:
 
190
  for _ in range(5):
191
  st.write("")
192
  col1, col2, col3 = st.columns(3)
193
  with col2:
194
  st.button("Cfu Explain", key="stick_button", on_click=stick_enquiry_callback)
195
  if st.session_state.stick_clicked:
196
- # Display the explanation text saved from analysis()
197
  st.text_area(' ', value=st.session_state.cfu_explain_text, height=300, disabled=True)
198
 
199
- # ---- Right Column ----
200
  with right_col:
201
  with st.container():
202
  col_left, col_center, col_right = st.columns([1, 2, 1])
203
  with col_center:
 
204
  if st.session_state.submitted_text and st.session_state.fortune_row:
205
  header_link = st.session_state.fortune_row.get("HeaderLink")
206
  if header_link:
@@ -220,6 +243,7 @@ with right_col:
220
  if img:
221
  st.image(img, caption="Your Fortune", use_container_width=False)
222
  with st.container():
 
223
  if st.session_state.fortune_row:
224
  luck_text = st.session_state.fortune_row.get("Luck", "N/A")
225
  description_text = st.session_state.fortune_row.get("Description", "No description available.")
@@ -234,4 +258,4 @@ with right_col:
234
  st.markdown(summary, unsafe_allow_html=True)
235
 
236
  st.text_area("Description", value=description_text, height=150, disabled=True)
237
- st.text_area("Detail", value=detail_text, height=150, disabled=True)
 
7
  from transformers import pipeline, AutoTokenizer, AutoModelForSeq2SeqLM
8
  import re
9
 
10
+ # --------------------------- Configuration & Session State ---------------------------
11
  # Define maximum dimensions for the fortune image (in pixels)
12
  MAX_SIZE = (400, 400)
13
 
 
15
  if "button_count_temp" not in st.session_state:
16
  st.session_state.button_count_temp = 0
17
 
18
+ # Set page configuration and title
19
  st.set_page_config(page_title="Fortuen Stick Enquiry", layout="wide")
20
  st.title("Fortuen Stick Enquiry")
21
 
22
+ # Initialize session state variables for managing application state
23
  if "submitted_text" not in st.session_state:
24
  st.session_state.submitted_text = False
25
  if "fortune_number" not in st.session_state:
 
33
  if "stick_clicked" not in st.session_state:
34
  st.session_state.stick_clicked = False
35
 
36
+ # Load fortune details from CSV file into session state
37
  if "fortune_data" not in st.session_state:
38
  try:
39
  st.session_state.fortune_data = pd.read_csv("/home/user/app/resources/detail.csv")
 
41
  st.error(f"Error loading CSV: {e}")
42
  st.session_state.fortune_data = None
43
 
44
+ # --------------------------- Model Functions ---------------------------
45
+ # Function to load a fine-tuned classifier model and predict a label based on the question
46
  def load_finetuned_classifier_model(question):
47
  label_list = ["Geomancy", "Lost Property", "Personal Well-Being", "Future Prospect", "Traveling"]
48
+ # Mapping to convert default "LABEL_x" outputs to meaningful labels
49
  mapping = {f"LABEL_{i}": label for i, label in enumerate(label_list)}
50
 
51
  pipe = pipeline("text-classification", model="tonyhui2234/CustomModel_classifier_model_10")
 
54
  print(predicted_label)
55
  return predicted_label
56
 
57
+ # Function to generate a detailed answer by combining the user's question and the fortune detail
58
  def generate_answer(question, fortune):
59
  tokenizer = AutoTokenizer.from_pretrained("tonyhui2234/finetuned_model_text_gen")
60
  model = AutoModelForSeq2SeqLM.from_pretrained("tonyhui2234/finetuned_model_text_gen")
 
71
  answer = tokenizer.decode(outputs[0], skip_special_tokens=True)
72
  return answer
73
 
74
+ # Function that combines analysis with regex to extract the related fortune detail and then generate an answer
75
  def analysis(row_detail, classifiy, question):
76
+ # Use the classifier's output to match the corresponding detail in the fortune data
77
  pattern = re.compile(re.escape(classifiy) + r":\s*(.*?)(?:\.|$)", re.IGNORECASE)
78
  match = pattern.search(row_detail)
79
  if match:
80
  result = match.group(1)
81
+ # Generate a custom answer based on the matched fortune detail and the user's question
82
  return generate_answer(question, result)
83
  else:
84
  return "Heaven's secret cannot be revealed."
85
 
86
+ # Function to check if the input sentence is in English using a language detection model
87
  def check_sentence_is_english_model(question):
88
+ pipe_english = pipeline("text-classification", model="eleldar/language-detection")
89
  return pipe_english(question)[0]['label'] == 'en'
90
 
91
+ # Function to check if the input sentence is a question using a question vs. statement classifier
92
  def check_sentence_is_question_model(question):
93
  pipe_question = pipeline("text-classification", model="shahrukhx01/question-vs-statement-classifier")
94
  return pipe_question(question)[0]['label'] == 'LABEL_1'
95
 
96
+ # --------------------------- Callback Functions ---------------------------
97
+ # Callback for when the submit button is clicked
98
  def submit_text_callback():
99
  question = st.session_state.get("user_sentence", "")
100
  # Clear any previous error message
101
  st.session_state.error_message = ""
102
 
103
+ # Validate that the input is in English and is a question
104
  if not check_sentence_is_english_model(question):
105
  st.session_state.error_message = "Please enter in English!"
106
  st.session_state.button_count_temp = 0
 
111
  st.session_state.button_count_temp = 0
112
  return
113
 
114
+ # Require a second confirmation click to proceed
115
  if st.session_state.button_count_temp == 0:
116
  st.session_state.error_message = "Please take a moment to quietly reflect on your question in your mind, then click submit again!"
117
  st.session_state.button_count_temp = 1
118
  return
119
 
120
+ # If validations pass, set submission flag and reset click counter
121
  st.session_state.submitted_text = True
122
+ st.session_state.button_count_temp = 0
123
 
124
+ # Randomly generate a fortune number between 1 and 100
125
  st.session_state.fortune_number = random.randint(1, 100)
126
 
127
+ # Retrieve corresponding fortune details from the CSV based on the generated number
128
  df = st.session_state.fortune_data
129
  row_detail = ''
130
  if df is not None:
 
149
  }
150
  print(row_detail)
151
 
152
+ # Function to load and resize a local image file
153
  def load_and_resize_image(path, max_size=MAX_SIZE):
154
  try:
155
  img = Image.open(path)
 
159
  st.error(f"Error loading image: {e}")
160
  return None
161
 
162
+ # Function to download an image from a URL and resize it
163
  def download_and_resize_image(url, max_size=MAX_SIZE):
164
  try:
165
  response = requests.get(url)
 
172
  st.error(f"Error loading image from URL: {e}")
173
  return None
174
 
175
+ # Callback for when the 'Cfu Explain' button is clicked
176
  def stick_enquiry_callback():
177
+ # Retrieve the user's question and ensure fortune data is available
178
  question = st.session_state.get("user_sentence", "")
179
  if not st.session_state.fortune_row:
180
  st.error("Fortune data is not available. Please submit your question first.")
181
  return
182
  row_detail = st.session_state.fortune_row.get("Detail", "No detail available.")
183
+
184
+ # Classify the question to determine which fortune detail to use
185
  classifiy = load_finetuned_classifier_model(question)
186
+ # Generate an explanation based on the classification and fortune detail
187
  cfu_explain = analysis(row_detail, classifiy, question)
188
+ # Save the generated explanation for display
189
  st.session_state.cfu_explain_text = cfu_explain
190
  st.session_state.stick_clicked = True
191
 
192
+ # --------------------------- Layout & Display ---------------------------
193
+ # Define the main layout with two columns: left for user input and right for fortune display
194
  left_col, _, right_col = st.columns([3, 1, 5])
195
 
196
+ # ---- Left Column: User Input and Interaction ----
197
  with left_col:
198
  left_top = st.container()
199
  left_bottom = st.container()
200
+
201
+ # Top container: Question input and submission button
202
  with left_top:
203
  st.text_area("Enter your question in English", key="user_sentence", height=150)
204
  st.button("submit", key="submit_button", on_click=submit_text_callback)
205
  if st.session_state.error_message:
206
  st.error(st.session_state.error_message)
207
+
208
+ # Bottom container: Button to trigger explanation and display the generated answer
209
  if st.session_state.submitted_text:
210
  with left_bottom:
211
+ # Add spacing for better visual separation
212
  for _ in range(5):
213
  st.write("")
214
  col1, col2, col3 = st.columns(3)
215
  with col2:
216
  st.button("Cfu Explain", key="stick_button", on_click=stick_enquiry_callback)
217
  if st.session_state.stick_clicked:
218
+ # Display the generated explanation text
219
  st.text_area(' ', value=st.session_state.cfu_explain_text, height=300, disabled=True)
220
 
221
+ # ---- Right Column: Fortune Display (Image and Details) ----
222
  with right_col:
223
  with st.container():
224
  col_left, col_center, col_right = st.columns([1, 2, 1])
225
  with col_center:
226
+ # Display fortune image based on fortune data availability
227
  if st.session_state.submitted_text and st.session_state.fortune_row:
228
  header_link = st.session_state.fortune_row.get("HeaderLink")
229
  if header_link:
 
243
  if img:
244
  st.image(img, caption="Your Fortune", use_container_width=False)
245
  with st.container():
246
+ # Display fortune details: Number, Luck, Description, and Detail
247
  if st.session_state.fortune_row:
248
  luck_text = st.session_state.fortune_row.get("Luck", "N/A")
249
  description_text = st.session_state.fortune_row.get("Description", "No description available.")
 
258
  st.markdown(summary, unsafe_allow_html=True)
259
 
260
  st.text_area("Description", value=description_text, height=150, disabled=True)
261
+ st.text_area("Detail", value=detail_text, height=150, disabled=True)