tonyhui2234 commited on
Commit
630851e
·
verified ·
1 Parent(s): 03ba6c9

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +266 -0
app.py ADDED
@@ -0,0 +1,266 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import random
3
+ import pandas as pd
4
+ import requests
5
+ from io import BytesIO
6
+ from PIL import Image
7
+ from transformers import pipeline, AutoTokenizer, AutoModelForSeq2SeqLM
8
+ import re
9
+ import time
10
+
11
+ # --------------------------- Configuration & Session State ---------------------------
12
+ # Define maximum dimensions for the fortune image (in pixels)
13
+ MAX_SIZE = (400, 400)
14
+
15
+ # Initialize button click count in session state
16
+ if "button_count_temp" not in st.session_state:
17
+ st.session_state.button_count_temp = 0
18
+
19
+ # Set page configuration and title
20
+ st.set_page_config(page_title="Fortune Stick Enquiry", layout="wide")
21
+ st.title("Fortune Stick Enquiry")
22
+
23
+ # Initialize session state variables for managing application state
24
+ if "submitted_text" not in st.session_state:
25
+ st.session_state.submitted_text = False
26
+ if "fortune_number" not in st.session_state:
27
+ st.session_state.fortune_number = None
28
+ if "fortune_row" not in st.session_state:
29
+ st.session_state.fortune_row = None
30
+ if "error_message" not in st.session_state:
31
+ st.session_state.error_message = ""
32
+ if "cfu_explain_text" not in st.session_state:
33
+ st.session_state.cfu_explain_text = ""
34
+ if "stick_clicked" not in st.session_state:
35
+ st.session_state.stick_clicked = False
36
+
37
+ # Load fortune details from CSV file into session state
38
+ if "fortune_data" not in st.session_state:
39
+ try:
40
+ st.session_state.fortune_data = pd.read_csv("/home/user/app/resources/detail.csv")
41
+ except Exception as e:
42
+ st.error(f"Error loading CSV: {e}")
43
+ st.session_state.fortune_data = None
44
+
45
+ # --------------------------- Model Functions ---------------------------
46
+ # Function to load a fine-tuned classifier model and predict a label based on the question
47
+ def load_finetuned_classifier_model(question):
48
+ label_list = ["Geomancy", "Lost Property", "Personal Well-Being", "Future Prospect", "Traveling"]
49
+ # Mapping to convert default "LABEL_x" outputs to meaningful labels
50
+ mapping = {f"LABEL_{i}": label for i, label in enumerate(label_list)}
51
+
52
+ pipe = pipeline("text-classification", model="tonyhui2234/CustomModel_classifier_model_10")
53
+ prediction = pipe(question)[0]['label']
54
+ predicted_label = mapping.get(prediction, prediction)
55
+ return predicted_label
56
+
57
+ # Function to generate a detailed answer by combining the user's question and the fortune detail
58
+ def generate_answer(question, fortune):
59
+ # Start measuring runtime
60
+ start_time = time.perf_counter()
61
+ tokenizer = AutoTokenizer.from_pretrained("tonyhui2234/finetuned_model_text_gen")
62
+ model = AutoModelForSeq2SeqLM.from_pretrained("tonyhui2234/finetuned_model_text_gen", device_map="auto")
63
+ input_text = "Question: " + question + " Fortune: " + fortune
64
+ inputs = tokenizer(input_text, return_tensors="pt", truncation=True)
65
+ outputs = model.generate(
66
+ **inputs,
67
+ max_length=256,
68
+ num_beams=4,
69
+ early_stopping=True,
70
+ repetition_penalty=2.0,
71
+ no_repeat_ngram_size=3
72
+ )
73
+ answer = tokenizer.decode(outputs[0], skip_special_tokens=True)
74
+
75
+ # Stop measuring runtime
76
+ run_time = time.perf_counter() - start_time
77
+ print(f"Runtime: {run_time:.4f} seconds")
78
+ return answer
79
+
80
+ # Function that combines analysis with regex to extract the related fortune detail and then generate an answer
81
+ def analysis(row_detail, classifiy, question):
82
+ # Use the classifier's output to match the corresponding detail in the fortune data
83
+ pattern = re.compile(re.escape(classifiy) + r":\s*(.*?)(?:\.|$)", re.IGNORECASE)
84
+ match = pattern.search(row_detail)
85
+ if match:
86
+ result = match.group(1)
87
+ # Generate a custom answer based on the matched fortune detail and the user's question
88
+ return generate_answer(question, result)
89
+ else:
90
+ return "Heaven's secret cannot be revealed."
91
+
92
+ # Function to check if the input sentence is in English using a language detection model
93
+ def check_sentence_is_english_model(question):
94
+ pipe_english = pipeline("text-classification", model="eleldar/language-detection")
95
+ return pipe_english(question)[0]['label'] == 'en'
96
+
97
+ # Function to check if the input sentence is a question using a question vs. statement classifier
98
+ def check_sentence_is_question_model(question):
99
+ pipe_question = pipeline("text-classification", model="shahrukhx01/question-vs-statement-classifier")
100
+ return pipe_question(question)[0]['label'] == 'LABEL_1'
101
+
102
+ # --------------------------- Callback Functions ---------------------------
103
+ # Callback for when the submit button is clicked
104
+ def submit_text_callback():
105
+ question = st.session_state.get("user_sentence", "")
106
+ # Clear any previous error message
107
+ st.session_state.error_message = ""
108
+
109
+ # Validate that the input is in English and is a question
110
+ if not check_sentence_is_english_model(question):
111
+ st.session_state.error_message = "Please enter in English!"
112
+ st.session_state.button_count_temp = 0
113
+ return
114
+
115
+ if not check_sentence_is_question_model(question):
116
+ st.session_state.error_message = "This is not a question. Please enter again!"
117
+ st.session_state.button_count_temp = 0
118
+ return
119
+
120
+ # Require a second confirmation click to proceed
121
+ if st.session_state.button_count_temp == 0:
122
+ st.session_state.error_message = "Please take a moment to quietly reflect on your question in your mind, then click submit again!"
123
+ st.session_state.button_count_temp = 1
124
+ return
125
+
126
+ # If validations pass, set submission flag and reset click counter
127
+ st.session_state.submitted_text = True
128
+ st.session_state.button_count_temp = 0
129
+
130
+ # Randomly generate a fortune number between 1 and 100
131
+ st.session_state.fortune_number = random.randint(1, 100)
132
+
133
+ # Retrieve corresponding fortune details from the CSV based on the generated number
134
+ df = st.session_state.fortune_data
135
+ row_detail = ''
136
+ if df is not None:
137
+ matching_row = df[df['CNumber'] == st.session_state.fortune_number]
138
+ if not matching_row.empty:
139
+ row = matching_row.iloc[0]
140
+ row_detail = row.get("Detail", "No detail available.")
141
+ st.session_state.fortune_row = {
142
+ "Header": row.get("Header", "N/A"),
143
+ "Luck": row.get("Luck", "N/A"),
144
+ "Description": row.get("Description", "No description available."),
145
+ "Detail": row_detail,
146
+ "HeaderLink": row.get("link", None)
147
+ }
148
+ else:
149
+ st.session_state.fortune_row = {
150
+ "Header": "N/A",
151
+ "Luck": "N/A",
152
+ "Description": "No description available.",
153
+ "Detail": "No detail available.",
154
+ "HeaderLink": None
155
+ }
156
+
157
+ # Function to load and resize a local image file
158
+ def load_and_resize_image(path, max_size=MAX_SIZE):
159
+ try:
160
+ img = Image.open(path)
161
+ img.thumbnail(max_size, Image.Resampling.LANCZOS)
162
+ return img
163
+ except Exception as e:
164
+ st.error(f"Error loading image: {e}")
165
+ return None
166
+
167
+ # Function to download an image from a URL and resize it
168
+ def download_and_resize_image(url, max_size=MAX_SIZE):
169
+ try:
170
+ response = requests.get(url)
171
+ response.raise_for_status()
172
+ image_bytes = BytesIO(response.content)
173
+ img = Image.open(image_bytes)
174
+ img.thumbnail(max_size, Image.Resampling.LANCZOS)
175
+ return img
176
+ except Exception as e:
177
+ st.error(f"Error loading image from URL: {e}")
178
+ return None
179
+
180
+ # Callback for when the 'Cfu Explain' button is clicked
181
+ def stick_enquiry_callback():
182
+ # Retrieve the user's question and ensure fortune data is available
183
+ question = st.session_state.get("user_sentence", "")
184
+ if not st.session_state.fortune_row:
185
+ st.error("Fortune data is not available. Please submit your question first.")
186
+ return
187
+ row_detail = st.session_state.fortune_row.get("Detail", "No detail available.")
188
+
189
+ # Classify the question to determine which fortune detail to use
190
+ classifiy = load_finetuned_classifier_model(question)
191
+ # Generate an explanation based on the classification and fortune detail
192
+ cfu_explain = analysis(row_detail, classifiy, question)
193
+ # Save the generated explanation for display
194
+ st.session_state.cfu_explain_text = cfu_explain
195
+ st.session_state.stick_clicked = True
196
+
197
+ # --------------------------- Layout & Display ---------------------------
198
+ # Define the main layout with two columns: left for user input and right for fortune display
199
+ left_col, _, right_col = st.columns([3, 1, 5])
200
+
201
+ # ---- Left Column: User Input and Interaction ----
202
+ with left_col:
203
+ left_top = st.container()
204
+ left_bottom = st.container()
205
+
206
+ # Top container: Question input and submission button
207
+ with left_top:
208
+ st.text_area("Enter your question in English", key="user_sentence", height=150)
209
+ st.button("submit", key="submit_button", on_click=submit_text_callback)
210
+ if st.session_state.error_message:
211
+ st.error(st.session_state.error_message)
212
+
213
+ # Bottom container: Button to trigger explanation and display the generated answer
214
+ if st.session_state.submitted_text:
215
+ with left_bottom:
216
+ # Add spacing for better visual separation
217
+ for _ in range(5):
218
+ st.write("")
219
+ col1, col2, col3 = st.columns(3)
220
+ with col2:
221
+ st.button("Cfu Explain", key="stick_button", on_click=stick_enquiry_callback)
222
+ if st.session_state.stick_clicked:
223
+ # Display the generated explanation text
224
+ st.text_area(' ', value=st.session_state.cfu_explain_text, height=300, disabled=True)
225
+
226
+ # ---- Right Column: Fortune Display (Image and Details) ----
227
+ with right_col:
228
+ with st.container():
229
+ col_left, col_center, col_right = st.columns([1, 2, 1])
230
+ with col_center:
231
+ # Display fortune image based on fortune data availability
232
+ if st.session_state.submitted_text and st.session_state.fortune_row:
233
+ header_link = st.session_state.fortune_row.get("HeaderLink")
234
+ if header_link:
235
+ img_from_url = download_and_resize_image(header_link)
236
+ if img_from_url:
237
+ st.image(img_from_url, use_container_width=False)
238
+ else:
239
+ img = load_and_resize_image("/home/user/app/resources/error.png")
240
+ if img:
241
+ st.image(img, use_container_width=False)
242
+ else:
243
+ img = load_and_resize_image("/home/user/app/resources/error.png")
244
+ if img:
245
+ st.image(img, use_container_width=False)
246
+ else:
247
+ img = load_and_resize_image("/home/user/app/resources/fortune.png")
248
+ if img:
249
+ st.image(img, caption="Your Fortune", use_container_width=False)
250
+ with st.container():
251
+ # Display fortune details: Number, Luck, Description, and Detail
252
+ if st.session_state.fortune_row:
253
+ luck_text = st.session_state.fortune_row.get("Luck", "N/A")
254
+ description_text = st.session_state.fortune_row.get("Description", "No description available.")
255
+ detail_text = st.session_state.fortune_row.get("Detail", "No detail available.")
256
+
257
+ summary = f"""
258
+ <div style="font-size: 28px; font-weight: bold;">
259
+ Fortune stick number: {st.session_state.fortune_number}<br>
260
+ Luck: {luck_text}
261
+ </div>
262
+ """
263
+ st.markdown(summary, unsafe_allow_html=True)
264
+
265
+ st.text_area("Description", value=description_text, height=150, disabled=True)
266
+ st.text_area("Detail", value=detail_text, height=150, disabled=True)