Spaces:
Sleeping
Sleeping
import streamlit as st | |
import random | |
import pandas as pd | |
import requests | |
from io import BytesIO | |
from PIL import Image | |
from transformers import pipeline, AutoTokenizer, AutoModelForSeq2SeqLM | |
import re | |
import time | |
# --------------------------- Configuration & CSS --------------------------- | |
MAX_SIZE = (450, 450) | |
st.set_page_config(page_title="🔮 Divine Fortune Teller", page_icon=":crystal_ball:") | |
# Updated CSS: added rules to force text color to black for inputs, text areas, and markdown | |
st.markdown( | |
""" | |
<style> | |
.reportview-container { | |
background: linear-gradient(135deg, #f6d365, #fda085); | |
} | |
.card { | |
background: rgba(255, 255, 255, 0.95); | |
padding: 30px; | |
border-radius: 12px; | |
box-shadow: 0 10px 30px rgba(0, 0, 0, 0.1); | |
max-width: 800px; | |
margin: auto; | |
text-align: center; | |
} | |
/* Force all text to be black */ | |
body, input, textarea, .stMarkdown, label { | |
color: black !important; | |
} | |
</style> | |
""", | |
unsafe_allow_html=True, | |
) | |
# --------------------------- Session State Initialization --------------------------- | |
if 'submitted' not in st.session_state: | |
st.session_state.submitted = False | |
if 'error_message' not in st.session_state: | |
st.session_state.error_message = "" | |
if 'question' not in st.session_state: | |
st.session_state.question = "" | |
if 'fortune_number' not in st.session_state: | |
st.session_state.fortune_number = None | |
if 'fortune_row' not in st.session_state: | |
st.session_state.fortune_row = None | |
if "button_count_temp" not in st.session_state: | |
st.session_state.button_count_temp = 0 | |
if "cfu_explain_text" not in st.session_state: | |
st.session_state.cfu_explain_text = "" | |
# --------------------------- Load Fortune CSV --------------------------- | |
if "fortune_data" not in st.session_state: | |
try: | |
st.session_state.fortune_data = pd.read_csv("/home/user/app/resources/detail.csv") | |
except Exception as e: | |
st.error(f"Error loading CSV: {e}") | |
st.session_state.fortune_data = None | |
# --------------------------- Helper Functions --------------------------- | |
def load_and_resize_image(path, max_size=MAX_SIZE): | |
""" | |
Loads an image from a local file path and resizes it to fit within a specified maximum size. | |
""" | |
try: | |
img = Image.open(path) | |
img.thumbnail(max_size, Image.Resampling.LANCZOS) | |
return img | |
except Exception as e: | |
st.error(f"Error loading image: {e}") | |
return None | |
def download_and_resize_image(url, max_size=MAX_SIZE): | |
""" | |
Downloads an image from a given URL, then resizes it to a predefined maximum size. | |
""" | |
try: | |
response = requests.get(url) | |
response.raise_for_status() | |
image_bytes = BytesIO(response.content) | |
img = Image.open(image_bytes) | |
img.thumbnail(max_size, Image.Resampling.LANCZOS) | |
return img | |
except Exception as e: | |
st.error(f"Error loading image from URL: {e}") | |
return None | |
def display_text_field(label, text, height): | |
""" | |
Creates and displays a custom-styled text field with a title and scrollable content. | |
""" | |
html = f""" | |
<h6 style="display: block; margin-top: 10px;">{label}</h6> | |
<div style="border: 1px solid #ccc; border-radius: 4px; background-color: #f0f0f0; | |
padding: 10px; height: {height}px; overflow-y: auto; color: black; font-size: 16px;"> | |
<div>{text}</div> | |
</div> | |
""" | |
st.markdown(html, unsafe_allow_html=True) | |
# --------------------------- Model Functions --------------------------- | |
def load_finetuned_classifier_model(question): | |
""" | |
Uses a fine-tuned text classification model to categorize the user's question into one of several predefined fortune themes. | |
""" | |
label_list = ["Geomancy", "Lost Property", "Personal Well-Being", "Future Prospect", "Traveling"] | |
mapping = {f"LABEL_{i}": label for i, label in enumerate(label_list)} | |
pipe = pipeline("text-classification", model="tonyhui2234/CustomModel_classifier_model_10") | |
prediction = pipe(question)[0]['label'] | |
predicted_label = mapping.get(prediction, prediction) | |
return predicted_label | |
def generate_answer(question, fortune): | |
""" | |
Generates a detailed explanation by feeding the question and the selected fortune text into a fine-tuned sequence-to-sequence language model. | |
""" | |
start_time = time.perf_counter() | |
tokenizer = AutoTokenizer.from_pretrained("tonyhui2234/finetuned_model_text_gen") | |
model = AutoModelForSeq2SeqLM.from_pretrained("tonyhui2234/finetuned_model_text_gen", device_map="auto") | |
input_text = "Question: " + question + " Fortune: " + fortune | |
inputs = tokenizer(input_text, return_tensors="pt", truncation=True) | |
outputs = model.generate( | |
**inputs, | |
max_length=256, | |
num_beams=4, | |
early_stopping=True, | |
repetition_penalty=2.0, | |
no_repeat_ngram_size=3 | |
) | |
answer = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
run_time = time.perf_counter() - start_time | |
print(f"Runtime: {run_time:.4f} seconds") | |
return answer | |
def analysis(row_detail, classifiy, question): | |
""" | |
Extracts a specific portion of the fortune details based on the classification result and then generates an answer using the text generation model. | |
""" | |
pattern = re.compile(re.escape(classifiy) + r":\s*(.*?)(?:\.|$)", re.IGNORECASE) | |
match = pattern.search(row_detail) | |
if match: | |
result = match.group(1) | |
return generate_answer(question, result) | |
else: | |
return "Heaven's secret cannot be revealed." | |
def check_sentence_is_english_model(question): | |
""" | |
Checks if the provided text is in English using a language detection model. | |
""" | |
pipe_english = pipeline("text-classification", model="eleldar/language-detection") | |
return pipe_english(question)[0]['label'] == 'en' | |
def check_sentence_is_question_model(question): | |
""" | |
Determines whether the input text is formulated as a question using a question vs. statement classifier. | |
""" | |
pipe_question = pipeline("text-classification", model="shahrukhx01/question-vs-statement-classifier") | |
return pipe_question(question)[0]['label'] == 'LABEL_1' | |
# --------------------------- Callback Functions --------------------------- | |
def random_draw(): | |
""" | |
Randomly selects a fortune entry from the loaded CSV based on a randomly generated number and updates the session state with the fortune’s details. | |
""" | |
st.session_state.fortune_number = random.randint(1, 100) | |
df = st.session_state.fortune_data | |
if df is not None: | |
matching_row = df[df['CNumber'] == st.session_state.fortune_number] | |
if not matching_row.empty: | |
row = matching_row.iloc[0] | |
st.session_state.fortune_row = { | |
"Header": row.get("Header", "N/A"), | |
"Luck": row.get("Luck", "N/A"), | |
"Description": row.get("Description", "No description available."), | |
"Detail": row.get("Detail", "No detail available."), | |
"HeaderLink": row.get("link", None) | |
} | |
else: | |
st.session_state.fortune_row = { | |
"Header": "N/A", | |
"Luck": "N/A", | |
"Description": "No description available.", | |
"Detail": "No detail available.", | |
"HeaderLink": None | |
} | |
else: | |
st.session_state.error_message = "Fortune data is not available." | |
st.session_state.submitted = True | |
st.session_state.show_explain = False | |
def submit_callback(): | |
""" | |
Validates the initial user input (ensuring it’s non-empty, in English, and a question), prompts the user to reflect, and then triggers a random fortune draw if the criteria are met. | |
""" | |
question = st.session_state.get("question_input", "").strip() | |
if not question: | |
st.session_state.error_message = "Please enter a valid question." | |
st.session_state.submitted = False | |
return | |
if not check_sentence_is_english_model(question): | |
st.session_state.error_message = "Please enter in English!" | |
st.session_state.button_count_temp = 0 | |
return | |
if not check_sentence_is_question_model(question): | |
st.session_state.error_message = "This is not a question. Please enter again!" | |
st.session_state.button_count_temp = 0 | |
return | |
if st.session_state.button_count_temp == 0: | |
st.session_state.error_message = "Please take a moment to quietly reflect on your question in your mind, then click submit again!" | |
st.session_state.button_count_temp = 1 | |
return | |
st.session_state.error_message = "" | |
st.session_state.question = question | |
st.session_state.button_count_temp = 0 | |
random_draw() | |
def resubmit_callback(): | |
""" | |
Allows the user to submit a revised question with similar validations, then updates the fortune selection accordingly. | |
""" | |
new_question = st.session_state.get("resubmit_input", "").strip() | |
if new_question == "": | |
st.session_state.error_message = "Please enter a valid question." | |
return | |
if not check_sentence_is_english_model(new_question): | |
st.session_state.error_message = "Please enter in English!" | |
st.session_state.button_count_temp = 0 | |
return | |
if not check_sentence_is_question_model(new_question): | |
st.session_state.error_message = "This is not a question. Please enter again!" | |
st.session_state.button_count_temp = 0 | |
return | |
if st.session_state.button_count_temp == 0: | |
st.session_state.error_message = "Please take a moment to quietly reflect on your question in your mind, then click submit again!" | |
st.session_state.button_count_temp = 1 | |
return | |
st.session_state.error_message = "" | |
if new_question != st.session_state.question: | |
st.session_state.question = new_question | |
st.session_state.button_count_temp = 0 | |
random_draw() | |
def explain_callback(): | |
""" | |
Uses the selected fortune details and the classifier to generate and display a customized explanation for the user's question using the text generation model. | |
""" | |
question = st.session_state.question | |
if not st.session_state.fortune_row: | |
st.error("Fortune data is not available. Please submit your question first.") | |
return | |
row_detail = st.session_state.fortune_row.get("Detail", "No detail available.") | |
classify = load_finetuned_classifier_model(question) | |
print(f"classify Checking: {classify}\nQuestion: {question}") | |
cfu_explain = analysis(row_detail, classify, question) | |
st.session_state.cfu_explain_text = cfu_explain | |
st.session_state.show_explain = True | |
# --------------------------- Layout & Display --------------------------- | |
st.title("🔮 Divine Fortune Teller") | |
if not st.session_state.submitted: | |
st.image("/home/user/app/resources/front.png", use_container_width=True) | |
st.text_input("Ask your fortune question...", key="question_input") | |
st.button("Submit", on_click=submit_callback) | |
if st.session_state.error_message: | |
st.error(st.session_state.error_message) | |
else: | |
st.text_input("Your Question", value=st.session_state.question, key="resubmit_input") | |
st.button("Resubmit", on_click=resubmit_callback) | |
if st.session_state.error_message: | |
st.error(st.session_state.error_message) | |
col1, col2 = st.columns([2, 3]) | |
with col1: | |
if st.session_state.fortune_row and st.session_state.fortune_row.get("HeaderLink"): | |
img_from_url = download_and_resize_image(st.session_state.fortune_row.get("HeaderLink")) | |
if img_from_url: | |
st.markdown("<h6> </h6>", unsafe_allow_html=True) | |
st.image(img_from_url, use_container_width=False) | |
else: | |
default_img = load_and_resize_image("/home/user/app/resources/error.png") | |
if default_img: | |
st.image(default_img, caption="Default image", use_container_width=False) | |
else: | |
default_img = load_and_resize_image("/home/user/app/resources/error.png") | |
if default_img: | |
st.image(default_img, caption="Default image", use_container_width=False) | |
with col2: | |
if st.session_state.fortune_row: | |
luck_text = st.session_state.fortune_row.get("Luck", "N/A") | |
summary = f""" | |
<div style="font-size: 24px; font-weight: bold;"> | |
Fortune Stick Number: {st.session_state.fortune_number}<br> | |
Luck: {luck_text} | |
</div> | |
""" | |
st.markdown(summary, unsafe_allow_html=True) | |
description_text = st.session_state.fortune_row.get("Description", "No description available.") | |
detail_text = st.session_state.fortune_row.get("Detail", "No detail available.") | |
# Replace text_area with our custom text field | |
display_text_field("Description:", description_text, 180) | |
display_text_field("Detail:", detail_text, 180) | |
st.button("CFU Explain", on_click=explain_callback) | |
if st.session_state.show_explain: | |
display_text_field("Explanation:", st.session_state.cfu_explain_text, 200) | |