Spaces:
Sleeping
Sleeping
import streamlit as st | |
import pandas as pd | |
import os | |
import json | |
import base64 | |
import random | |
from streamlit_pdf_viewer import pdf_viewer | |
from langchain.prompts import PromptTemplate | |
from datetime import datetime | |
from pathlib import Path | |
from openai import OpenAI | |
from dotenv import load_dotenv | |
import warnings | |
warnings.filterwarnings('ignore') | |
os.getenv("OAUTH_CLIENT_ID") | |
# Load environment variables and initialize the OpenAI client to use Hugging Face Inference API. | |
load_dotenv() | |
client = OpenAI( | |
base_url="https://api-inference.huggingface.co/v1", | |
api_key=os.environ.get('TOKEN2) # Hugging Face API token | |
) | |
# Create necessary directories | |
for dir_name in ['data', 'feedback']: | |
if not os.path.exists(dir_name): | |
os.makedirs(dir_name) | |
# Custom CSS | |
st.markdown(""" | |
<style> | |
.stButton > button { | |
width: 100%; | |
margin-bottom: 10px; | |
background-color: #4CAF50; | |
color: white; | |
border: none; | |
padding: 10px; | |
border-radius: 5px; | |
} | |
.task-button { | |
background-color: #2196F3 !important; | |
} | |
.stSelectbox { | |
margin-bottom: 20px; | |
} | |
.output-container { | |
padding: 20px; | |
border-radius: 5px; | |
border: 1px solid #ddd; | |
margin: 10px 0; | |
} | |
.status-container { | |
padding: 10px; | |
border-radius: 5px; | |
margin: 10px 0; | |
} | |
.sidebar-info { | |
padding: 10px; | |
background-color: #f0f2f6; | |
border-radius: 5px; | |
margin: 10px 0; | |
} | |
.feedback-button { | |
background-color: #ff9800 !important; | |
} | |
.feedback-container { | |
padding: 15px; | |
background-color: #f5f5f5; | |
border-radius: 5px; | |
margin: 15px 0; | |
} | |
</style> | |
""", unsafe_allow_html=True) | |
# Helper functions | |
def read_csv_with_encoding(file): | |
encodings = ['utf-8', 'latin1', 'iso-8859-1', 'cp1252'] | |
for encoding in encodings: | |
try: | |
return pd.read_csv(file, encoding=encoding) | |
except UnicodeDecodeError: | |
continue | |
raise UnicodeDecodeError("Failed to read file with any supported encoding") | |
#def save_feedback(feedback_data): | |
#feedback_file = 'feedback/user_feedback.csv' | |
#feedback_df = pd.DataFrame([feedback_data]) | |
#if os.path.exists(feedback_file): | |
#feedback_df.to_csv(feedback_file, mode='a', header=False, index=False) | |
#else: | |
#feedback_df.to_csv(feedback_file, index=False) | |
def reset_conversation(): | |
st.session_state.conversation = [] | |
st.session_state.messages = [] | |
if 'task_choice' in st.session_state: | |
del st.session_state.task_choice | |
return None | |
#new 24 March | |
#user_input = st.text_input("Enter your prompt:") | |
###########33 | |
# Initialize session state variables | |
if "messages" not in st.session_state: | |
st.session_state.messages = [] | |
if "examples_to_classify" not in st.session_state: | |
st.session_state.examples_to_classify = [] | |
if "system_role" not in st.session_state: | |
st.session_state.system_role = "" | |
# Main app title | |
st.title("π€π¦ Text Data Labeling and Generation App") | |
# def embed_pdf_sidebar(pdf_path): | |
# with open(pdf_path, "rb") as f: | |
# base64_pdf = base64.b64encode(f.read()).decode('utf-8') | |
# pdf_display = f""" | |
# <iframe src="data:application/pdf;base64,{base64_pdf}" | |
# width="100%" height="400" type="application/pdf"></iframe> | |
# """ | |
# st.markdown(pdf_display, unsafe_allow_html=True) | |
# | |
# Sidebar settings | |
with st.sidebar: | |
st.title("βοΈ Settings") | |
# Add PDF upload section | |
# | |
# if st.button("π Show Instructions"): | |
# # This should be a path to a local file | |
# pdf_path = os.path.join("Streamlit.pdf") | |
# pdf_viewer( | |
# pdf_path, | |
# width="100%", | |
# height=300, | |
# render_text=True | |
# ) | |
# with st.sidebar: | |
# with st.expander("π View Instructions"): | |
# pdf_viewer("Streamlit.pdf", width="100%", height=300, render_text=True) | |
# | |
###4 | |
# with st.sidebar: | |
# st.markdown("### π Instructions") | |
# st.markdown("[π Open Instructions PDF](/file/instructions.pdf)") | |
# | |
####2 | |
# #with st.sidebar: | |
# st.markdown("### π Instructions") | |
# # PDF served from Space's file system | |
# pdf_url = "/file/instructions.pdf" | |
# st.markdown(f""" | |
# <a href="{pdf_url}" target="_blank"> | |
# <button style='padding:10px;width:100%;font-size:16px;'>π Open Instructions PDF</button> | |
# </a> | |
# """, unsafe_allow_html=True) | |
# ###3 working code | |
# with st.sidebar: | |
# with open("instructions.pdf", "rb") as f: | |
# st.sidebar.download_button( | |
# label="π Download Instructions PDF", | |
# data=f, | |
# file_name="instructions.pdf", | |
# mime="application/pdf" | |
# ) | |
###6 | |
#this last code works | |
with st.sidebar: | |
st.markdown("### πData Generation and Labeling Instructions") | |
#st.markdown("<h4 style='color: #4A90E2;'>π Instructions</h4>", unsafe_allow_html=True) | |
with open("User instructions.pdf", "rb") as f: | |
st.download_button( | |
label="π Download Instructions PDF", | |
data=f, | |
#file_name="instructions.pdf", | |
file_name="User instructions.pdf", | |
mime="application/pdf" | |
) | |
#works with blu color text | |
# with st.sidebar: | |
# # Stylish "Instructions" label | |
# st.markdown("<h4 style='color: #4A90E2;'>π Instructions</h4>", unsafe_allow_html=True) | |
# # PDF download button | |
# with open("instructions.pdf", "rb") as f: | |
# st.download_button( | |
# label="π Download Instructions PDF", | |
# data=f, | |
# file_name="instructions.pdf", | |
# mime="application/pdf" | |
# ) | |
###5 | |
#with st.sidebar: | |
# st.markdown("### π Instructions") | |
# # PDF served from Space's file system | |
# pdf_url = "/file/instructions.pdf" | |
# st.markdown(f""" | |
# <a href="{pdf_url}" target="_blank"> | |
# <button style='padding:15px;width:100%;font-size:16px;'> π Open Instructions PDF</button> | |
# </a> | |
# """, unsafe_allow_html=True) | |
selected_model = st.selectbox( | |
"Select Model", | |
["meta-llama/Llama-3.3-70B-Instruct", "meta-llama/Llama-3.2-3B-Instruct","meta-llama/Llama-4-Scout-17B-16E-Instruct", "meta-llama/Meta-Llama-3-8B-Instruct", | |
"meta-llama/Llama-3.1-70B-Instruct"], | |
key='model_select' | |
) | |
temperature = st.slider( | |
"Temperature", | |
0.0, 1.0, 0.7, | |
help="Controls randomness in generation" | |
) | |
st.button("π New Conversation", on_click=reset_conversation) | |
# st.markdown("### π Instructions") | |
# embed_pdf_sidebar("Streamlit.pdf") | |
#Add PDF Instructions | |
# with st.expander("π Instructions"): | |
# st.write("View or download instruction guides:") | |
# # Option 1: Using st.download_button for PDFs stored in your app | |
# with open("file:///C:/Users/hp/Downloads/Streamlit.pdf", "rb") as file: | |
# first_pdf = file.read() | |
# st.download_button( | |
# label="Download Guide 1", | |
# data=first_pdf, | |
# file_name="user_guide.pdf", | |
# mime="application/pdf" | |
# ) | |
# #with open("https://huggingface.co/spaces/Wedyan2023/COPY/blob/main/Streamlit.pdf", "rb") as file: | |
# with open("file:///C:/Users/hp/Downloads/Streamlit.pdf", "rb") as file: | |
# second_pdf = file.read() | |
# st.download_button( | |
# label="Download Guide 2", | |
# data=second_pdf, | |
# file_name="technical_guide.pdf", | |
# mime="application/pdf" | |
# ) | |
with st.container(): | |
st.markdown(f""" | |
<div class="sidebar-info"> | |
<h4>Current Model: {selected_model}</h4> | |
<p><em>Note: Generated content may be inaccurate or false. Check important info.</em></p> | |
</div> | |
""", unsafe_allow_html=True) | |
# with st.sidebar: | |
# st.markdown("### π Instructions") | |
# if pdf_file := st.file_uploader("Upload Instruction PDF", type="pdf"): | |
# embed_pdf(pdf_file) | |
feedback_url = "https://docs.google.com/forms/d/e/1FAIpQLSdZ_5mwW-pjqXHgxR0xriyVeRhqdQKgb5c-foXlYAV55Rilsg/viewform?usp=header" | |
st.sidebar.markdown( | |
f'<a href="{feedback_url}" target="_blank"><button style="width: 100%;">Feedback Form</button></a>', | |
unsafe_allow_html=True | |
) | |
# Display conversation | |
for message in st.session_state.messages: | |
with st.chat_message(message["role"]): | |
st.markdown(message["content"]) | |
# Main content | |
if 'task_choice' not in st.session_state: | |
col1, col2 = st.columns(2) | |
with col1: | |
if st.button("π Data Generation", key="gen_button", help="Generate new data"): | |
st.session_state.task_choice = "Data Generation" | |
with col2: | |
if st.button("π·οΈ Data Labeling", key="label_button", help="Label existing data"): | |
st.session_state.task_choice = "Data Labeling" | |
if "task_choice" in st.session_state: | |
if st.session_state.task_choice == "Data Generation": | |
st.header("π Data Generation") | |
# 1. Domain selection | |
domain_selection = st.selectbox("Domain", [ | |
"Restaurant reviews", "E-Commerce reviews", "News", "AG News", "Tourism", "Custom" | |
]) | |
# 2. Handle custom domain input | |
custom_domain_valid = True # Assume valid until proven otherwise | |
if domain_selection == "Custom": | |
domain = st.text_input("Specify custom domain") | |
if not domain.strip(): | |
st.error("Please specify a domain name.") | |
custom_domain_valid = False | |
else: | |
domain = domain_selection | |
# Classification type selection | |
classification_type = st.selectbox( | |
"Classification Type", | |
["Sentiment Analysis", "Binary Classification", "Multi-Class Classification"] | |
) | |
#system role before | |
#### | |
# Labels setup based on classification type | |
#labels = [] | |
labels = [] | |
labels_valid = False | |
errors = [] | |
def validate_binary_labels(labels): | |
errors = [] | |
normalized = [label.strip().lower() for label in labels] | |
if not labels[0].strip(): | |
errors.append("First class name is required.") | |
if not labels[1].strip(): | |
errors.append("Second class name is required.") | |
if normalized[0] == normalized[1] and all(normalized): | |
errors.append("Class names must be different.") | |
return errors | |
if classification_type == "Sentiment Analysis": | |
st.write("### Sentiment Analysis Labels (Fixed)") | |
col1, col2, col3 = st.columns(3) | |
with col1: | |
st.text_input("First class", "Positive", disabled=True) | |
with col2: | |
st.text_input("Second class", "Negative", disabled=True) | |
with col3: | |
st.text_input("Third class", "Neutral", disabled=True) | |
labels = ["Positive", "Negative", "Neutral"] | |
elif classification_type == "Binary Classification": | |
st.write("### Binary Classification Labels") | |
col1, col2 = st.columns(2) | |
with col1: | |
label_1 = st.text_input("First class", "Positive") | |
with col2: | |
label_2 = st.text_input("Second class", "Negative") | |
labels = [label_1, label_2] | |
errors = validate_binary_labels(labels) | |
if errors: | |
st.error("\n".join(errors)) | |
else: | |
st.success("Binary class names are valid and unique!") | |
# if classification_type == "Sentiment Analysis": | |
# st.write("### Sentiment Analysis Labels (Fixed)") | |
# col1, col2, col3 = st.columns(3) | |
# with col1: | |
# label_1 = st.text_input("First class", "Positive", disabled=True) | |
# with col2: | |
# label_2 = st.text_input("Second class", "Negative", disabled=True) | |
# with col3: | |
# label_3 = st.text_input("Third class", "Neutral", disabled=True) | |
# labels = ["Positive", "Negative", "Neutral"] | |
# elif classification_type == "Binary Classification": | |
# st.write("### Binary Classification Labels") | |
# col1, col2 = st.columns(2) | |
# with col1: | |
# label_1 = st.text_input("First class", "Positive") | |
# with col2: | |
# label_2 = st.text_input("Second class", "Negative") | |
# errors = [] | |
# labels = [label_1.strip(), label_2.strip()] | |
# # Check for empty class names | |
# if not labels[0]: | |
# errors.append("First class name is required.") | |
# if not labels[1]: | |
# errors.append("Second class name is required.") | |
# # Check for duplicates | |
# if labels[0].lower() == labels[1].lower(): | |
# errors.append("Class names must be different.") | |
# # Show errors or success | |
# if errors: | |
# for error in errors: | |
# st.error(error) | |
# else: | |
# st.success("Binary class names are valid and unique!") | |
######### | |
elif classification_type == "Multi-Class Classification": | |
st.write("### Multi-Class Classification Labels") | |
default_labels_by_domain = { | |
"News": ["Political", "Sports", "Entertainment", "Technology", "Business"], | |
"AG News": ["World", "Sports", "Business", "Sci/Tech"], | |
"Tourism": ["Accommodation", "Transportation", "Tourist Attractions", | |
"Food & Dining", "Local Experience", "Adventure Activities", | |
"Wellness & Spa", "Eco-Friendly Practices", "Family-Friendly", | |
"Luxury Tourism"], | |
"Restaurant reviews": ["Italian", "French", "American"], | |
"E-Commerce reviews": ["Mobile Phones & Accessories", "Laptops & Computers","Kitchen & Dining", | |
"Beauty & Personal Care", "Home & Furniture", "Clothing & Fashion", | |
"Shoes & Handbags", "Health & Wellness", "Electronics & Gadgets", | |
"Books & Stationery","Toys & Games", "Sports & Fitness", | |
"Grocery & Gourmet Food","Watches & Accessories", "Baby Products"] | |
} | |
num_classes = st.slider("Number of classes", 3, 15, 3) | |
# Get defaults for selected domain, or empty list | |
defaults = default_labels_by_domain.get(domain, []) | |
labels = [] | |
errors = [] | |
cols = st.columns(3) | |
for i in range(num_classes): | |
with cols[i % 3]: | |
default_value = defaults[i] if i < len(defaults) else "" | |
label_input = st.text_input(f"Class {i+1}", default_value) | |
normalized_label = label_input.strip().title() | |
if not normalized_label: | |
errors.append(f"Class {i+1} name is required.") | |
else: | |
labels.append(normalized_label) | |
# Check for duplicates (case-insensitive) | |
if len(labels) != len(set(labels)): | |
errors.append("Labels names must be unique (case-insensitive, normalized to Title Case).") | |
# Show validation results | |
if errors: | |
for error in errors: | |
st.error(error) | |
else: | |
st.success("All Labels names are valid and unique!") | |
labels_valid = not errors # Will be True only if there are no label errors | |
############## | |
# Generation parameters | |
col1, col2 = st.columns(2) | |
with col1: | |
min_words = st.number_input("Min words", 1, 100, 20) | |
with col2: | |
max_words = st.number_input("Max words", min_words, 100, 50) | |
# Few-shot examples | |
use_few_shot = st.toggle("Use few-shot examples") | |
few_shot_examples = [] | |
if use_few_shot: | |
num_examples = st.slider("Number of few-shot examples", 1, 10, 1) | |
for i in range(num_examples): | |
with st.expander(f"Example {i+1}"): | |
content = st.text_area(f"Content", key=f"few_shot_content_{i}") | |
label = st.selectbox(f"Label", labels, key=f"few_shot_label_{i}") | |
if content and label: | |
few_shot_examples.append({"content": content, "label": label}) | |
num_to_generate = st.number_input("Number of examples", 1, 200, 10) | |
#sytem role after | |
# System role customization | |
default_system_role = f"You are a professional {classification_type} expert, your role is to generate text examples for {domain} domain. Always generate unique diverse examples and do not repeat the generated data. The generated text should be between {min_words} to {max_words} words long." | |
system_role = st.text_area("Modify System Role (optional)", | |
value=default_system_role, | |
key="system_role_input") | |
st.session_state['system_role'] = system_role if system_role else default_system_role | |
# Labels initialization | |
#labels = [] | |
user_prompt = st.text_area("User Prompt (optional)") | |
# Updated prompt template including system role | |
prompt_template = PromptTemplate( | |
input_variables=["system_role", "classification_type", "domain", "num_examples", | |
"min_words", "max_words", "labels", "user_prompt", "few_shot_examples"], | |
template=( | |
"{system_role}\n" | |
"- Use the following parameters:\n" | |
"- Generate {num_examples} examples\n" | |
"- Each example should be between {min_words} to {max_words} words long\n" | |
#"- Word range: {min_words} - {max_words} words\n " | |
"- Use these labels: {labels}.\n" | |
"- Generate the examples in this format: 'Example text. Label: label'\n" | |
"- Do not include word counts or any additional information\n" | |
"- Always use your creativity and intelligence to generate unique and diverse text data\n" | |
"- Write unique examples every time.\n" | |
"- DO NOT REPEAT your gnerated text. \n" | |
"- For each Output, describe it once and move to the next.\n" | |
"- List each Output only once, and avoid repeating details.\n" | |
"- Additional instructions: {user_prompt}\n\n" | |
"- Use the following examples as a reference in the generation process\n\n {few_shot_examples}. \n" | |
"- Think step by step, generate numbered examples, and check each newly generated example to ensure it has not been generated before. If it has, modify it" | |
#"- Think step by step, generate numbered examples and check every new generated example if it is generated before and change it." | |
) | |
) | |
# Generate system prompt | |
system_prompt = prompt_template.format( | |
system_role=st.session_state['system_role'], | |
classification_type=classification_type, | |
domain=domain, | |
num_examples=num_to_generate, | |
min_words=min_words, | |
max_words=max_words, | |
labels=", ".join(labels), | |
user_prompt=user_prompt, | |
few_shot_examples="\n".join([f"{ex['content']}\nLabel: {ex['label']}" for ex in few_shot_examples]) if few_shot_examples else "" | |
) | |
# Store system prompt in session state | |
st.session_state['system_prompt'] = system_prompt | |
# Display system prompt | |
st.write("System Prompt:") | |
st.text_area("Current System Prompt", value=st.session_state['system_prompt'], | |
height=400, disabled=True) | |
if st.button("π― Generate Examples"): | |
# | |
errors = [] | |
if domain_selection == "Custom" and not domain.strip(): | |
st.warning("Custom domain name is required.") | |
elif len(labels) != len(set(labels)): | |
st.warning("Class names must be unique.") | |
elif any(not lbl.strip() for lbl in labels): | |
st.warning("All class labels must be filled in.") | |
#else: | |
#st.success("Generating examples for domain: {domain}") | |
#if not custom_domain_valid: | |
#st.warning("Custom domain name is required.") | |
#elif not labels_valid: | |
#st.warning("Please fix the label errors before generating examples.") | |
#else: | |
# Proceed to generate examples | |
#st.success(f"Generating examples for domain: {domain}") | |
with st.spinner("Generating examples..."): | |
try: | |
stream = client.chat.completions.create( | |
model=selected_model, | |
messages=[{"role": "system", "content": st.session_state['system_prompt']}], | |
temperature=temperature, | |
stream=True, | |
max_tokens=80000, | |
top_p=0.9, | |
# repetition_penalty=1.2, | |
#frequency_penalty=0.5, # Discourages frequent words | |
#presence_penalty=0.6, | |
) | |
#st.session_state['system_prompt'] = system_prompt | |
#new 24 march | |
st.session_state.messages.append({"role": "user", "content": system_prompt}) | |
# # #################### | |
response = st.write_stream(stream) | |
st.session_state.messages.append({"role": "assistant", "content": response}) | |
# Initialize session state variables if they don't exist | |
if 'system_prompt' not in st.session_state: | |
st.session_state.system_prompt = system_prompt | |
if 'response' not in st.session_state: | |
st.session_state.response = response | |
if 'generated_examples' not in st.session_state: | |
st.session_state.generated_examples = [] | |
if 'generated_examples_csv' not in st.session_state: | |
st.session_state.generated_examples_csv = None | |
if 'generated_examples_json' not in st.session_state: | |
st.session_state.generated_examples_json = None | |
# Parse response and generate examples list | |
examples_list = [] | |
for line in response.split('\n'): | |
if line.strip(): | |
parts = line.rsplit('Label:', 1) | |
if len(parts) == 2: | |
text = parts[0].strip() | |
label = parts[1].strip() | |
if text and label: | |
examples_list.append({ | |
'text': text, | |
'label': label, | |
'system_prompt': st.session_state.system_prompt, | |
'system_role': st.session_state.system_role, | |
'task_type': 'Data Generation', | |
'Use few-shot example?': 'Yes' if use_few_shot else 'No', | |
}) | |
if examples_list: | |
# Update session state with new data | |
st.session_state.generated_examples = examples_list | |
# Generate CSV and JSON data | |
df = pd.DataFrame(examples_list) | |
st.session_state.generated_examples_csv = df.to_csv(index=False).encode('utf-8') | |
st.session_state.generated_examples_json = json.dumps(examples_list, indent=2).encode('utf-8') | |
# Vertical layout with centered "or" between buttons | |
st.download_button( | |
"π₯ Download Generated Examples (CSV)", | |
st.session_state.generated_examples_csv, | |
"generated_examples.csv", | |
"text/csv", | |
key='download-csv-persistent' | |
) | |
# Add space and center the "or" | |
st.markdown(""" | |
<div style='text-align: left; margin:15px 0; font-weight: 600; color: #666;'>. . . . . . or</div> | |
""", unsafe_allow_html=True) | |
st.download_button( | |
"π₯ Download Generated Examples (JSON)", | |
st.session_state.generated_examples_json, | |
"generated_examples.json", | |
"application/json", | |
key='download-json-persistent' | |
) | |
# # Display the labeled examples | |
# st.markdown("##### π Labeled Examples Preview") | |
# st.dataframe(df, use_container_width=True) | |
if st.button("Continue"): | |
if follow_up == "Generate more examples": | |
st.experimental_rerun() | |
elif follow_up == "Data Labeling": | |
st.session_state.task_choice = "Data Labeling" | |
st.experimental_rerun() | |
except Exception as e: | |
st.error("An error occurred during generation.") | |
st.error(f"Details: {e}") | |
# Lableing Process | |
elif st.session_state.task_choice == "Data Labeling": | |
st.header("π·οΈ Data Labeling") | |
#new new new | |
# 1. Domain selection | |
# 1. Domain selection | |
domain_selection = st.selectbox("Domain", ["Restaurant reviews", "E-Commerce reviews", "News", "AG News", "Tourism", "Custom"]) | |
# 2. Handle custom domain input | |
custom_domain_valid = True # Assume valid until proven otherwise | |
if domain_selection == "Custom": | |
domain = st.text_input("Specify custom domain") | |
if not domain.strip(): | |
st.error("Please specify a domain name.") | |
custom_domain_valid = False | |
else: | |
domain = domain_selection | |
# # Classification type selection | |
# classification_type = st.selectbox( | |
# "Classification Type", | |
# ["Sentiment Analysis", "Binary Classification", "Multi-Class Classification"] | |
# ) | |
#NNew edit | |
# classification_type = st.selectbox( | |
# "Classification Type", | |
# #["Sentiment Analysis", "Binary Classification", "Multi-Class Classification", "Named Entity Recognition (NER)"], | |
# ["Sentiment Analysis", "Binary Classification", "Multi-Class Classification"], | |
# key="label_class_type" | |
# ) | |
# Classification type selection | |
classification_type = st.selectbox( | |
"Classification Type", | |
["Sentiment Analysis", "Binary Classification", "Multi-Class Classification", "Named Entity Recognition (NER)"] | |
) | |
#NNew edit | |
# Labels setup based on classification type | |
labels = [] | |
labels_valid = False | |
errors = [] | |
if classification_type == "Sentiment Analysis": | |
st.write("### Sentiment Analysis Labels (Fixed)") | |
col1, col2, col3 = st.columns(3) | |
with col1: | |
label_1 = st.text_input("First class", "Positive", disabled=True) | |
with col2: | |
label_2 = st.text_input("Second class", "Negative", disabled=True) | |
with col3: | |
label_3 = st.text_input("Third class", "Neutral", disabled=True) | |
labels = ["Positive", "Negative", "Neutral"] | |
elif classification_type == "Binary Classification": | |
st.write("### Binary Classification Labels") | |
col1, col2 = st.columns(2) | |
with col1: | |
label_1 = st.text_input("First class", "Positive") | |
with col2: | |
label_2 = st.text_input("Second class", "Negative") | |
errors = [] | |
labels = [label_1.strip(), label_2.strip()] | |
# Strip and lower-case labels for validation | |
label_1 = labels[0].strip() | |
label_2 = labels[1].strip() | |
# Check for empty class names | |
if not label_1: | |
errors.append("First class name is required.") | |
if not label_2: | |
errors.append("Second class name is required.") | |
# Check for duplicates (case insensitive) | |
if label_1.lower() == label_2.lower() and label_1 and label_2: | |
errors.append("Class names must be different.") | |
# Show errors or success | |
if errors: | |
for error in errors: | |
st.error(error) | |
else: | |
st.success("Binary class names are valid and unique!") | |
elif classification_type == "Multi-Class Classification": | |
st.write("### Multi-Class Classification Labels") | |
default_labels_by_domain = { | |
"News": ["Political", "Sports", "Entertainment", "Technology", "Business"], | |
"AG News": ["World", "Sports", "Business", "Sci/Tech"], | |
"Tourism": ["Accommodation", "Transportation", "Tourist Attractions", | |
"Food & Dining", "Local Experience", "Adventure Activities", | |
"Wellness & Spa", "Eco-Friendly Practices", "Family-Friendly", | |
"Luxury Tourism"], | |
"Restaurant reviews": ["Italian", "French", "American"], | |
"E-Commerce reviews": ["Mobile Phones & Accessories", "Laptops & Computers","Kitchen & Dining", | |
"Beauty & Personal Care", "Home & Furniture", "Clothing & Fashion", | |
"Shoes & Handbags", "Health & Wellness", "Electronics & Gadgets", | |
"Books & Stationery","Toys & Games", "Sports & Fitness", | |
"Grocery & Gourmet Food","Watches & Accessories", "Baby Products"] | |
} | |
# Ask user how many classes they want to define | |
num_classes = st.slider("Select the number of classes (labels)", min_value=3, max_value=10, value=3) | |
# Use default labels based on selected domain, if available | |
defaults = default_labels_by_domain.get(domain, []) | |
labels = [] | |
errors = [] | |
cols = st.columns(3) # For nicely arranged label inputs | |
for i in range(num_classes): | |
with cols[i % 3]: # Distribute inputs across columns | |
default_value = defaults[i] if i < len(defaults) else "" | |
label_input = st.text_input(f"Label {i + 1}", default_value) | |
normalized_label = label_input.strip().title() | |
if not normalized_label: | |
errors.append(f"Label {i + 1} is required.") | |
else: | |
labels.append(normalized_label) | |
# Check for duplicates (case-insensitive) | |
normalized_set = {label.lower() for label in labels} | |
if len(labels) != len(normalized_set): | |
errors.append("Label names must be unique (case-insensitive).") | |
# Show validation results | |
if errors: | |
for error in errors: | |
st.error(error) | |
else: | |
st.success("All label names are valid and unique!") | |
labels_valid = not errors # True if no validation errors | |
elif classification_type == "Named Entity Recognition (NER)": | |
# NER entity options | |
ner_entities = [ | |
"PERSON - Names of people, fictional characters, historical figures", | |
"ORG - Companies, institutions, agencies, teams", | |
"LOC - Physical locations (mountains, oceans, etc.)", | |
"GPE - Countries, cities, states, political regions", | |
"DATE - Calendar dates, years, centuries", | |
"TIME - Times, durations", | |
"MONEY - Monetary values with currency" | |
] | |
selected_entities = st.multiselect( | |
"Select entities to recognize", | |
ner_entities, | |
default=["PERSON - Names of people, fictional characters, historical figures", | |
"ORG - Companies, institutions, agencies, teams", | |
"LOC - Physical locations (mountains, oceans, etc.)", | |
"GPE - Countries, cities, states, political regions", | |
"DATE - Calendar dates, years, centuries", | |
"TIME - Times, durations", | |
"MONEY - Monetary values with currency"], | |
key="ner_entity_selection" | |
) | |
# Extract just the entity type (before the dash) | |
labels = [entity.split(" - ")[0] for entity in selected_entities] | |
if not labels: | |
st.warning("Please select at least one entity type") | |
labels = ["PERSON"] # Default if nothing selected | |
#NNew edit | |
# elif classification_type == "Multi-Class Classification": | |
# st.write("### Multi-Class Classification Labels") | |
# default_labels_by_domain = { | |
# "News": ["Political", "Sports", "Entertainment", "Technology", "Business"], | |
# "AG News": ["World", "Sports", "Business", "Sci/Tech"], | |
# "Tourism": ["Accommodation", "Transportation", "Tourist Attractions", | |
# "Food & Dining", "Local Experience", "Adventure Activities", | |
# "Wellness & Spa", "Eco-Friendly Practices", "Family-Friendly", | |
# "Luxury Tourism"], | |
# "Restaurant reviews": ["Italian", "French", "American"] | |
# } | |
# num_classes = st.slider("Number of classes", 3, 10, 3) | |
# # Get defaults for selected domain, or empty list | |
# defaults = default_labels_by_domain.get(domain, []) | |
# labels = [] | |
# errors = [] | |
# cols = st.columns(3) | |
# for i in range(num_classes): | |
# with cols[i % 3]: | |
# default_value = defaults[i] if i < len(defaults) else "" | |
# label_input = st.text_input(f"Class {i+1}", default_value) | |
# normalized_label = label_input.strip().title() | |
# if not normalized_label: | |
# errors.append(f"Class {i+1} name is required.") | |
# else: | |
# labels.append(normalized_label) | |
# # Check for duplicates (case-insensitive) | |
# if len(labels) != len(set(labels)): | |
# errors.append("Labels names must be unique (case-insensitive, normalized to Title Case).") | |
# # Show validation results | |
# if errors: | |
# for error in errors: | |
# st.error(error) | |
# else: | |
# st.success("All Labels names are valid and unique!") | |
# labels_valid = not errors # Will be True only if there are no label errors | |
# else: | |
# num_classes = st.slider("Number of classes", 3, 23, 3, key="label_num_classes") | |
# labels = [] | |
# cols = st.columns(3) | |
# for i in range(num_classes): | |
# with cols[i % 3]: | |
# label = st.text_input(f"Class {i+1}", f"Class_{i+1}", key=f"label_class_{i}") | |
# labels.append(label) | |
use_few_shot = st.toggle("Use few-shot examples for labeling") | |
few_shot_examples = [] | |
if use_few_shot: | |
num_few_shot = st.slider("Number of few-shot examples", 1, 10, 1) | |
for i in range(num_few_shot): | |
with st.expander(f"Few-shot Example {i+1}"): | |
content = st.text_area(f"Content", key=f"label_few_shot_content_{i}") | |
label = st.selectbox(f"Label", labels, key=f"label_few_shot_label_{i}") | |
if content and label: | |
few_shot_examples.append(f"{content}\nLabel: {label}") | |
num_examples = st.number_input("Number of examples to classify", 1, 100, 1) | |
examples_to_classify = [] | |
if num_examples <= 20: | |
for i in range(num_examples): | |
example = st.text_area(f"Example {i+1}", key=f"example_{i}") | |
if example: | |
examples_to_classify.append(example) | |
else: | |
examples_text = st.text_area( | |
"Enter examples (one per line)", | |
height=300, | |
help="Enter each example on a new line" | |
) | |
if examples_text: | |
examples_to_classify = [ex.strip() for ex in examples_text.split('\n') if ex.strip()] | |
if len(examples_to_classify) > num_examples: | |
examples_to_classify = examples_to_classify[:num_examples] | |
#New Wedyan | |
default_system_role = f"You are a professional {classification_type} expert, your role is to classify the provided text examples for {domain} domain." | |
system_role = st.text_area("Modify System Role (optional)", | |
value=default_system_role, | |
key="system_role_input") | |
st.session_state['system_role'] = system_role if system_role else default_system_role | |
# Labels initialization | |
#labels = [] | |
#### | |
user_prompt = st.text_area("User prompt (optional)", key="label_instructions") | |
few_shot_text = "\n\n".join(few_shot_examples) if few_shot_examples else "" | |
examples_text = "\n".join([f"{i+1}. {ex}" for i, ex in enumerate(examples_to_classify)]) | |
# Customize prompt template based on classification type | |
if classification_type == "Named Entity Recognition (NER)": | |
label_prompt_template = PromptTemplate( | |
input_variables=["system_role", "labels", "few_shot_examples", "examples", "domain", "user_prompt"], | |
template=( | |
"{system_role}\n" | |
#"- You are a professional Named Entity Recognition (NER) expert in {domain} domain. Your role is to identify and extract the following entity types: {labels}.\n" | |
"- For each text example provided, identify all entities of the requested types.\n" | |
"- Use the following entities: {labels}.\n" | |
"- Return each example followed by the entities you found in this format: 'Example text.\n Entities: [ENTITY_TYPE: entity text\n, ENTITY_TYPE: entity text\n, ...] or [No entities found]'\n" | |
"- If no entities of the requested types are found, indicate 'No entities found' in this text.\n" | |
"- Be precise about entity boundaries - don't include unnecessary words.\n" | |
"- Do not provide any additional information or explanations.\n" | |
"- Additional instructions:\n {user_prompt}\n\n" | |
"- Use user few-shot examples as guidance if provided:\n{few_shot_examples}\n\n" | |
"- Examples to analyze:\n{examples}\n\n" | |
"Output:\n" | |
) | |
) | |
else: | |
label_prompt_template = PromptTemplate( | |
input_variables=["system_role", "classification_type", "labels", "few_shot_examples", "examples","domain", "user_prompt"], | |
template=( | |
#"- Let'\s think step by step:" | |
"{system_role}\n" | |
# "- You are a professional {classification_type} expert in {domain} domain. Your role is to classify the following examples using these labels: {labels}.\n" | |
"- Use the following instructions:\n" | |
"- Use the following labels: {labels}.\n" | |
"- Return the classified text followed by the label in this format: 'text. Label: [label]'\n" | |
"- Do not provide any additional information or explanations\n" | |
"- User prompt:\n {user_prompt}\n\n" | |
"- Use user provided examples as guidence in the classification process:\n\n {few_shot_examples}\n" | |
"- Examples to classify:\n{examples}\n\n" | |
"- Think step by step then classify the examples" | |
#"Output:\n" | |
)) | |
# Check if few_shot_examples is already a formatted string | |
# Check if few_shot_examples is already a formatted string | |
if isinstance(few_shot_examples, str): | |
formatted_few_shot = few_shot_examples | |
# If it's a list of already formatted strings | |
elif isinstance(few_shot_examples, list) and all(isinstance(ex, str) for ex in few_shot_examples): | |
formatted_few_shot = "\n".join(few_shot_examples) | |
# If it's a list of dictionaries with 'content' and 'label' keys | |
elif isinstance(few_shot_examples, list) and all(isinstance(ex, dict) and 'content' in ex and 'label' in ex for ex in few_shot_examples): | |
formatted_few_shot = "\n".join([f"{ex['content']}\nLabel: {ex['label']}" for ex in few_shot_examples]) | |
else: | |
formatted_few_shot = "" | |
system_prompt = label_prompt_template.format( | |
system_role=st.session_state['system_role'], | |
classification_type=classification_type, | |
domain=domain, | |
examples="\n".join(examples_to_classify), | |
labels=", ".join(labels), | |
user_prompt=user_prompt, | |
few_shot_examples=formatted_few_shot | |
) | |
# Step 2: Store the system_prompt in st.session_state | |
st.session_state['system_prompt'] = system_prompt | |
#::contentReference[oaicite:0]{index=0} | |
st.write("System Prompt:") | |
#st.code(system_prompt) | |
#st.code(st.session_state['system_prompt']) | |
st.text_area("System Prompt", value=st.session_state['system_prompt'], height=300, max_chars=None, key=None, help=None, disabled=True) | |
if st.button("π·οΈ Label Data"): | |
if examples_to_classify: | |
with st.spinner("Labeling data..."): | |
# Generate the system prompt based on classification type | |
if classification_type == "Named Entity Recognition (NER)": | |
system_prompt = label_prompt_template.format( | |
system_role=st.session_state['system_role'], | |
labels=", ".join(labels), | |
domain = domain, | |
few_shot_examples=few_shot_text, | |
examples=examples_text, | |
user_prompt=user_prompt | |
) | |
else: | |
system_prompt = label_prompt_template.format( | |
classification_type=classification_type, | |
system_role=st.session_state['system_role'], | |
domain = domain, | |
labels=", ".join(labels), | |
few_shot_examples=few_shot_text, | |
examples=examples_text, | |
user_prompt=user_prompt | |
) | |
try: | |
stream = client.chat.completions.create( | |
model=selected_model, | |
messages=[{"role": "system", "content": system_prompt}], | |
temperature=temperature, | |
stream=True, | |
max_tokens=20000, | |
top_p = 0.9, | |
) | |
#new 24 March | |
# Append user message | |
st.session_state.messages.append({"role": "user", "content": system_prompt}) | |
################# | |
response = st.write_stream(stream) | |
st.session_state.messages.append({"role": "assistant", "content": response}) | |
# Display the labeled examples | |
# # Optional: If you want to add it as a chat-style message log | |
# preview_str = st.session_state.labeled_preview.to_markdown(index=False) | |
# st.session_state.messages.append({"role": "assistant", "content": f"Here is a preview of the labeled examples:\n\n{preview_str}"}) | |
# # Stream response and append assistant message | |
# #14/4/2024 | |
# response = st.write_stream(stream) | |
# st.session_state.messages.append({"role": "assistant", "content": response}) | |
# Initialize session state variables if they don't exist | |
if 'system_prompt' not in st.session_state: | |
st.session_state.system_prompt = system_prompt | |
if 'response' not in st.session_state: | |
st.session_state.response = response | |
if 'generated_examples' not in st.session_state: | |
st.session_state.generated_examples = [] | |
if 'generated_examples_csv' not in st.session_state: | |
st.session_state.generated_examples_csv = None | |
if 'generated_examples_json' not in st.session_state: | |
st.session_state.generated_examples_json = None | |
# Save labeled examples to CSV | |
#new 14/4/2025 | |
labeled_examples = [] | |
if classification_type == "Named Entity Recognition (NER)": | |
labeled_examples = [] | |
for line in response.split('\n'): | |
if line.strip(): | |
parts = line.rsplit('Entities:', 1) | |
if len(parts) == 2: | |
text = parts[0].strip() | |
entities = parts[1].strip() | |
if text and entities: | |
labeled_examples.append({ | |
'text': text, | |
'entities': entities, | |
'system_prompt': st.session_state.system_prompt, | |
'system_role': st.session_state.system_role, | |
'task_type': 'Named Entity Recognition (NER)', | |
'Use few-shot example?': 'Yes' if use_few_shot else 'No', | |
}) | |
else: | |
labeled_examples = [] | |
for line in response.split('\n'): | |
if line.strip(): | |
parts = line.rsplit('Label:', 1) | |
if len(parts) == 2: | |
text = parts[0].strip() | |
label = parts[1].strip() | |
if text and label: | |
labeled_examples.append({ | |
'text': text, | |
'label': label, | |
'system_prompt': st.session_state.system_prompt, | |
'system_role': st.session_state.system_role, | |
'task_type': 'Data Labeling', | |
'Use few-shot example?': 'Yes' if use_few_shot else 'No', | |
}) | |
# Save and provide download options | |
if labeled_examples: | |
# Update session state | |
st.session_state.labeled_examples = labeled_examples | |
# Convert to CSV and JSON | |
df = pd.DataFrame(labeled_examples) | |
st.session_state.labeled_examples_csv = df.to_csv(index=False).encode('utf-8') | |
st.session_state.labeled_examples_json = json.dumps(labeled_examples, indent=2).encode('utf-8') | |
# Download buttons | |
st.download_button( | |
"π₯ Download Labeled Examples (CSV)", | |
st.session_state.labeled_examples_csv, | |
"labeled_examples.csv", | |
"text/csv", | |
key='download-labeled-csv' | |
) | |
st.markdown(""" | |
<div style='text-align: left; margin:15px 0; font-weight: 600; color: #666;'>. . . . . . or</div> | |
""", unsafe_allow_html=True) | |
st.download_button( | |
"π₯ Download Labeled Examples (JSON)", | |
st.session_state.labeled_examples_json, | |
"labeled_examples.json", | |
"application/json", | |
key='download-labeled-json' | |
) | |
# Display the labeled examples | |
st.markdown("##### π Labeled Examples Preview") | |
st.dataframe(df, use_container_width=True) | |
# Display section | |
#st.markdown("### π Labeled Examples Preview") | |
#st.dataframe(st.session_state.labeled_preview, use_container_width=True) | |
# if labeled_examples: | |
# df = pd.DataFrame(labeled_examples) | |
# csv = df.to_csv(index=False).encode('utf-8') | |
# st.download_button( | |
# "π₯ Download Labeled Examples", | |
# csv, | |
# "labeled_examples.csv", | |
# "text/csv", | |
# key='download-labeled-csv' | |
# ) | |
# # Add space and center the "or" | |
# st.markdown(""" | |
# <div style='text-align: left; margin:15px 0; font-weight: 600; color: #666;'>. . . . . . or</div> | |
# """, unsafe_allow_html=True) | |
# if labeled_examples: | |
# df = pd.DataFrame(labeled_examples) | |
# csv = df.to_csv(index=False).encode('utf-8') | |
# st.download_button( | |
# "π₯ Download Labeled Examples", | |
# csv, | |
# "labeled_examples.json", | |
# "text/json", | |
# key='download-labeled-JSON' | |
# ) | |
# Add follow-up interaction options | |
#st.markdown("---") | |
#follow_up = st.radio( | |
#"What would you like to do next?", | |
#["Label more data", "Data Generation"], | |
# key="labeling_follow_up" | |
# ) | |
if st.button("Continue"): | |
if follow_up == "Label more data": | |
st.session_state.examples_to_classify = [] | |
st.experimental_rerun() | |
elif follow_up == "Data Generation": | |
st.session_state.task_choice = "Data Labeling" | |
st.experimental_rerun() | |
except Exception as e: | |
st.error("An error occurred during labeling.") | |
st.error(f"Details: {e}") | |
else: | |
st.warning("Please enter at least one example to classify.") | |
#st.session_state.messages.append({"role": "assistant", "content": response}) | |
# Footer | |
st.markdown("---") | |
st.markdown( | |
""" | |
<div style='text-align: center'> | |
<p>Made with β€οΈ by Wedyan AlSakran 2025</p> | |
</div> | |
""", | |
unsafe_allow_html=True | |
) |