|
import streamlit as st |
|
import pandas as pd |
|
import os |
|
import json |
|
import base64 |
|
import random |
|
from streamlit_pdf_viewer import pdf_viewer |
|
from langchain.prompts import PromptTemplate |
|
from datetime import datetime |
|
from pathlib import Path |
|
from openai import OpenAI |
|
from dotenv import load_dotenv |
|
import warnings |
|
|
|
warnings.filterwarnings('ignore') |
|
|
|
os.getenv("OAUTH_CLIENT_ID") |
|
|
|
|
|
|
|
load_dotenv() |
|
client = OpenAI( |
|
base_url="https://api-inference.huggingface.co/v1", |
|
api_key=os.environ.get('TOKEN2') |
|
) |
|
|
|
|
|
for dir_name in ['data', 'feedback']: |
|
if not os.path.exists(dir_name): |
|
os.makedirs(dir_name) |
|
|
|
|
|
st.markdown(""" |
|
<style> |
|
.stButton > button { |
|
width: 100%; |
|
margin-bottom: 10px; |
|
background-color: #4CAF50; |
|
color: white; |
|
border: none; |
|
padding: 10px; |
|
border-radius: 5px; |
|
} |
|
.task-button { |
|
background-color: #2196F3 !important; |
|
} |
|
.stSelectbox { |
|
margin-bottom: 20px; |
|
} |
|
.output-container { |
|
padding: 20px; |
|
border-radius: 5px; |
|
border: 1px solid #ddd; |
|
margin: 10px 0; |
|
} |
|
.status-container { |
|
padding: 10px; |
|
border-radius: 5px; |
|
margin: 10px 0; |
|
} |
|
.sidebar-info { |
|
padding: 10px; |
|
background-color: #f0f2f6; |
|
border-radius: 5px; |
|
margin: 10px 0; |
|
} |
|
.feedback-button { |
|
background-color: #ff9800 !important; |
|
} |
|
.feedback-container { |
|
padding: 15px; |
|
background-color: #f5f5f5; |
|
border-radius: 5px; |
|
margin: 15px 0; |
|
} |
|
</style> |
|
""", unsafe_allow_html=True) |
|
|
|
|
|
def read_csv_with_encoding(file): |
|
encodings = ['utf-8', 'latin1', 'iso-8859-1', 'cp1252'] |
|
for encoding in encodings: |
|
try: |
|
return pd.read_csv(file, encoding=encoding) |
|
except UnicodeDecodeError: |
|
continue |
|
raise UnicodeDecodeError("Failed to read file with any supported encoding") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def reset_conversation(): |
|
st.session_state.conversation = [] |
|
st.session_state.messages = [] |
|
if 'task_choice' in st.session_state: |
|
del st.session_state.task_choice |
|
return None |
|
|
|
|
|
|
|
|
|
|
|
if "messages" not in st.session_state: |
|
st.session_state.messages = [] |
|
if "examples_to_classify" not in st.session_state: |
|
st.session_state.examples_to_classify = [] |
|
if "system_role" not in st.session_state: |
|
st.session_state.system_role = "" |
|
|
|
|
|
|
|
|
|
st.title("π€π¦ Text Data Labeling and Generation App") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
with st.sidebar: |
|
st.title("βοΈ Settings") |
|
|
|
|
|
|
|
with st.sidebar: |
|
st.markdown("### πData Generation and Labeling Instructions") |
|
|
|
with open("User instructions.pdf", "rb") as f: |
|
st.download_button( |
|
label="π Download Instructions PDF", |
|
data=f, |
|
|
|
file_name="User instructions.pdf", |
|
mime="application/pdf" |
|
) |
|
|
|
selected_model = st.selectbox( |
|
"Select Model", |
|
["meta-llama/Llama-3.3-70B-Instruct", "meta-llama/Llama-3.2-3B-Instruct","meta-llama/Llama-4-Scout-17B-16E-Instruct", "meta-llama/Meta-Llama-3-8B-Instruct", |
|
"meta-llama/Llama-3.1-70B-Instruct"], |
|
key='model_select' |
|
) |
|
|
|
temperature = st.slider( |
|
"Temperature", |
|
0.0, 1.0, 0.7, |
|
help="Controls randomness in generation" |
|
) |
|
|
|
st.button("π New Conversation", on_click=reset_conversation) |
|
with st.container(): |
|
st.markdown(f""" |
|
<div class="sidebar-info"> |
|
<h4>Current Model: {selected_model}</h4> |
|
<p><em>Note: Generated content may be inaccurate or false. Check important info.</em></p> |
|
</div> |
|
""", unsafe_allow_html=True) |
|
|
|
feedback_url = "https://docs.google.com/forms/d/e/1FAIpQLSdZ_5mwW-pjqXHgxR0xriyVeRhqdQKgb5c-foXlYAV55Rilsg/viewform?usp=header" |
|
st.sidebar.markdown( |
|
f'<a href="{feedback_url}" target="_blank"><button style="width: 100%;">Feedback Form</button></a>', |
|
unsafe_allow_html=True |
|
) |
|
|
|
|
|
for message in st.session_state.messages: |
|
with st.chat_message(message["role"]): |
|
st.markdown(message["content"]) |
|
|
|
|
|
if 'task_choice' not in st.session_state: |
|
col1, col2 = st.columns(2) |
|
with col1: |
|
if st.button("π Data Generation", key="gen_button", help="Generate new data"): |
|
st.session_state.task_choice = "Data Generation" |
|
with col2: |
|
if st.button("π·οΈ Data Labeling", key="label_button", help="Label existing data"): |
|
st.session_state.task_choice = "Data Labeling" |
|
|
|
if "task_choice" in st.session_state: |
|
if st.session_state.task_choice == "Data Generation": |
|
st.header("π Data Generation") |
|
|
|
|
|
domain_selection = st.selectbox("Domain", [ |
|
"Restaurant reviews", "E-Commerce reviews", "News", "AG News", "Tourism", "Custom" |
|
]) |
|
|
|
|
|
custom_domain_valid = True |
|
|
|
if domain_selection == "Custom": |
|
domain = st.text_input("Specify custom domain") |
|
if not domain.strip(): |
|
st.error("Please specify a domain name.") |
|
custom_domain_valid = False |
|
else: |
|
domain = domain_selection |
|
|
|
|
|
classification_type = st.selectbox( |
|
"Classification Type", |
|
["Sentiment Analysis", "Binary Classification", "Multi-Class Classification"] |
|
) |
|
|
|
|
|
labels = [] |
|
labels_valid = False |
|
errors = [] |
|
|
|
def validate_binary_labels(labels): |
|
errors = [] |
|
normalized = [label.strip().lower() for label in labels] |
|
|
|
if not labels[0].strip(): |
|
errors.append("First class name is required.") |
|
if not labels[1].strip(): |
|
errors.append("Second class name is required.") |
|
if normalized[0] == normalized[1] and all(normalized): |
|
errors.append("Class names must be different.") |
|
return errors |
|
|
|
if classification_type == "Sentiment Analysis": |
|
st.write("### Sentiment Analysis Labels (Fixed)") |
|
col1, col2, col3 = st.columns(3) |
|
with col1: |
|
st.text_input("First class", "Positive", disabled=True) |
|
with col2: |
|
st.text_input("Second class", "Negative", disabled=True) |
|
with col3: |
|
st.text_input("Third class", "Neutral", disabled=True) |
|
labels = ["Positive", "Negative", "Neutral"] |
|
|
|
elif classification_type == "Binary Classification": |
|
st.write("### Binary Classification Labels") |
|
col1, col2 = st.columns(2) |
|
with col1: |
|
label_1 = st.text_input("First class", "Positive") |
|
with col2: |
|
label_2 = st.text_input("Second class", "Negative") |
|
|
|
labels = [label_1, label_2] |
|
errors = validate_binary_labels(labels) |
|
|
|
if errors: |
|
st.error("\n".join(errors)) |
|
else: |
|
st.success("Binary class names are valid and unique!") |
|
|
|
|
|
elif classification_type == "Multi-Class Classification": |
|
st.write("### Multi-Class Classification Labels") |
|
|
|
default_labels_by_domain = { |
|
"News": ["Political", "Sports", "Entertainment", "Technology", "Business"], |
|
"AG News": ["World", "Sports", "Business", "Sci/Tech"], |
|
"Tourism": ["Accommodation", "Transportation", "Tourist Attractions", |
|
"Food & Dining", "Local Experience", "Adventure Activities", |
|
"Wellness & Spa", "Eco-Friendly Practices", "Family-Friendly", |
|
"Luxury Tourism"], |
|
"Restaurant reviews": ["Italian", "French", "American"], |
|
"E-Commerce reviews": ["Mobile Phones & Accessories", "Laptops & Computers","Kitchen & Dining", |
|
"Beauty & Personal Care", "Home & Furniture", "Clothing & Fashion", |
|
"Shoes & Handbags", "Health & Wellness", "Electronics & Gadgets", |
|
"Books & Stationery","Toys & Games", "Sports & Fitness", |
|
"Grocery & Gourmet Food","Watches & Accessories", "Baby Products"] |
|
} |
|
|
|
num_classes = st.slider("Number of classes", 3, 15, 3) |
|
|
|
|
|
defaults = default_labels_by_domain.get(domain, []) |
|
|
|
labels = [] |
|
errors = [] |
|
cols = st.columns(3) |
|
|
|
for i in range(num_classes): |
|
with cols[i % 3]: |
|
default_value = defaults[i] if i < len(defaults) else "" |
|
label_input = st.text_input(f"Class {i+1}", default_value) |
|
normalized_label = label_input.strip().title() |
|
|
|
if not normalized_label: |
|
errors.append(f"Class {i+1} name is required.") |
|
else: |
|
labels.append(normalized_label) |
|
|
|
|
|
if len(labels) != len(set(labels)): |
|
errors.append("Labels names must be unique (case-insensitive, normalized to Title Case).") |
|
|
|
|
|
if errors: |
|
for error in errors: |
|
st.error(error) |
|
else: |
|
st.success("All Labels names are valid and unique!") |
|
labels_valid = not errors |
|
|
|
|
|
|
|
|
|
add_attributes = st.checkbox("Add additional attributes (optional)") |
|
additional_attributes = [] |
|
|
|
if add_attributes: |
|
num_attributes = st.slider("Number of attributes to add", 1, 5, 1) |
|
for i in range(num_attributes): |
|
st.markdown(f"#### Attribute {i+1}") |
|
attr_name = st.text_input(f"Name of attribute {i+1}", key=f"attr_name_{i}") |
|
attr_topics = st.text_input(f"Topics (comma-separated) for {attr_name}", key=f"attr_topics_{i}") |
|
if attr_name and attr_topics: |
|
topics_list = [topic.strip() for topic in attr_topics.split(",") if topic.strip()] |
|
additional_attributes.append({"attribute": attr_name, "topics": topics_list}) |
|
|
|
|
|
|
|
|
|
col1, col2 = st.columns(2) |
|
with col1: |
|
min_words = st.number_input("Min words", 1, 100, 20) |
|
with col2: |
|
max_words = st.number_input("Max words", min_words, 100, 50) |
|
|
|
|
|
use_few_shot = st.toggle("Use few-shot examples") |
|
few_shot_examples = [] |
|
if use_few_shot: |
|
num_examples = st.slider("Number of few-shot examples", 1, 10, 1) |
|
for i in range(num_examples): |
|
with st.expander(f"Example {i+1}"): |
|
content = st.text_area(f"Content", key=f"few_shot_content_{i}") |
|
label = st.selectbox(f"Label", labels, key=f"few_shot_label_{i}") |
|
if content and label: |
|
few_shot_examples.append({"content": content, "label": label}) |
|
|
|
num_to_generate = st.number_input("Number of examples", 1, 100, 10) |
|
|
|
|
|
|
|
|
|
default_system_role = ( |
|
f"You are a seasoned expert in {classification_type}, specializing in the {domain} domain. " |
|
f" Your primary responsibility is to generate high-quality, diverse, and unique text examples " |
|
f"tailored to this domain. Please ensure that each example adheres to the specified length " |
|
f"requirements, ranging from {min_words} to {max_words} words, and avoid any repetition in the generated content." |
|
) |
|
system_role = st.text_area("Modify System Role (optional)", |
|
value=default_system_role, |
|
key="system_role_input") |
|
st.session_state['system_role'] = system_role if system_role else default_system_role |
|
|
|
|
|
|
|
|
|
user_prompt = st.text_area("User Prompt (optional)") |
|
|
|
|
|
prompt_template = PromptTemplate( |
|
input_variables=["system_role", "classification_type", "domain", "num_examples", |
|
"min_words", "max_words", "labels", "user_prompt", "few_shot_examples", "additional_attributes"], |
|
template=( |
|
"{system_role}\n" |
|
"- Use the following parameters:\n" |
|
"- Generate {num_examples} examples\n" |
|
"- Each example should be between {min_words} to {max_words} words long\n" |
|
"- Use these labels: {labels}.\n" |
|
"- Use the following additional attributes:\n" |
|
"- {additional_attributes}\n" |
|
"- Generate the examples in this format: 'Example text. Label: label'\n" |
|
"- Do not include word counts or any additional information\n" |
|
"- Always use your creativity and intelligence to generate unique and diverse text data\n" |
|
"- In sentiment analysis, ensure that the sentiment classification is clearly identified as Positive, Negative, or Neutral. Do not leave the sentiment ambiguous.\n" |
|
"- In binary sentiment analysis, classify text strictly as either Positive or Negative. Do not include or imply Neutral as an option.\n" |
|
"- Write unique examples every time.\n" |
|
"- DO NOT REPEAT your gnerated text. \n" |
|
"- For each Output, describe it once and move to the next.\n" |
|
"- List each Output only once, and avoid repeating details.\n" |
|
"- Additional instructions: {user_prompt}\n\n" |
|
"- Use the following examples as a reference in the generation process\n\n {few_shot_examples}. \n" |
|
"- Think step by step, generate numbered examples, and check each newly generated example to ensure it has not been generated before. If it has, modify it" |
|
|
|
) |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
formatted_attributes = "\n".join([ |
|
f"- {attr['attribute']}: {', '.join(attr['topics'])}" for attr in additional_attributes |
|
]) |
|
|
|
|
|
|
|
system_prompt = prompt_template.format( |
|
system_role=st.session_state['system_role'], |
|
classification_type=classification_type, |
|
domain=domain, |
|
num_examples=num_to_generate, |
|
min_words=min_words, |
|
max_words=max_words, |
|
labels=", ".join(labels), |
|
user_prompt=user_prompt, |
|
few_shot_examples="\n".join([f"{ex['content']}\nLabel: {ex['label']}" for ex in few_shot_examples]) if few_shot_examples else "", |
|
additional_attributes=formatted_attributes |
|
) |
|
|
|
|
|
|
|
st.session_state['system_prompt'] = system_prompt |
|
|
|
|
|
st.write("System Prompt:") |
|
st.text_area("Current System Prompt", value=st.session_state['system_prompt'], |
|
height=400, disabled=True) |
|
|
|
|
|
if st.button("π― Generate Examples"): |
|
|
|
errors = [] |
|
if domain_selection == "Custom" and not domain.strip(): |
|
st.warning("Custom domain name is required.") |
|
elif len(labels) != len(set(labels)): |
|
st.warning("Class names must be unique.") |
|
elif any(not lbl.strip() for lbl in labels): |
|
st.warning("All class labels must be filled in.") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
with st.spinner("Generating examples..."): |
|
try: |
|
stream = client.chat.completions.create( |
|
model=selected_model, |
|
messages=[{"role": "system", "content": st.session_state['system_prompt']}], |
|
temperature=temperature, |
|
stream=True, |
|
|
|
max_tokens=4000, |
|
top_p=0.9, |
|
|
|
|
|
|
|
) |
|
|
|
|
|
st.session_state.messages.append({"role": "user", "content": system_prompt}) |
|
|
|
response = st.write_stream(stream) |
|
st.session_state.messages.append({"role": "assistant", "content": response}) |
|
|
|
if 'system_prompt' not in st.session_state: |
|
st.session_state.system_prompt = system_prompt |
|
|
|
if 'response' not in st.session_state: |
|
st.session_state.response = response |
|
|
|
if 'generated_examples' not in st.session_state: |
|
st.session_state.generated_examples = [] |
|
|
|
if 'generated_examples_csv' not in st.session_state: |
|
st.session_state.generated_examples_csv = None |
|
|
|
if 'generated_examples_json' not in st.session_state: |
|
st.session_state.generated_examples_json = None |
|
|
|
|
|
examples_list = [] |
|
for line in response.split('\n'): |
|
if line.strip(): |
|
parts = line.rsplit('Label:', 1) |
|
if len(parts) == 2: |
|
text = parts[0].strip() |
|
label = parts[1].strip() |
|
if text and label: |
|
examples_list.append({ |
|
'text': text, |
|
'label': label, |
|
'system_prompt': st.session_state.system_prompt, |
|
'system_role': st.session_state.system_role, |
|
'task_type': 'Data Generation', |
|
'Use few-shot example?': 'Yes' if use_few_shot else 'No', |
|
}) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if examples_list: |
|
|
|
st.session_state.generated_examples = examples_list |
|
|
|
|
|
df = pd.DataFrame(examples_list) |
|
st.session_state.generated_examples_csv = df.to_csv(index=False).encode('utf-8') |
|
st.session_state.generated_examples_json = json.dumps(examples_list, indent=2).encode('utf-8') |
|
|
|
|
|
st.download_button( |
|
"π₯ Download Generated Examples (CSV)", |
|
st.session_state.generated_examples_csv, |
|
"generated_examples.csv", |
|
"text/csv", |
|
key='download-csv-persistent' |
|
) |
|
|
|
|
|
st.markdown(""" |
|
<div style='text-align: left; margin:15px 0; font-weight: 600; color: #666;'>. . . . . . or</div> |
|
""", unsafe_allow_html=True) |
|
|
|
st.download_button( |
|
"π₯ Download Generated Examples (JSON)", |
|
st.session_state.generated_examples_json, |
|
"generated_examples.json", |
|
"application/json", |
|
key='download-json-persistent' |
|
) |
|
|
|
|
|
|
|
|
|
if st.button("Continue"): |
|
if follow_up == "Generate more examples": |
|
st.experimental_rerun() |
|
elif follow_up == "Data Labeling": |
|
st.session_state.task_choice = "Data Labeling" |
|
st.experimental_rerun() |
|
|
|
except Exception as e: |
|
st.error("An error occurred during generation.") |
|
st.error(f"Details: {e}") |
|
|
|
|
|
|
|
elif st.session_state.task_choice == "Data Labeling": |
|
st.header("π·οΈ Data Labeling") |
|
|
|
domain_selection = st.selectbox("Domain", ["Restaurant reviews", "E-Commerce reviews", "News", "AG News", "Tourism", "Custom"]) |
|
|
|
custom_domain_valid = True |
|
|
|
if domain_selection == "Custom": |
|
domain = st.text_input("Specify custom domain") |
|
if not domain.strip(): |
|
st.error("Please specify a domain name.") |
|
custom_domain_valid = False |
|
else: |
|
domain = domain_selection |
|
|
|
|
|
|
|
classification_type = st.selectbox( |
|
"Classification Type", |
|
["Sentiment Analysis", "Binary Classification", "Multi-Class Classification", "Named Entity Recognition (NER)"] |
|
) |
|
|
|
|
|
labels = [] |
|
labels_valid = False |
|
errors = [] |
|
|
|
if classification_type == "Sentiment Analysis": |
|
st.write("### Sentiment Analysis Labels (Fixed)") |
|
col1, col2, col3 = st.columns(3) |
|
with col1: |
|
label_1 = st.text_input("First class", "Positive", disabled=True) |
|
with col2: |
|
label_2 = st.text_input("Second class", "Negative", disabled=True) |
|
with col3: |
|
label_3 = st.text_input("Third class", "Neutral", disabled=True) |
|
labels = ["Positive", "Negative", "Neutral"] |
|
|
|
|
|
elif classification_type == "Binary Classification": |
|
st.write("### Binary Classification Labels") |
|
col1, col2 = st.columns(2) |
|
|
|
with col1: |
|
label_1 = st.text_input("First class", "Positive") |
|
with col2: |
|
label_2 = st.text_input("Second class", "Negative") |
|
|
|
errors = [] |
|
labels = [label_1.strip(), label_2.strip()] |
|
|
|
|
|
|
|
label_1 = labels[0].strip() |
|
label_2 = labels[1].strip() |
|
|
|
|
|
if not label_1: |
|
errors.append("First class name is required.") |
|
if not label_2: |
|
errors.append("Second class name is required.") |
|
|
|
|
|
if label_1.lower() == label_2.lower() and label_1 and label_2: |
|
errors.append("Class names must be different.") |
|
|
|
|
|
if errors: |
|
for error in errors: |
|
st.error(error) |
|
else: |
|
st.success("Binary class names are valid and unique!") |
|
|
|
|
|
elif classification_type == "Multi-Class Classification": |
|
st.write("### Multi-Class Classification Labels") |
|
|
|
default_labels_by_domain = { |
|
"News": ["Political", "Sports", "Entertainment", "Technology", "Business"], |
|
"AG News": ["World", "Sports", "Business", "Sci/Tech"], |
|
"Tourism": ["Accommodation", "Transportation", "Tourist Attractions", |
|
"Food & Dining", "Local Experience", "Adventure Activities", |
|
"Wellness & Spa", "Eco-Friendly Practices", "Family-Friendly", |
|
"Luxury Tourism"], |
|
"Restaurant reviews": ["Italian", "French", "American"], |
|
"E-Commerce reviews": ["Mobile Phones & Accessories", "Laptops & Computers","Kitchen & Dining", |
|
"Beauty & Personal Care", "Home & Furniture", "Clothing & Fashion", |
|
"Shoes & Handbags", "Health & Wellness", "Electronics & Gadgets", |
|
"Books & Stationery","Toys & Games", "Sports & Fitness", |
|
"Grocery & Gourmet Food","Watches & Accessories", "Baby Products"] |
|
} |
|
|
|
|
|
|
|
|
|
num_classes = st.slider("Select the number of classes (labels)", min_value=3, max_value=10, value=3) |
|
|
|
|
|
defaults = default_labels_by_domain.get(domain, []) |
|
|
|
labels = [] |
|
errors = [] |
|
cols = st.columns(3) |
|
|
|
for i in range(num_classes): |
|
with cols[i % 3]: |
|
default_value = defaults[i] if i < len(defaults) else "" |
|
label_input = st.text_input(f"Label {i + 1}", default_value) |
|
normalized_label = label_input.strip().title() |
|
|
|
if not normalized_label: |
|
errors.append(f"Label {i + 1} is required.") |
|
else: |
|
labels.append(normalized_label) |
|
|
|
|
|
normalized_set = {label.lower() for label in labels} |
|
if len(labels) != len(normalized_set): |
|
errors.append("Label names must be unique (case-insensitive).") |
|
|
|
|
|
if errors: |
|
for error in errors: |
|
st.error(error) |
|
else: |
|
st.success("All label names are valid and unique!") |
|
|
|
labels_valid = not errors |
|
|
|
elif classification_type == "Named Entity Recognition (NER)": |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
use_few_shot = True |
|
|
|
few_shot_examples = [ |
|
{"content": "Mount Everest is the tallest mountain in the world.", "label": "LOC: Mount Everest"}, |
|
{"content": "The President of the United States visited Paris last summer.", "label": "GPE: United States, GPE: Paris"}, |
|
{"content": "Amazon is expanding its offices in Berlin.", "label": "ORG: Amazon, GPE: Berlin"}, |
|
{"content": "J.K. Rowling wrote the Harry Potter books.", "label": "PERSON: J.K. Rowling"}, |
|
{"content": "Apple was founded in California in 1976.", "label": "ORG: Apple, GPE: California, DATE: 1976"}, |
|
{"content": "The Nile is the longest river in Africa.", "label": "LOC: Nile, GPE: Africa"}, |
|
{"content": "He arrived at 3 PM for the meeting.", "label": "TIME: 3 PM"}, |
|
{"content": "She bought the dress for $200.", "label": "MONEY: $200"}, |
|
{"content": "The event is scheduled for July 4th.", "label": "DATE: July 4th"}, |
|
{"content": "The World Health Organization is headquartered in Geneva.", "label": "ORG: World Health Organization, GPE: Geneva"} |
|
] |
|
|
|
|
|
st.write("### Named Entity Recognition (NER) Entities") |
|
|
|
|
|
ner_entities = [ |
|
"PERSON - Names of people, fictional characters, historical figures", |
|
"ORG - Companies, institutions, agencies, teams", |
|
"LOC - Physical locations (mountains, oceans, etc.)", |
|
"GPE - Countries, cities, states, political regions", |
|
"DATE - Calendar dates, years, centuries", |
|
"TIME - Times, durations", |
|
"MONEY - Monetary values with currency" |
|
] |
|
|
|
|
|
custom_ner_entities = [] |
|
if st.checkbox("Add custom NER entities?"): |
|
num_custom_ner = st.slider("Number of custom NER entities", 1, 10, 1) |
|
for i in range(num_custom_ner): |
|
st.markdown(f"#### Custom Entity {i+1}") |
|
custom_type = st.text_input(f"Entity type {i+1}", key=f"custom_ner_type_{i}") |
|
custom_description = st.text_input(f"Description for {custom_type}", key=f"custom_ner_desc_{i}") |
|
if custom_type and custom_description: |
|
custom_ner_entities.append(f"{custom_type.upper()} - {custom_description}") |
|
|
|
|
|
all_ner_options = ner_entities + custom_ner_entities |
|
|
|
selected_entities = st.multiselect( |
|
"Select entities to recognize", |
|
all_ner_options, |
|
default=ner_entities |
|
) |
|
|
|
|
|
labels = [entity.split(" - ")[0].strip() for entity in selected_entities] |
|
|
|
if not labels: |
|
st.warning("Please select at least one entity type.") |
|
labels = ["PERSON"] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
use_few_shot = st.toggle("Use few-shot examples for labeling") |
|
few_shot_examples = [] |
|
if use_few_shot: |
|
num_few_shot = st.slider("Number of few-shot examples", 1, 10, 1) |
|
for i in range(num_few_shot): |
|
with st.expander(f"Few-shot Example {i+1}"): |
|
content = st.text_area(f"Content", key=f"label_few_shot_content_{i}") |
|
label = st.selectbox(f"Label", labels, key=f"label_few_shot_label_{i}") |
|
if content and label: |
|
few_shot_examples.append(f"{content}\nLabel: {label}") |
|
|
|
num_examples = st.number_input("Number of examples to classify", 1, 100, 1) |
|
|
|
examples_to_classify = [] |
|
if num_examples <= 10: |
|
for i in range(num_examples): |
|
example = st.text_area(f"Example {i+1}", key=f"example_{i}") |
|
if example: |
|
examples_to_classify.append(example) |
|
else: |
|
examples_text = st.text_area( |
|
"Enter examples (one per line)", |
|
height=300, |
|
help="Enter each example on a new line" |
|
) |
|
if examples_text: |
|
examples_to_classify = [ex.strip() for ex in examples_text.split('\n') if ex.strip()] |
|
if len(examples_to_classify) > num_examples: |
|
examples_to_classify = examples_to_classify[:num_examples] |
|
|
|
|
|
|
|
|
|
default_system_role = (f"You are a highly skilled {classification_type} expert." |
|
f" Your task is to accurately classify the provided text examples within the {domain} domain." |
|
f" Ensure that all classifications are precise, context-aware, and aligned with domain-specific standards and best practices." |
|
) |
|
system_role = st.text_area("Modify System Role (optional)", |
|
value=default_system_role, |
|
key="system_role_input") |
|
st.session_state['system_role'] = system_role if system_role else default_system_role |
|
|
|
|
|
|
|
|
|
user_prompt = st.text_area("User prompt (optional)", key="label_instructions") |
|
|
|
few_shot_text = "\n\n".join(few_shot_examples) if few_shot_examples else "" |
|
examples_text = "\n".join([f"{i+1}. {ex}" for i, ex in enumerate(examples_to_classify)]) |
|
|
|
|
|
if classification_type == "Named Entity Recognition (NER)": |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
label_prompt_template = PromptTemplate( |
|
input_variables=["system_role", "labels", "few_shot_examples", "examples", "domain", "user_prompt"], |
|
template=( |
|
"{system_role}\n" |
|
"- You are an expert at Named Entity Recognition (NER) for domain: {domain}.\n" |
|
"- Use these entity types: {labels}.\n\n" |
|
"### Output Format:\n" |
|
"Return each example followed by the entities you found in this format:\n" |
|
"'Example text.\nEntity types:\n" |
|
"Then group the entities under each label like this:\n" |
|
"\nPERSON β Angela Merkel, John Smith\n" |
|
"ORG β Google, United Nations\n" |
|
"DATE β January 1st, 2023\n" |
|
"... and so on.\n\n" |
|
"Each new entities group should be in a new line.\n" |
|
"If entity type {labels} is not found, do not write it in your response.\n" |
|
"- Do NOT output them inline after the text.\n" |
|
"- Do NOT repeat the sentence.\n" |
|
"- If no entities are found for a type, skip it.\n" |
|
"- Keep the format consistent.\n\n" |
|
"User Instructions:\n{user_prompt}\n\n" |
|
"Few-shot Examples:\n{few_shot_examples}\n\n" |
|
"Examples to analyze:\n{examples}" |
|
) |
|
) |
|
|
|
|
|
else: |
|
label_prompt_template = PromptTemplate( |
|
|
|
input_variables=["system_role", "classification_type", "labels", "few_shot_examples", "examples","domain", "user_prompt"], |
|
template=( |
|
|
|
"{system_role}\n" |
|
|
|
"- Use the following instructions:\n" |
|
"- Use the following labels: {labels}.\n" |
|
"- Return the classified text followed by the label in this format: 'text. Label: [label]'\n" |
|
"- Do not provide any additional information or explanations\n" |
|
"- User prompt:\n {user_prompt}\n\n" |
|
"- Use user provided examples as guidence in the classification process:\n\n {few_shot_examples}\n" |
|
"- Examples to classify:\n{examples}\n\n" |
|
"- Think step by step then classify the examples" |
|
|
|
)) |
|
|
|
|
|
|
|
if isinstance(few_shot_examples, str): |
|
formatted_few_shot = few_shot_examples |
|
|
|
elif isinstance(few_shot_examples, list) and all(isinstance(ex, str) for ex in few_shot_examples): |
|
formatted_few_shot = "\n".join(few_shot_examples) |
|
|
|
elif isinstance(few_shot_examples, list) and all(isinstance(ex, dict) and 'content' in ex and 'label' in ex for ex in few_shot_examples): |
|
formatted_few_shot = "\n".join([f"{ex['content']}\nLabel: {ex['label']}" for ex in few_shot_examples]) |
|
else: |
|
formatted_few_shot = "" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
formatted_few_shot = "\n\n".join([f"{ex['content']}\n\nEntity types\n{ex['label']}" for ex in few_shot_examples]) |
|
|
|
|
|
system_prompt = label_prompt_template.format( |
|
system_role=st.session_state['system_role'], |
|
classification_type=classification_type, |
|
domain=domain, |
|
examples="\n".join(examples_to_classify), |
|
labels=", ".join(labels), |
|
user_prompt=user_prompt, |
|
few_shot_examples=formatted_few_shot |
|
) |
|
|
|
|
|
st.session_state['system_prompt'] = system_prompt |
|
|
|
st.write("System Prompt:") |
|
|
|
|
|
st.text_area("System Prompt", value=st.session_state['system_prompt'], height=300, max_chars=None, key=None, help=None, disabled=True) |
|
|
|
|
|
|
|
if st.button("π·οΈ Label Data"): |
|
if examples_to_classify: |
|
with st.spinner("Labeling data..."): |
|
|
|
if classification_type == "Named Entity Recognition (NER)": |
|
system_prompt = label_prompt_template.format( |
|
system_role=st.session_state['system_role'], |
|
labels=", ".join(labels), |
|
domain = domain, |
|
few_shot_examples=few_shot_text, |
|
examples=examples_text, |
|
user_prompt=user_prompt |
|
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
else: |
|
system_prompt = label_prompt_template.format( |
|
classification_type=classification_type, |
|
system_role=st.session_state['system_role'], |
|
domain = domain, |
|
labels=", ".join(labels), |
|
few_shot_examples=few_shot_text, |
|
examples=examples_text, |
|
user_prompt=user_prompt |
|
) |
|
try: |
|
stream = client.chat.completions.create( |
|
model=selected_model, |
|
messages=[{"role": "system", "content": system_prompt}], |
|
temperature=temperature, |
|
stream=True, |
|
|
|
max_tokens=4000, |
|
top_p = 0.9, |
|
|
|
) |
|
|
|
|
|
st.session_state.messages.append({"role": "user", "content": system_prompt}) |
|
|
|
response = st.write_stream(stream) |
|
st.session_state.messages.append({"role": "assistant", "content": response}) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if 'system_prompt' not in st.session_state: |
|
st.session_state.system_prompt = system_prompt |
|
|
|
if 'response' not in st.session_state: |
|
st.session_state.response = response |
|
|
|
if 'generated_examples' not in st.session_state: |
|
st.session_state.generated_examples = [] |
|
|
|
if 'generated_examples_csv' not in st.session_state: |
|
st.session_state.generated_examples_csv = None |
|
|
|
if 'generated_examples_json' not in st.session_state: |
|
st.session_state.generated_examples_json = None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
labeled_examples = [] |
|
if classification_type == "Named Entity Recognition (NER)": |
|
labeled_examples = [{ |
|
'ner_output': response.strip(), |
|
'system_prompt': st.session_state.system_prompt, |
|
'system_role': st.session_state.system_role, |
|
'task_type': 'Named Entity Recognition (NER)', |
|
'Use few-shot example?': 'Yes' if use_few_shot else 'No', |
|
}] |
|
|
|
|
|
|
|
|
|
else: |
|
labeled_examples = [] |
|
for line in response.split('\n'): |
|
if line.strip(): |
|
parts = line.rsplit('Label:', 1) |
|
if len(parts) == 2: |
|
text = parts[0].strip() |
|
label = parts[1].strip() |
|
if text and label: |
|
labeled_examples.append({ |
|
'text': text, |
|
'label': label, |
|
'system_prompt': st.session_state.system_prompt, |
|
'system_role': st.session_state.system_role, |
|
'task_type': 'Data Labeling', |
|
'Use few-shot example?': 'Yes' if use_few_shot else 'No', |
|
}) |
|
|
|
if labeled_examples: |
|
|
|
st.session_state.labeled_examples = labeled_examples |
|
|
|
|
|
df = pd.DataFrame(labeled_examples) |
|
|
|
|
|
st.session_state.labeled_examples_csv = df.to_csv(index=False).encode('utf-8') |
|
|
|
|
|
st.session_state.labeled_examples_json = json.dumps({ |
|
"metadata": { |
|
"domain": domain, |
|
"labels": labels, |
|
"used_few_shot": use_few_shot, |
|
"task_type": "Named Entity Recognition (NER)", |
|
"timestamp": datetime.now().isoformat() |
|
}, |
|
"examples": labeled_examples |
|
}, indent=2).encode('utf-8') |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
st.download_button( |
|
"π₯ Download Labeled Examples (CSV)", |
|
st.session_state.labeled_examples_csv, |
|
"labeled_examples.csv", |
|
"text/csv", |
|
key='download-labeled-csv' |
|
) |
|
|
|
st.markdown(""" |
|
<div style='text-align: left; margin:15px 0; font-weight: 600; color: #666;'>. . . . . . or</div> |
|
""", unsafe_allow_html=True) |
|
|
|
st.download_button( |
|
"π₯ Download Labeled Examples (JSON)", |
|
st.session_state.labeled_examples_json, |
|
"labeled_examples.json", |
|
"application/json", |
|
key='download-labeled-json' |
|
) |
|
|
|
st.markdown("##### π Labeled Examples Preview") |
|
st.dataframe(df, use_container_width=True) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if st.button("Continue"): |
|
if follow_up == "Label more data": |
|
st.session_state.examples_to_classify = [] |
|
st.experimental_rerun() |
|
elif follow_up == "Data Generation": |
|
st.session_state.task_choice = "Data Labeling" |
|
st.experimental_rerun() |
|
|
|
except Exception as e: |
|
st.error("An error occurred during labeling.") |
|
st.error(f"Details: {e}") |
|
else: |
|
st.warning("Please enter at least one example to classify.") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
st.markdown("---") |
|
st.markdown( |
|
""" |
|
<div style='text-align: center'> |
|
<p>Made with β€οΈ by Wedyan AlSakran 2025</p> |
|
</div> |
|
""", |
|
unsafe_allow_html=True |
|
) |