Wedyan2023 commited on
Commit
bc09f8c
·
verified ·
1 Parent(s): 2626d68

Create app111.py

Browse files
Files changed (1) hide show
  1. app111.py +1618 -0
app111.py ADDED
@@ -0,0 +1,1618 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import pandas as pd
3
+ import os
4
+ import json
5
+ import base64
6
+ import random
7
+ from streamlit_pdf_viewer import pdf_viewer
8
+ from langchain.prompts import PromptTemplate
9
+ from datetime import datetime
10
+ from pathlib import Path
11
+ from openai import OpenAI
12
+ from dotenv import load_dotenv
13
+ import warnings
14
+
15
+ from transformers import AutoModelForCausalLM, AutoTokenizer
16
+ import torch
17
+
18
+ warnings.filterwarnings('ignore')
19
+
20
+ os.getenv("OAUTH_CLIENT_ID")
21
+
22
+
23
+ # # Load environment variables and initialize the OpenAI client to use Hugging Face Inference API.
24
+ # load_dotenv()
25
+ # client = OpenAI(
26
+ # base_url="https://api-inference.huggingface.co/v1",
27
+ # #api_key=os.environ.get('TOKEN2') # Hugging Face API token
28
+ # api_key=os.environ.get('LLM')
29
+ # )
30
+ #######
31
+ #from openai import OpenAI
32
+
33
+ client = OpenAI(
34
+ base_url="https://router.huggingface.co/hf-inference/models/meta-llama/Llama-3.3-70B-Instruct/v1",
35
+ #api_key="hf_xxxxxxxxxxxxxxxxxxxxxxxx",
36
+ api_key=os.environ.get('LLM')
37
+ )
38
+
39
+ completion = client.chat.completions.create(
40
+ model="meta-llama/Llama-3.3-70B-Instruct",
41
+ messages=[
42
+ {
43
+ "role": "user",
44
+ "content": "What is the capital of France?"
45
+ }
46
+ ],
47
+ )
48
+
49
+ print(completion.choices[0].message)
50
+ #######
51
+ #####
52
+ # from openai import OpenAI
53
+
54
+ # client = OpenAI(
55
+ # base_url="https://router.huggingface.co/together/v1",
56
+ # #api_key="hf_XXXXX",
57
+ # api_key=os.environ.get('TOKEN2'), # Hugging Face API token
58
+ # )
59
+ # #meta-llama/Meta-Llama-3-8B-Instruct
60
+ # completion = client.chat.completions.create(
61
+ # #model="meta-llama/Meta-Llama-3-8B-Instruct-Turbo",
62
+ # model="meta-llama/Meta-Llama-3-8B-Instruct",
63
+ # messages=[
64
+ # {
65
+ # "role": "user",
66
+ # "content": "What is the capital of France?"
67
+ # }
68
+ # ],
69
+ # )
70
+
71
+ #print(completion.choices[0].message)
72
+ #####
73
+ ##########################################################3
74
+ # import streamlit as st
75
+ # from transformers import AutoModelForCausalLM, AutoTokenizer
76
+ # import torch
77
+
78
+ # # Model selection dropdown
79
+ # selected_model = st.selectbox(
80
+ # "Select Model",
81
+ # ["meta-llama/Meta-Llama-3-8B-Instruct-Turbo",
82
+ # "meta-llama/Llama-3.3-70B-Instruct",
83
+ # "meta-llama/Llama-3.2-3B-Instruct",
84
+ # "meta-llama/Llama-4-Scout-17B-16E-Instruct",
85
+ # "meta-llama/Meta-Llama-3-8B-Instruct",
86
+ # "meta-llama/Llama-3.1-70B-Instruct"],
87
+ # key='model_select'
88
+ # )
89
+
90
+ # @st.cache_resource # Cache the model to prevent reloading
91
+ # def load_model(model_name):
92
+ # try:
93
+ # # Optimized model loading configuration
94
+ # model = AutoModelForCausalLM.from_pretrained(
95
+ # model_name,
96
+ # torch_dtype=torch.float16, # Use half precision
97
+ # device_map="auto", # Automatic device mapping
98
+ # load_in_8bit=True, # Enable 8-bit quantization
99
+ # low_cpu_mem_usage=True, # Optimize CPU memory usage
100
+ # max_memory={0: "10GB"} # Limit GPU memory usage
101
+ # )
102
+
103
+ # tokenizer = AutoTokenizer.from_pretrained(
104
+ # model_name,
105
+ # padding_side="left",
106
+ # truncation_side="left"
107
+ # )
108
+
109
+ # return model, tokenizer
110
+
111
+ # except Exception as e:
112
+ # st.error(f"Error loading model: {str(e)}")
113
+ # return None, None
114
+
115
+ # # Load the selected model with optimizations
116
+ # if selected_model:
117
+ # model, tokenizer = load_model(selected_model)
118
+
119
+ # # Check if model loaded successfully
120
+ # if model is not None:
121
+ # st.success(f"Successfully loaded {selected_model}")
122
+ # else:
123
+ # st.warning("Please select a different model or check your hardware capabilities")
124
+
125
+ # # Function to generate text
126
+ # def generate_response(prompt, model, tokenizer):
127
+ # try:
128
+ # inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512)
129
+
130
+ # with torch.no_grad():
131
+ # outputs = model.generate(
132
+ # inputs["input_ids"],
133
+ # max_length=256,
134
+ # num_return_sequences=1,
135
+ # temperature=0.7,
136
+ # do_sample=True,
137
+ # pad_token_id=tokenizer.pad_token_id
138
+ # )
139
+
140
+ # response = tokenizer.decode(outputs[0], skip_special_tokens=True)
141
+ # return response
142
+
143
+ # except Exception as e:
144
+ # return f"Error generating response: {str(e)}"
145
+ ############################################################
146
+
147
+ ####new
148
+ # from openai import OpenAI
149
+
150
+ # client = OpenAI(
151
+ # base_url="https://router.huggingface.co/together/v1",
152
+ # api_key=os.environ.get('TOKEN2'),
153
+ # )
154
+
155
+ # completion = client.chat.completions.create(
156
+ # model="meta-llama/Meta-Llama-3-8B-Instruct-Turbo",
157
+ # messages=[
158
+ # {
159
+ # "role": "user",
160
+ # "content": "What is the capital of France?"
161
+ # }
162
+ # ],
163
+ # max_tokens=512,
164
+ # )
165
+
166
+ # print(completion.choices[0].message)
167
+ #####
168
+
169
+ # Create necessary directories
170
+ for dir_name in ['data', 'feedback']:
171
+ if not os.path.exists(dir_name):
172
+ os.makedirs(dir_name)
173
+
174
+ # Custom CSS
175
+ st.markdown("""
176
+ <style>
177
+ .stButton > button {
178
+ width: 100%;
179
+ margin-bottom: 10px;
180
+ background-color: #4CAF50;
181
+ color: white;
182
+ border: none;
183
+ padding: 10px;
184
+ border-radius: 5px;
185
+ }
186
+ .task-button {
187
+ background-color: #2196F3 !important;
188
+ }
189
+ .stSelectbox {
190
+ margin-bottom: 20px;
191
+ }
192
+ .output-container {
193
+ padding: 20px;
194
+ border-radius: 5px;
195
+ border: 1px solid #ddd;
196
+ margin: 10px 0;
197
+ }
198
+ .status-container {
199
+ padding: 10px;
200
+ border-radius: 5px;
201
+ margin: 10px 0;
202
+ }
203
+ .sidebar-info {
204
+ padding: 10px;
205
+ background-color: #f0f2f6;
206
+ border-radius: 5px;
207
+ margin: 10px 0;
208
+ }
209
+ .feedback-button {
210
+ background-color: #ff9800 !important;
211
+ }
212
+ .feedback-container {
213
+ padding: 15px;
214
+ background-color: #f5f5f5;
215
+ border-radius: 5px;
216
+ margin: 15px 0;
217
+ }
218
+ </style>
219
+ """, unsafe_allow_html=True)
220
+
221
+ # Helper functions
222
+ def read_csv_with_encoding(file):
223
+ encodings = ['utf-8', 'latin1', 'iso-8859-1', 'cp1252']
224
+ for encoding in encodings:
225
+ try:
226
+ return pd.read_csv(file, encoding=encoding)
227
+ except UnicodeDecodeError:
228
+ continue
229
+ raise UnicodeDecodeError("Failed to read file with any supported encoding")
230
+
231
+ #def save_feedback(feedback_data):
232
+ #feedback_file = 'feedback/user_feedback.csv'
233
+ #feedback_df = pd.DataFrame([feedback_data])
234
+
235
+ #if os.path.exists(feedback_file):
236
+ #feedback_df.to_csv(feedback_file, mode='a', header=False, index=False)
237
+ #else:
238
+ #feedback_df.to_csv(feedback_file, index=False)
239
+
240
+ def reset_conversation():
241
+ st.session_state.conversation = []
242
+ st.session_state.messages = []
243
+ if 'task_choice' in st.session_state:
244
+ del st.session_state.task_choice
245
+ return None
246
+ #new 24 March
247
+ #user_input = st.text_input("Enter your prompt:")
248
+ ###########33
249
+
250
+ # Initialize session state variables
251
+ if "messages" not in st.session_state:
252
+ st.session_state.messages = []
253
+ if "examples_to_classify" not in st.session_state:
254
+ st.session_state.examples_to_classify = []
255
+ if "system_role" not in st.session_state:
256
+ st.session_state.system_role = ""
257
+
258
+
259
+
260
+ # Main app title
261
+ st.title("🤖🦙 Text Data Labeling and Generation App")
262
+ # def embed_pdf_sidebar(pdf_path):
263
+ # with open(pdf_path, "rb") as f:
264
+ # base64_pdf = base64.b64encode(f.read()).decode('utf-8')
265
+ # pdf_display = f"""
266
+ # <iframe src="data:application/pdf;base64,{base64_pdf}"
267
+ # width="100%" height="400" type="application/pdf"></iframe>
268
+ # """
269
+ # st.markdown(pdf_display, unsafe_allow_html=True)
270
+ #
271
+
272
+
273
+ # Sidebar settings
274
+ with st.sidebar:
275
+ st.title("⚙️ Settings")
276
+
277
+
278
+ #this last code works
279
+ with st.sidebar:
280
+ st.markdown("### 📘Data Generation and Labeling Instructions")
281
+ #st.markdown("<h4 style='color: #4A90E2;'>📘 Instructions</h4>", unsafe_allow_html=True)
282
+ with open("User instructions.pdf", "rb") as f:
283
+ st.download_button(
284
+ label="📄 Download Instructions PDF",
285
+ data=f,
286
+ #file_name="instructions.pdf",
287
+ file_name="User instructions.pdf",
288
+ mime="application/pdf"
289
+ )
290
+
291
+ selected_model = st.selectbox(
292
+ "Select Model",
293
+ ["meta-llama/Llama-Prompt-Guard-2-86M","mistralai/Mistral-7B-Instruct-v0.2", "meta-llama/Llama-3.2-11B-Vision-Instruct","meta-llama/Meta-Llama-3-8B-Instruct-Turbo", "meta-llama/Llama-3.3-70B-Instruct", "meta-llama/Llama-3.2-3B-Instruct","meta-llama/Llama-4-Scout-17B-16E-Instruct", "meta-llama/Meta-Llama-3-8B-Instruct",
294
+ "meta-llama/Llama-3.1-70B-Instruct"],
295
+ key='model_select'
296
+ )
297
+
298
+ #################new oooo
299
+
300
+ # # Model selection dropdown
301
+ # selected_model = st.selectbox(
302
+ # "Select Model",
303
+ # [#"meta-llama/Meta-Llama-3-8B-Instruct-Turbo",
304
+ # "meta-llama/Llama-3.2-3B-Instruct",
305
+ # "meta-llama/Llama-3.3-70B-Instruct",
306
+ # "meta-llama/Llama-3.2-3B-Instruct",
307
+ # "meta-llama/Llama-4-Scout-17B-16E-Instruct",
308
+ # "meta-llama/Meta-Llama-3-8B-Instruct",
309
+ # "meta-llama/Llama-3.1-70B-Instruct"],
310
+ # key='model_select'
311
+ # )
312
+
313
+ # @st.cache_resource # Cache the model to prevent reloading
314
+ # def load_model(model_name):
315
+ # try:
316
+ # # Optimized model loading configuration
317
+ # model = AutoModelForCausalLM.from_pretrained(
318
+ # model_name,
319
+ # torch_dtype=torch.float16, # Use half precision
320
+ # device_map="auto", # Automatic device mapping
321
+ # load_in_8bit=True, # Enable 8-bit quantization
322
+ # low_cpu_mem_usage=True, # Optimize CPU memory usage
323
+ # max_memory={0: "10GB"} # Limit GPU memory usage
324
+ # )
325
+
326
+ # tokenizer = AutoTokenizer.from_pretrained(
327
+ # model_name,
328
+ # padding_side="left",
329
+ # truncation_side="left"
330
+ # )
331
+
332
+ # return model, tokenizer
333
+
334
+ # except Exception as e:
335
+ # st.error(f"Error loading model: {str(e)}")
336
+ # return None, None
337
+
338
+ # # Load the selected model with optimizations
339
+ # if selected_model:
340
+ # model, tokenizer = load_model(selected_model)
341
+
342
+ # # Check if model loaded successfully
343
+ # if model is not None:
344
+ # st.success(f"Successfully loaded {selected_model}")
345
+ # else:
346
+ # st.warning("Please select a different model or check your hardware capabilities")
347
+
348
+ # # Function to generate text
349
+ # def generate_response(prompt, model, tokenizer):
350
+ # try:
351
+ # inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512)
352
+
353
+ # with torch.no_grad():
354
+ # outputs = model.generate(
355
+ # inputs["input_ids"],
356
+ # max_length=256,
357
+ # num_return_sequences=1,
358
+ # temperature=0.7,
359
+ # do_sample=True,
360
+ # pad_token_id=tokenizer.pad_token_id
361
+ # )
362
+
363
+ # response = tokenizer.decode(outputs[0], skip_special_tokens=True)
364
+ # return response
365
+
366
+ # except Exception as e:
367
+ # return f"Error generating response: {str(e)}"
368
+ # ################
369
+
370
+ # model = AutoModelForCausalLM.from_pretrained(
371
+ # "meta-llama/Meta-Llama-3-8B-Instruct",
372
+ # torch_dtype=torch.float16, # Use half precision
373
+ # device_map="auto", # Automatic device mapping
374
+ # load_in_8bit=True # Load in 8-bit precision
375
+ # )
376
+ temperature = st.slider(
377
+ "Temperature",
378
+ 0.0, 1.0, 0.7,
379
+ help="Controls randomness in generation"
380
+ )
381
+
382
+ st.button("🔄 New Conversation", on_click=reset_conversation)
383
+ with st.container():
384
+ st.markdown(f"""
385
+ <div class="sidebar-info">
386
+ <h4>Current Model: {selected_model}</h4>
387
+ <p><em>Note: Generated content may be inaccurate or false. Check important info.</em></p>
388
+ </div>
389
+ """, unsafe_allow_html=True)
390
+
391
+ feedback_url = "https://docs.google.com/forms/d/e/1FAIpQLSdZ_5mwW-pjqXHgxR0xriyVeRhqdQKgb5c-foXlYAV55Rilsg/viewform?usp=header"
392
+ st.sidebar.markdown(
393
+ f'<a href="{feedback_url}" target="_blank"><button style="width: 100%;">Feedback Form</button></a>',
394
+ unsafe_allow_html=True
395
+ )
396
+
397
+ # Display conversation
398
+ for message in st.session_state.messages:
399
+ with st.chat_message(message["role"]):
400
+ st.markdown(message["content"])
401
+
402
+ # Main content
403
+ if 'task_choice' not in st.session_state:
404
+ col1, col2 = st.columns(2)
405
+ with col1:
406
+ if st.button("📝 Data Generation", key="gen_button", help="Generate new data"):
407
+ st.session_state.task_choice = "Data Generation"
408
+ with col2:
409
+ if st.button("🏷️ Data Labeling", key="label_button", help="Label existing data"):
410
+ st.session_state.task_choice = "Data Labeling"
411
+
412
+ if "task_choice" in st.session_state:
413
+ if st.session_state.task_choice == "Data Generation":
414
+ st.header("📝 Data Generation")
415
+
416
+ # 1. Domain selection
417
+ domain_selection = st.selectbox("Domain", [
418
+ "Restaurant reviews", "E-Commerce reviews", "News", "AG News", "Tourism", "Custom"
419
+ ])
420
+
421
+ # 2. Handle custom domain input
422
+ custom_domain_valid = True # Assume valid until proven otherwise
423
+
424
+ if domain_selection == "Custom":
425
+ domain = st.text_input("Specify custom domain")
426
+ if not domain.strip():
427
+ st.error("Please specify a domain name.")
428
+ custom_domain_valid = False
429
+ else:
430
+ domain = domain_selection
431
+
432
+ # Classification type selection
433
+ classification_type = st.selectbox(
434
+ "Classification Type",
435
+ ["Sentiment Analysis", "Binary Classification", "Multi-Class Classification"]
436
+ )
437
+ # Labels setup based on classification type
438
+ #labels = []
439
+ labels = []
440
+ labels_valid = False
441
+ errors = []
442
+
443
+ def validate_binary_labels(labels):
444
+ errors = []
445
+ normalized = [label.strip().lower() for label in labels]
446
+
447
+ if not labels[0].strip():
448
+ errors.append("First class name is required.")
449
+ if not labels[1].strip():
450
+ errors.append("Second class name is required.")
451
+ if normalized[0] == normalized[1] and all(normalized):
452
+ errors.append("Class names must be different.")
453
+ return errors
454
+
455
+ if classification_type == "Sentiment Analysis":
456
+ st.write("### Sentiment Analysis Labels (Fixed)")
457
+ col1, col2, col3 = st.columns(3)
458
+ with col1:
459
+ st.text_input("First class", "Positive", disabled=True)
460
+ with col2:
461
+ st.text_input("Second class", "Negative", disabled=True)
462
+ with col3:
463
+ st.text_input("Third class", "Neutral", disabled=True)
464
+ labels = ["Positive", "Negative", "Neutral"]
465
+
466
+ elif classification_type == "Binary Classification":
467
+ st.write("### Binary Classification Labels")
468
+ col1, col2 = st.columns(2)
469
+ with col1:
470
+ label_1 = st.text_input("First class", "Positive")
471
+ with col2:
472
+ label_2 = st.text_input("Second class", "Negative")
473
+
474
+ labels = [label_1, label_2]
475
+ errors = validate_binary_labels(labels)
476
+
477
+ if errors:
478
+ st.error("\n".join(errors))
479
+ else:
480
+ st.success("Binary class names are valid and unique!")
481
+
482
+
483
+ elif classification_type == "Multi-Class Classification":
484
+ st.write("### Multi-Class Classification Labels")
485
+
486
+ default_labels_by_domain = {
487
+ "News": ["Political", "Sports", "Entertainment", "Technology", "Business"],
488
+ "AG News": ["World", "Sports", "Business", "Sci/Tech"],
489
+ "Tourism": ["Accommodation", "Transportation", "Tourist Attractions",
490
+ "Food & Dining", "Local Experience", "Adventure Activities",
491
+ "Wellness & Spa", "Eco-Friendly Practices", "Family-Friendly",
492
+ "Luxury Tourism"],
493
+ "Restaurant reviews": ["Italian", "French", "American"],
494
+ "E-Commerce reviews": ["Mobile Phones & Accessories", "Laptops & Computers","Kitchen & Dining",
495
+ "Beauty & Personal Care", "Home & Furniture", "Clothing & Fashion",
496
+ "Shoes & Handbags", "Health & Wellness", "Electronics & Gadgets",
497
+ "Books & Stationery","Toys & Games", "Sports & Fitness",
498
+ "Grocery & Gourmet Food","Watches & Accessories", "Baby Products"]
499
+ }
500
+
501
+ num_classes = st.slider("Number of classes", 3, 15, 3)
502
+
503
+ # Get defaults for selected domain, or empty list
504
+ defaults = default_labels_by_domain.get(domain, [])
505
+
506
+ labels = []
507
+ errors = []
508
+ cols = st.columns(3)
509
+
510
+ for i in range(num_classes):
511
+ with cols[i % 3]:
512
+ default_value = defaults[i] if i < len(defaults) else ""
513
+ label_input = st.text_input(f"Class {i+1}", default_value)
514
+ normalized_label = label_input.strip().title()
515
+
516
+ if not normalized_label:
517
+ errors.append(f"Class {i+1} name is required.")
518
+ else:
519
+ labels.append(normalized_label)
520
+
521
+ # Check for duplicates (case-insensitive)
522
+ if len(labels) != len(set(labels)):
523
+ errors.append("Labels names must be unique (case-insensitive, normalized to Title Case).")
524
+
525
+ # Show validation results
526
+ if errors:
527
+ for error in errors:
528
+ st.error(error)
529
+ else:
530
+ st.success("All Labels names are valid and unique!")
531
+ labels_valid = not errors # Will be True only if there are no label errors
532
+
533
+ ##############
534
+ #new 22/4/2025
535
+ # add additional attributes
536
+ add_attributes = st.checkbox("Add additional attributes (optional)")
537
+ additional_attributes = []
538
+
539
+ if add_attributes:
540
+ num_attributes = st.slider("Number of attributes to add", 1, 5, 1)
541
+ for i in range(num_attributes):
542
+ st.markdown(f"#### Attribute {i+1}")
543
+ attr_name = st.text_input(f"Name of attribute {i+1}", key=f"attr_name_{i}")
544
+ attr_topics = st.text_input(f"Topics (comma-separated) for {attr_name}", key=f"attr_topics_{i}")
545
+ if attr_name and attr_topics:
546
+ topics_list = [topic.strip() for topic in attr_topics.split(",") if topic.strip()]
547
+ additional_attributes.append({"attribute": attr_name, "topics": topics_list})
548
+
549
+ ################
550
+
551
+ # Generation parameters
552
+ col1, col2 = st.columns(2)
553
+ with col1:
554
+ min_words = st.number_input("Min words", 1, 100, 20)
555
+ with col2:
556
+ max_words = st.number_input("Max words", min_words, 100, 50)
557
+
558
+ # Few-shot examples
559
+ use_few_shot = st.toggle("Use few-shot examples")
560
+ few_shot_examples = []
561
+ if use_few_shot:
562
+ num_examples = st.slider("Number of few-shot examples", 1, 10, 1)
563
+ for i in range(num_examples):
564
+ with st.expander(f"Example {i+1}"):
565
+ content = st.text_area(f"Content", key=f"few_shot_content_{i}")
566
+ label = st.selectbox(f"Label", labels, key=f"few_shot_label_{i}")
567
+ if content and label:
568
+ few_shot_examples.append({"content": content, "label": label})
569
+
570
+ num_to_generate = st.number_input("Number of examples", 1, 100, 10)
571
+ #sytem role after
572
+ # System role customization
573
+ #default_system_role = f"You are a professional {classification_type} expert, your role is to generate text examples for {domain} domain. Always generate unique diverse examples and do not repeat the generated data. The generated text should be between {min_words} to {max_words} words long."
574
+ # System role customization
575
+ default_system_role = (
576
+ f"You are a seasoned expert in {classification_type}, specializing in the {domain} domain. "
577
+ f" Your primary responsibility is to generate high-quality, diverse, and unique text examples "
578
+ f"tailored to this domain. Please ensure that each example adheres to the specified length "
579
+ f"requirements, ranging from {min_words} to {max_words} words, and avoid any repetition in the generated content."
580
+ )
581
+ system_role = st.text_area("Modify System Role (optional)",
582
+ value=default_system_role,
583
+ key="system_role_input")
584
+ st.session_state['system_role'] = system_role if system_role else default_system_role
585
+ # Labels initialization
586
+ #labels = []
587
+
588
+
589
+ user_prompt = st.text_area("User Prompt (optional)")
590
+
591
+ # Updated prompt template including system role
592
+ prompt_template = PromptTemplate(
593
+ input_variables=["system_role", "classification_type", "domain", "num_examples",
594
+ "min_words", "max_words", "labels", "user_prompt", "few_shot_examples", "additional_attributes"],
595
+ template=(
596
+ "{system_role}\n"
597
+ "- Use the following parameters:\n"
598
+ "- Generate {num_examples} examples\n"
599
+ "- Each example should be between {min_words} to {max_words} words long\n"
600
+ "- Use these labels: {labels}.\n"
601
+ "- Use the following additional attributes:\n"
602
+ "- {additional_attributes}\n"
603
+ "- Generate the examples in this format: 'Example text. Label: label'\n"
604
+ "- Do not include word counts or any additional information\n"
605
+ "- Always use your creativity and intelligence to generate unique and diverse text data\n"
606
+ "- In sentiment analysis, ensure that the sentiment classification is clearly identified as Positive, Negative, or Neutral. Do not leave the sentiment ambiguous.\n"
607
+ "- In binary sentiment analysis, classify text strictly as either Positive or Negative. Do not include or imply Neutral as an option.\n"
608
+ "- Write unique examples every time.\n"
609
+ "- DO NOT REPEAT your gnerated text. \n"
610
+ "- For each Output, describe it once and move to the next.\n"
611
+ "- List each Output only once, and avoid repeating details.\n"
612
+ "- Additional instructions: {user_prompt}\n\n"
613
+ "- Use the following examples as a reference in the generation process\n\n {few_shot_examples}. \n"
614
+ "- Think step by step, generate numbered examples, and check each newly generated example to ensure it has not been generated before. If it has, modify it"
615
+
616
+ )
617
+ )
618
+ # template=(
619
+ # "{system_role}\n"
620
+ # "- Use the following parameters:\n"
621
+ # "- Generate {num_examples} examples\n"
622
+ # "- Each example should be between {min_words} to {max_words} words long\n"
623
+ # "- Use these labels: {labels}.\n"
624
+ # "- Use the following additional attributes:\n"
625
+ # "{additional_attributes}\n"
626
+ # #"- Format each example like this: 'Example text. Label: [label]. Attribute1: [topic1]. Attribute2: [topic2]'\n"
627
+ # "- Generate the examples in this format: 'Example text. Label: label'\n"
628
+ # "- Additional instructions: {user_prompt}\n"
629
+ # "- Use these few-shot examples if provided:\n{few_shot_examples}\n"
630
+ # "- Think step by step and ensure examples are unique and not repeated."
631
+ # )
632
+ # )
633
+ ##########new 22/4/2025
634
+ formatted_attributes = "\n".join([
635
+ f"- {attr['attribute']}: {', '.join(attr['topics'])}" for attr in additional_attributes
636
+ ])
637
+ #######################
638
+
639
+ # Generate system prompt
640
+ system_prompt = prompt_template.format(
641
+ system_role=st.session_state['system_role'],
642
+ classification_type=classification_type,
643
+ domain=domain,
644
+ num_examples=num_to_generate,
645
+ min_words=min_words,
646
+ max_words=max_words,
647
+ labels=", ".join(labels),
648
+ user_prompt=user_prompt,
649
+ few_shot_examples="\n".join([f"{ex['content']}\nLabel: {ex['label']}" for ex in few_shot_examples]) if few_shot_examples else "",
650
+ additional_attributes=formatted_attributes
651
+ )
652
+
653
+
654
+ # Store system prompt in session state
655
+ st.session_state['system_prompt'] = system_prompt
656
+
657
+ # Display system prompt
658
+ st.write("System Prompt:")
659
+ st.text_area("Current System Prompt", value=st.session_state['system_prompt'],
660
+ height=400, disabled=True)
661
+
662
+
663
+ if st.button("🎯 Generate Examples"):
664
+ #
665
+ errors = []
666
+ if domain_selection == "Custom" and not domain.strip():
667
+ st.warning("Custom domain name is required.")
668
+ elif len(labels) != len(set(labels)):
669
+ st.warning("Class names must be unique.")
670
+ elif any(not lbl.strip() for lbl in labels):
671
+ st.warning("All class labels must be filled in.")
672
+ #else:
673
+ #st.success("Generating examples for domain: {domain}")
674
+
675
+ #if not custom_domain_valid:
676
+ #st.warning("Custom domain name is required.")
677
+ #elif not labels_valid:
678
+ #st.warning("Please fix the label errors before generating examples.")
679
+ #else:
680
+ # Proceed to generate examples
681
+ #st.success(f"Generating examples for domain: {domain}")
682
+
683
+ with st.spinner("Generating examples..."):
684
+ try:
685
+ stream = client.chat.completions.create(
686
+ model=selected_model,
687
+ messages=[{"role": "system", "content": st.session_state['system_prompt']}],
688
+ temperature=temperature,
689
+ stream=True,
690
+ #max_tokens=80000,
691
+ max_tokens=4000,
692
+ top_p=0.9,
693
+ # repetition_penalty=1.2,
694
+ #frequency_penalty=0.5, # Discourages frequent words
695
+ #presence_penalty=0.6,
696
+ )
697
+ #st.session_state['system_prompt'] = system_prompt
698
+ #new 24 march
699
+ st.session_state.messages.append({"role": "user", "content": system_prompt})
700
+ # # ####################
701
+ response = st.write_stream(stream)
702
+ st.session_state.messages.append({"role": "assistant", "content": response})
703
+ # Initialize session state variables if they don't exist
704
+ if 'system_prompt' not in st.session_state:
705
+ st.session_state.system_prompt = system_prompt
706
+
707
+ if 'response' not in st.session_state:
708
+ st.session_state.response = response
709
+
710
+ if 'generated_examples' not in st.session_state:
711
+ st.session_state.generated_examples = []
712
+
713
+ if 'generated_examples_csv' not in st.session_state:
714
+ st.session_state.generated_examples_csv = None
715
+
716
+ if 'generated_examples_json' not in st.session_state:
717
+ st.session_state.generated_examples_json = None
718
+
719
+ # Parse response and generate examples list
720
+ examples_list = []
721
+ for line in response.split('\n'):
722
+ if line.strip():
723
+ parts = line.rsplit('Label:', 1)
724
+ if len(parts) == 2:
725
+ text = parts[0].strip()
726
+ label = parts[1].strip()
727
+ if text and label:
728
+ examples_list.append({
729
+ 'text': text,
730
+ 'label': label,
731
+ 'system_prompt': st.session_state.system_prompt,
732
+ 'system_role': st.session_state.system_role,
733
+ 'task_type': 'Data Generation',
734
+ 'Use few-shot example?': 'Yes' if use_few_shot else 'No',
735
+ })
736
+
737
+ # example_dict = {
738
+ # 'text': text,
739
+ # 'label': label,
740
+ # 'system_prompt': st.session_state.system_prompt,
741
+ # 'system_role': st.session_state.system_role,
742
+ # 'task_type': 'Data Generation',
743
+ # 'Use few-shot example?': 'Yes' if use_few_shot else 'No',
744
+ # }
745
+ # for attr in additional_attributes:
746
+ # example_dict[attr['attribute']] = random.choice(attr['topics'])
747
+
748
+ # examples_list.append(example_dict)
749
+
750
+
751
+ if examples_list:
752
+ # Update session state with new data
753
+ st.session_state.generated_examples = examples_list
754
+
755
+ # Generate CSV and JSON data
756
+ df = pd.DataFrame(examples_list)
757
+ st.session_state.generated_examples_csv = df.to_csv(index=False).encode('utf-8')
758
+ st.session_state.generated_examples_json = json.dumps(examples_list, indent=2).encode('utf-8')
759
+
760
+ # Vertical layout with centered "or" between buttons
761
+ st.download_button(
762
+ "📥 Download Generated Examples (CSV)",
763
+ st.session_state.generated_examples_csv,
764
+ "generated_examples.csv",
765
+ "text/csv",
766
+ key='download-csv-persistent'
767
+ )
768
+
769
+ # Add space and center the "or"
770
+ st.markdown("""
771
+ <div style='text-align: left; margin:15px 0; font-weight: 600; color: #666;'>. . . . . . or</div>
772
+ """, unsafe_allow_html=True)
773
+
774
+ st.download_button(
775
+ "📥 Download Generated Examples (JSON)",
776
+ st.session_state.generated_examples_json,
777
+ "generated_examples.json",
778
+ "application/json",
779
+ key='download-json-persistent'
780
+ )
781
+ # # Display the labeled examples
782
+ # st.markdown("##### 📋 Labeled Examples Preview")
783
+ # st.dataframe(df, use_container_width=True)
784
+
785
+ if st.button("Continue"):
786
+ if follow_up == "Generate more examples":
787
+ st.experimental_rerun()
788
+ elif follow_up == "Data Labeling":
789
+ st.session_state.task_choice = "Data Labeling"
790
+ st.experimental_rerun()
791
+
792
+ except Exception as e:
793
+ st.error("An error occurred during generation.")
794
+ st.error(f"Details: {e}")
795
+
796
+
797
+ # Lableing Process
798
+ elif st.session_state.task_choice == "Data Labeling":
799
+ st.header("🏷️ Data Labeling")
800
+
801
+ domain_selection = st.selectbox("Domain", ["Restaurant reviews", "E-Commerce reviews", "News", "AG News", "Tourism", "Custom"])
802
+ # 2. Handle custom domain input
803
+ custom_domain_valid = True # Assume valid until proven otherwise
804
+
805
+ if domain_selection == "Custom":
806
+ domain = st.text_input("Specify custom domain")
807
+ if not domain.strip():
808
+ st.error("Please specify a domain name.")
809
+ custom_domain_valid = False
810
+ else:
811
+ domain = domain_selection
812
+
813
+
814
+ # Classification type selection
815
+ classification_type = st.selectbox(
816
+ "Classification Type",
817
+ ["Sentiment Analysis", "Binary Classification", "Multi-Class Classification", "Named Entity Recognition (NER)"]
818
+ )
819
+ #NNew edit
820
+ # Labels setup based on classification type
821
+ labels = []
822
+ labels_valid = False
823
+ errors = []
824
+
825
+ if classification_type == "Sentiment Analysis":
826
+ st.write("### Sentiment Analysis Labels (Fixed)")
827
+ col1, col2, col3 = st.columns(3)
828
+ with col1:
829
+ label_1 = st.text_input("First class", "Positive", disabled=True)
830
+ with col2:
831
+ label_2 = st.text_input("Second class", "Negative", disabled=True)
832
+ with col3:
833
+ label_3 = st.text_input("Third class", "Neutral", disabled=True)
834
+ labels = ["Positive", "Negative", "Neutral"]
835
+
836
+
837
+ elif classification_type == "Binary Classification":
838
+ st.write("### Binary Classification Labels")
839
+ col1, col2 = st.columns(2)
840
+
841
+ with col1:
842
+ label_1 = st.text_input("First class", "Positive")
843
+ with col2:
844
+ label_2 = st.text_input("Second class", "Negative")
845
+
846
+ errors = []
847
+ labels = [label_1.strip(), label_2.strip()]
848
+
849
+
850
+ # Strip and lower-case labels for validation
851
+ label_1 = labels[0].strip()
852
+ label_2 = labels[1].strip()
853
+
854
+ # Check for empty class names
855
+ if not label_1:
856
+ errors.append("First class name is required.")
857
+ if not label_2:
858
+ errors.append("Second class name is required.")
859
+
860
+ # Check for duplicates (case insensitive)
861
+ if label_1.lower() == label_2.lower() and label_1 and label_2:
862
+ errors.append("Class names must be different.")
863
+
864
+ # Show errors or success
865
+ if errors:
866
+ for error in errors:
867
+ st.error(error)
868
+ else:
869
+ st.success("Binary class names are valid and unique!")
870
+
871
+
872
+ elif classification_type == "Multi-Class Classification":
873
+ st.write("### Multi-Class Classification Labels")
874
+
875
+ default_labels_by_domain = {
876
+ "News": ["Political", "Sports", "Entertainment", "Technology", "Business"],
877
+ "AG News": ["World", "Sports", "Business", "Sci/Tech"],
878
+ "Tourism": ["Accommodation", "Transportation", "Tourist Attractions",
879
+ "Food & Dining", "Local Experience", "Adventure Activities",
880
+ "Wellness & Spa", "Eco-Friendly Practices", "Family-Friendly",
881
+ "Luxury Tourism"],
882
+ "Restaurant reviews": ["Italian", "French", "American"],
883
+ "E-Commerce reviews": ["Mobile Phones & Accessories", "Laptops & Computers","Kitchen & Dining",
884
+ "Beauty & Personal Care", "Home & Furniture", "Clothing & Fashion",
885
+ "Shoes & Handbags", "Health & Wellness", "Electronics & Gadgets",
886
+ "Books & Stationery","Toys & Games", "Sports & Fitness",
887
+ "Grocery & Gourmet Food","Watches & Accessories", "Baby Products"]
888
+ }
889
+
890
+
891
+
892
+ # Ask user how many classes they want to define
893
+ num_classes = st.slider("Select the number of classes (labels)", min_value=3, max_value=10, value=3)
894
+
895
+ # Use default labels based on selected domain, if available
896
+ defaults = default_labels_by_domain.get(domain, [])
897
+
898
+ labels = []
899
+ errors = []
900
+ cols = st.columns(3) # For nicely arranged label inputs
901
+
902
+ for i in range(num_classes):
903
+ with cols[i % 3]: # Distribute inputs across columns
904
+ default_value = defaults[i] if i < len(defaults) else ""
905
+ label_input = st.text_input(f"Label {i + 1}", default_value)
906
+ normalized_label = label_input.strip().title()
907
+
908
+ if not normalized_label:
909
+ errors.append(f"Label {i + 1} is required.")
910
+ else:
911
+ labels.append(normalized_label)
912
+
913
+ # Check for duplicates (case-insensitive)
914
+ normalized_set = {label.lower() for label in labels}
915
+ if len(labels) != len(normalized_set):
916
+ errors.append("Label names must be unique (case-insensitive).")
917
+
918
+ # Show validation results
919
+ if errors:
920
+ for error in errors:
921
+ st.error(error)
922
+ else:
923
+ st.success("All label names are valid and unique!")
924
+
925
+ labels_valid = not errors # True if no validation errors
926
+
927
+ elif classification_type == "Named Entity Recognition (NER)":
928
+ # # NER entity options
929
+ # ner_entities = [
930
+ # "PERSON - Names of people, fictional characters, historical figures",
931
+ # "ORG - Companies, institutions, agencies, teams",
932
+ # "LOC - Physical locations (mountains, oceans, etc.)",
933
+ # "GPE - Countries, cities, states, political regions",
934
+ # "DATE - Calendar dates, years, centuries",
935
+ # "TIME - Times, durations",
936
+ # "MONEY - Monetary values with currency"
937
+ # ]
938
+ # selected_entities = st.multiselect(
939
+ # "Select entities to recognize",
940
+ # ner_entities,
941
+ # default=["PERSON - Names of people, fictional characters, historical figures",
942
+ # "ORG - Companies, institutions, agencies, teams",
943
+ # "LOC - Physical locations (mountains, oceans, etc.)",
944
+ # "GPE - Countries, cities, states, political regions",
945
+ # "DATE - Calendar dates, years, centuries",
946
+ # "TIME - Times, durations",
947
+ # "MONEY - Monetary values with currency"],
948
+ # key="ner_entity_selection"
949
+ # )
950
+ #new 22/4/2025
951
+ #if classification_type == "Named Entity Recognition (NER)":
952
+ use_few_shot = True
953
+ #new 22/4/2025
954
+ few_shot_examples = [
955
+ {"content": "Mount Everest is the tallest mountain in the world.", "label": "LOC: Mount Everest"},
956
+ {"content": "The President of the United States visited Paris last summer.", "label": "GPE: United States, GPE: Paris"},
957
+ {"content": "Amazon is expanding its offices in Berlin.", "label": "ORG: Amazon, GPE: Berlin"},
958
+ {"content": "J.K. Rowling wrote the Harry Potter books.", "label": "PERSON: J.K. Rowling"},
959
+ {"content": "Apple was founded in California in 1976.", "label": "ORG: Apple, GPE: California, DATE: 1976"},
960
+ {"content": "The Nile is the longest river in Africa.", "label": "LOC: Nile, GPE: Africa"},
961
+ {"content": "He arrived at 3 PM for the meeting.", "label": "TIME: 3 PM"},
962
+ {"content": "She bought the dress for $200.", "label": "MONEY: $200"},
963
+ {"content": "The event is scheduled for July 4th.", "label": "DATE: July 4th"},
964
+ {"content": "The World Health Organization is headquartered in Geneva.", "label": "ORG: World Health Organization, GPE: Geneva"}
965
+ ]
966
+ ###########
967
+
968
+ st.write("### Named Entity Recognition (NER) Entities")
969
+
970
+ # Predefined standard entities
971
+ ner_entities = [
972
+ "PERSON - Names of people, fictional characters, historical figures",
973
+ "ORG - Companies, institutions, agencies, teams",
974
+ "LOC - Physical locations (mountains, oceans, etc.)",
975
+ "GPE - Countries, cities, states, political regions",
976
+ "DATE - Calendar dates, years, centuries",
977
+ "TIME - Times, durations",
978
+ "MONEY - Monetary values with currency"
979
+ ]
980
+
981
+ # User can add custom NER types
982
+ custom_ner_entities = []
983
+ if st.checkbox("Add custom NER entities?"):
984
+ num_custom_ner = st.slider("Number of custom NER entities", 1, 10, 1)
985
+ for i in range(num_custom_ner):
986
+ st.markdown(f"#### Custom Entity {i+1}")
987
+ custom_type = st.text_input(f"Entity type {i+1}", key=f"custom_ner_type_{i}")
988
+ custom_description = st.text_input(f"Description for {custom_type}", key=f"custom_ner_desc_{i}")
989
+ if custom_type and custom_description:
990
+ custom_ner_entities.append(f"{custom_type.upper()} - {custom_description}")
991
+
992
+ # Combine built-in and custom NERs
993
+ all_ner_options = ner_entities + custom_ner_entities
994
+
995
+ selected_entities = st.multiselect(
996
+ "Select entities to recognize",
997
+ all_ner_options,
998
+ default=ner_entities
999
+ )
1000
+
1001
+ # Extract entity type names (before the dash)
1002
+ labels = [entity.split(" - ")[0].strip() for entity in selected_entities]
1003
+
1004
+ if not labels:
1005
+ st.warning("Please select at least one entity type.")
1006
+ labels = ["PERSON"]
1007
+
1008
+ ##########
1009
+
1010
+ # # Extract just the entity type (before the dash)
1011
+ # labels = [entity.split(" - ")[0] for entity in selected_entities]
1012
+
1013
+ # if not labels:
1014
+ # st.warning("Please select at least one entity type")
1015
+ # labels = ["PERSON"] # Default if nothing selected
1016
+
1017
+
1018
+
1019
+
1020
+
1021
+ #NNew edit
1022
+ # elif classification_type == "Multi-Class Classification":
1023
+ # st.write("### Multi-Class Classification Labels")
1024
+
1025
+ # default_labels_by_domain = {
1026
+ # "News": ["Political", "Sports", "Entertainment", "Technology", "Business"],
1027
+ # "AG News": ["World", "Sports", "Business", "Sci/Tech"],
1028
+ # "Tourism": ["Accommodation", "Transportation", "Tourist Attractions",
1029
+ # "Food & Dining", "Local Experience", "Adventure Activities",
1030
+ # "Wellness & Spa", "Eco-Friendly Practices", "Family-Friendly",
1031
+ # "Luxury Tourism"],
1032
+ # "Restaurant reviews": ["Italian", "French", "American"]
1033
+ # }
1034
+ # num_classes = st.slider("Number of classes", 3, 10, 3)
1035
+
1036
+ # # Get defaults for selected domain, or empty list
1037
+ # defaults = default_labels_by_domain.get(domain, [])
1038
+
1039
+ # labels = []
1040
+ # errors = []
1041
+ # cols = st.columns(3)
1042
+
1043
+ # for i in range(num_classes):
1044
+ # with cols[i % 3]:
1045
+ # default_value = defaults[i] if i < len(defaults) else ""
1046
+ # label_input = st.text_input(f"Class {i+1}", default_value)
1047
+ # normalized_label = label_input.strip().title()
1048
+
1049
+ # if not normalized_label:
1050
+ # errors.append(f"Class {i+1} name is required.")
1051
+ # else:
1052
+ # labels.append(normalized_label)
1053
+
1054
+ # # Check for duplicates (case-insensitive)
1055
+ # if len(labels) != len(set(labels)):
1056
+ # errors.append("Labels names must be unique (case-insensitive, normalized to Title Case).")
1057
+
1058
+ # # Show validation results
1059
+ # if errors:
1060
+ # for error in errors:
1061
+ # st.error(error)
1062
+ # else:
1063
+ # st.success("All Labels names are valid and unique!")
1064
+ # labels_valid = not errors # Will be True only if there are no label errors
1065
+
1066
+
1067
+
1068
+
1069
+ # else:
1070
+ # num_classes = st.slider("Number of classes", 3, 23, 3, key="label_num_classes")
1071
+ # labels = []
1072
+ # cols = st.columns(3)
1073
+ # for i in range(num_classes):
1074
+ # with cols[i % 3]:
1075
+ # label = st.text_input(f"Class {i+1}", f"Class_{i+1}", key=f"label_class_{i}")
1076
+ # labels.append(label)
1077
+
1078
+ use_few_shot = st.toggle("Use few-shot examples for labeling")
1079
+ few_shot_examples = []
1080
+ if use_few_shot:
1081
+ num_few_shot = st.slider("Number of few-shot examples", 1, 10, 1)
1082
+ for i in range(num_few_shot):
1083
+ with st.expander(f"Few-shot Example {i+1}"):
1084
+ content = st.text_area(f"Content", key=f"label_few_shot_content_{i}")
1085
+ label = st.selectbox(f"Label", labels, key=f"label_few_shot_label_{i}")
1086
+ if content and label:
1087
+ few_shot_examples.append(f"{content}\nLabel: {label}")
1088
+
1089
+ num_examples = st.number_input("Number of examples to classify", 1, 100, 1)
1090
+
1091
+ examples_to_classify = []
1092
+ if num_examples <= 10:
1093
+ for i in range(num_examples):
1094
+ example = st.text_area(f"Example {i+1}", key=f"example_{i}")
1095
+ if example:
1096
+ examples_to_classify.append(example)
1097
+ else:
1098
+ examples_text = st.text_area(
1099
+ "Enter examples (one per line)",
1100
+ height=300,
1101
+ help="Enter each example on a new line"
1102
+ )
1103
+ if examples_text:
1104
+ examples_to_classify = [ex.strip() for ex in examples_text.split('\n') if ex.strip()]
1105
+ if len(examples_to_classify) > num_examples:
1106
+ examples_to_classify = examples_to_classify[:num_examples]
1107
+
1108
+ #New Wedyan
1109
+ #default_system_role = f"You are a professional {classification_type} expert, your role is to classify the provided text examples for {domain} domain."
1110
+ # System role customization
1111
+ default_system_role = (f"You are a highly skilled {classification_type} expert."
1112
+ f" Your task is to accurately classify the provided text examples within the {domain} domain."
1113
+ f" Ensure that all classifications are precise, context-aware, and aligned with domain-specific standards and best practices."
1114
+ )
1115
+ system_role = st.text_area("Modify System Role (optional)",
1116
+ value=default_system_role,
1117
+ key="system_role_input")
1118
+ st.session_state['system_role'] = system_role if system_role else default_system_role
1119
+ # Labels initialization
1120
+ #labels = []
1121
+ ####
1122
+
1123
+ user_prompt = st.text_area("User prompt (optional)", key="label_instructions")
1124
+
1125
+ few_shot_text = "\n\n".join(few_shot_examples) if few_shot_examples else ""
1126
+ examples_text = "\n".join([f"{i+1}. {ex}" for i, ex in enumerate(examples_to_classify)])
1127
+
1128
+ # Customize prompt template based on classification type
1129
+ if classification_type == "Named Entity Recognition (NER)":
1130
+ # label_prompt_template = PromptTemplate(
1131
+ # input_variables=["system_role", "labels", "few_shot_examples", "examples", "domain", "user_prompt"],
1132
+ # template=(
1133
+ # "{system_role}\n"
1134
+ # #"- You are a professional Named Entity Recognition (NER) expert in {domain} domain. Your role is to identify and extract the following entity types: {labels}.\n"
1135
+ # "- For each text example provided, identify all entities of the requested types.\n"
1136
+ # "- Use the following entities: {labels}.\n"
1137
+ # "- Return each example followed by the entities you found in this format: 'Example text.\n \n Entities:\n [ENTITY_TYPE: entity text\n\n, ENTITY_TYPE: entity text\n\n, ...] or [No entities found]'\n"
1138
+ # "- If no entities of the requested types are found, indicate 'No entities found' in this text.\n"
1139
+ # "- Be precise about entity boundaries - don't include unnecessary words.\n"
1140
+ # "- Do not provide any additional information or explanations.\n"
1141
+ # "- Additional instructions:\n {user_prompt}\n\n"
1142
+ # "- Use user few-shot examples as guidance if provided:\n{few_shot_examples}\n\n"
1143
+ # "- Examples to analyze:\n{examples}\n\n"
1144
+ # "Output:\n"
1145
+ # )
1146
+ # )
1147
+ #new 22/4/2025
1148
+ # label_prompt_template = PromptTemplate(
1149
+ # input_variables=["system_role", "labels", "few_shot_examples", "examples", "domain", "user_prompt"],
1150
+ # template=(
1151
+ # "{system_role}\n"
1152
+ # "- You are performing Named Entity Recognition (NER) in the domain of {domain}.\n"
1153
+ # "- Use the following entity types: {labels}.\n\n"
1154
+ # "### Reasoning Steps:\n"
1155
+ # "1. Read the example carefully.\n"
1156
+ # "2. For each named entity mentioned, determine its meaning and role in the sentence.\n"
1157
+ # "3. Think about the **context**: Is it a physical location (LOC)? A geopolitical region (GPE)? A person (PERSON)?\n"
1158
+ # "4. Based on the definition of each label, assign the most **specific and correct** label.\n\n"
1159
+ # "For example:\n"
1160
+ # "- 'Mount Everest' → LOC (it's a mountain)\n"
1161
+ # "- 'France' → GPE (it's a country)\n"
1162
+ # "- 'Microsoft' → ORG\n"
1163
+ # "- 'John Smith' → PERSON\n\n"
1164
+ # "- Return each example followed by the entities you found in this format:\n"
1165
+ # "'Example text.'\nEntities: [ENTITY_TYPE: entity text, ENTITY_TYPE: entity text, ...] or [No entities found]\n"
1166
+ # "- If no entities of the requested types are found, return 'No entities found'.\n"
1167
+ # "- Be precise about entity boundaries - don't include extra words.\n"
1168
+ # "- Do not explain or justify your answers.\n\n"
1169
+ # "Additional instructions:\n{user_prompt}\n\n"
1170
+ # "Few-shot examples:\n{few_shot_examples}\n\n"
1171
+ # "Examples to label:\n{examples}\n"
1172
+ # "Output:\n"
1173
+ # )
1174
+ #)
1175
+ # label_prompt_template = PromptTemplate(
1176
+ # input_variables=["system_role", "labels", "few_shot_examples", "examples", "domain", "user_prompt"],
1177
+ # template=(
1178
+ # "{system_role}\n"
1179
+ # "- You are an expert at Named Entity Recognition (NER) for domain: {domain}.\n"
1180
+ # "- Use these entity types: {labels}.\n\n"
1181
+ # "### Output Format:\n"
1182
+ # # "Return each example followed by the entities you found in this format: 'Example text.\n Entities:\n [ENTITY_TYPE: entity text\n\"
1183
+ # "Return each example followed by the entities you found in this format: 'Example text.\n 'Entity types:\n "Then group the entities under each label like this:\n" "
1184
+ # #"Then Start with this line exactly: 'Entity types\n'\n"
1185
+ # #"Then group the entities under each label like this:\n"
1186
+ # "\n PERSON – Angela Merkel, John Smith\n\n"
1187
+ # "\ ORG – Google, United Nations\n\n"
1188
+ # "\n DATE – January 1st, 2023\n\n"
1189
+ # "\n ... and so on.\n\n"
1190
+ # "If entity {labels} not found, do not write it in your response\n"
1191
+ # "- Do NOT output them inline after the text.\n"
1192
+ # "- Do NOT repeat the sentence.\n"
1193
+ # "- If no entities are found for a type, skip it.\n"
1194
+ # "- Keep the format consistent.\n\n"
1195
+ # "User Instructions:\n{user_prompt}\n\n"
1196
+ # "Few-shot Examples:\n{few_shot_examples}\n\n"
1197
+ # "Examples to analyze:\n{examples}"
1198
+ # )
1199
+ # )
1200
+
1201
+
1202
+ label_prompt_template = PromptTemplate(
1203
+ input_variables=["system_role", "labels", "few_shot_examples", "examples", "domain", "user_prompt"],
1204
+ template=(
1205
+ "{system_role}\n"
1206
+ "- You are an expert at Named Entity Recognition (NER) for domain: {domain}.\n"
1207
+ "- Use these entity types: {labels}.\n\n"
1208
+ "### Output Format:\n"
1209
+ "Return each example followed by the entities you found in this format:\n"
1210
+ "'Example text.\nEntity types:\n"
1211
+ "Then group the entities under each label like this:\n"
1212
+ "\nPERSON – Angela Merkel, John Smith\n"
1213
+ "ORG – Google, United Nations\n"
1214
+ "DATE – January 1st, 2023\n"
1215
+ "... and so on.\n\n"
1216
+ "Each new entities group should be in a new line.\n"
1217
+ "If entity type {labels} is not found, do not write it in your response.\n"
1218
+ "- Do NOT output them inline after the text.\n"
1219
+ "- Do NOT repeat the sentence.\n"
1220
+ "- If no entities are found for a type, skip it.\n"
1221
+ "- Keep the format consistent.\n\n"
1222
+ "User Instructions:\n{user_prompt}\n\n"
1223
+ "Few-shot Examples:\n{few_shot_examples}\n\n"
1224
+ "Examples to analyze:\n{examples}"
1225
+ )
1226
+ )
1227
+
1228
+ #######
1229
+ else:
1230
+ label_prompt_template = PromptTemplate(
1231
+
1232
+ input_variables=["system_role", "classification_type", "labels", "few_shot_examples", "examples","domain", "user_prompt"],
1233
+ template=(
1234
+ #"- Let'\s think step by step:"
1235
+ "{system_role}\n"
1236
+ # "- You are a professional {classification_type} expert in {domain} domain. Your role is to classify the following examples using these labels: {labels}.\n"
1237
+ "- Use the following instructions:\n"
1238
+ "- Use the following labels: {labels}.\n"
1239
+ "- Return the classified text followed by the label in this format: 'text. Label: [label]'\n"
1240
+ "- Do not provide any additional information or explanations\n"
1241
+ "- User prompt:\n {user_prompt}\n\n"
1242
+ "- Use user provided examples as guidence in the classification process:\n\n {few_shot_examples}\n"
1243
+ "- Examples to classify:\n{examples}\n\n"
1244
+ "- Think step by step then classify the examples"
1245
+ #"Output:\n"
1246
+ ))
1247
+
1248
+ # Check if few_shot_examples is already a formatted string
1249
+ # Check if few_shot_examples is already a formatted string
1250
+ if isinstance(few_shot_examples, str):
1251
+ formatted_few_shot = few_shot_examples
1252
+ # If it's a list of already formatted strings
1253
+ elif isinstance(few_shot_examples, list) and all(isinstance(ex, str) for ex in few_shot_examples):
1254
+ formatted_few_shot = "\n".join(few_shot_examples)
1255
+ # If it's a list of dictionaries with 'content' and 'label' keys
1256
+ elif isinstance(few_shot_examples, list) and all(isinstance(ex, dict) and 'content' in ex and 'label' in ex for ex in few_shot_examples):
1257
+ formatted_few_shot = "\n".join([f"{ex['content']}\nLabel: {ex['label']}" for ex in few_shot_examples])
1258
+ else:
1259
+ formatted_few_shot = ""
1260
+ # #new 22/4/2025
1261
+ # few_shot_examples = [
1262
+ # {"content": "Mount Everest is the tallest mountain in the world.", "label": "LOC: Mount Everest"},
1263
+ # {"content": "The President of the United States visited Paris last summer.", "label": "GPE: United States, GPE: Paris"},
1264
+ # {"content": "Amazon is expanding its offices in Berlin.", "label": "ORG: Amazon, GPE: Berlin"},
1265
+ # {"content": "J.K. Rowling wrote the Harry Potter books.", "label": "PERSON: J.K. Rowling"},
1266
+ # {"content": "Apple was founded in California in 1976.", "label": "ORG: Apple, GPE: California, DATE: 1976"},
1267
+ # {"content": "The Nile is the longest river in Africa.", "label": "LOC: Nile, GPE: Africa"},
1268
+ # {"content": "He arrived at 3 PM for the meeting.", "label": "TIME: 3 PM"},
1269
+ # {"content": "She bought the dress for $200.", "label": "MONEY: $200"},
1270
+ # {"content": "The event is scheduled for July 4th.", "label": "DATE: July 4th"},
1271
+ # {"content": "The World Health Organization is headquartered in Geneva.", "label": "ORG: World Health Organization, GPE: Geneva"}
1272
+ # ]
1273
+ # ###########
1274
+ # new 22/4/2025
1275
+ #formatted_few_shot = "\n".join([f"{ex['content']}\nEntities: [{ex['label']}]" for ex in few_shot_examples])
1276
+ formatted_few_shot = "\n\n".join([f"{ex['content']}\n\nEntity types\n{ex['label']}" for ex in few_shot_examples])
1277
+
1278
+ ###########
1279
+ system_prompt = label_prompt_template.format(
1280
+ system_role=st.session_state['system_role'],
1281
+ classification_type=classification_type,
1282
+ domain=domain,
1283
+ examples="\n".join(examples_to_classify),
1284
+ labels=", ".join(labels),
1285
+ user_prompt=user_prompt,
1286
+ few_shot_examples=formatted_few_shot
1287
+ )
1288
+
1289
+ # Step 2: Store the system_prompt in st.session_state
1290
+ st.session_state['system_prompt'] = system_prompt
1291
+ #::contentReference[oaicite:0]{index=0}
1292
+ st.write("System Prompt:")
1293
+ #st.code(system_prompt)
1294
+ #st.code(st.session_state['system_prompt'])
1295
+ st.text_area("System Prompt", value=st.session_state['system_prompt'], height=300, max_chars=None, key=None, help=None, disabled=True)
1296
+
1297
+
1298
+
1299
+ if st.button("🏷️ Label Data"):
1300
+ if examples_to_classify:
1301
+ with st.spinner("Labeling data..."):
1302
+ #Generate the system prompt based on classification type
1303
+ if classification_type == "Named Entity Recognition (NER)":
1304
+ system_prompt = label_prompt_template.format(
1305
+ system_role=st.session_state['system_role'],
1306
+ labels=", ".join(labels),
1307
+ domain = domain,
1308
+ few_shot_examples=few_shot_text,
1309
+ examples=examples_text,
1310
+ user_prompt=user_prompt
1311
+ #new
1312
+ #'Use few-shot example?': 'Yes' if use_few_shot else 'No',
1313
+ )
1314
+ # if classification_type == "Named Entity Recognition (NER)":
1315
+ # # Step 1: Split the full response by example
1316
+ # raw_outputs = [block.strip() for block in response.strip().split("Entity types") if block.strip()]
1317
+ # inputs = [ex.strip() for ex in examples_to_classify]
1318
+
1319
+ # # Step 2: Match inputs with NER output blocks
1320
+ # labeled_examples = []
1321
+ # for i, (text, output_block) in enumerate(zip(inputs, raw_outputs)):
1322
+ # labeled_examples.append({
1323
+ # 'text': text,
1324
+ # 'entities': f"Entity types\n{output_block.strip()}",
1325
+ # 'system_prompt': st.session_state.system_prompt,
1326
+ # 'system_role': st.session_state.system_role,
1327
+ # 'task_type': 'Named Entity Recognition (NER)',
1328
+ # 'Use few-shot example?': 'Yes' if use_few_shot else 'No',
1329
+ # })
1330
+
1331
+ # if classification_type == "Named Entity Recognition (NER)":
1332
+ # # Step 1: Split the full response by example
1333
+ # raw_outputs = [block.strip() for block in response.strip().split("Entity types") if block.strip()]
1334
+ # inputs = [ex.strip() for ex in examples_to_classify]
1335
+
1336
+ # # Step 2: Match inputs with NER output blocks
1337
+ # labeled_examples = []
1338
+ # for i, (text, output_block) in enumerate(zip(inputs, raw_outputs)):
1339
+ # labeled_examples.append({
1340
+ # 'text': text,
1341
+ # 'entities': f"Entity types\n{output_block.strip()}",
1342
+ # 'system_prompt': st.session_state.system_prompt,
1343
+ # 'system_role': st.session_state.system_role,
1344
+ # 'task_type': 'Named Entity Recognition (NER)',
1345
+ # 'Use few-shot example?': 'Yes' if use_few_shot else 'No',
1346
+ # })
1347
+
1348
+
1349
+ # import re
1350
+
1351
+ # if classification_type == "Named Entity Recognition (NER)":
1352
+ # # Use regex to split on "Entity types" while keeping it attached to each block
1353
+ # blocks = re.split(r"(Entity types)", response.strip())
1354
+
1355
+ # # Recombine 'Entity types' with each block after splitting
1356
+ # raw_outputs = [
1357
+ # (blocks[i] + blocks[i+1]).strip()
1358
+ # for i in range(1, len(blocks) - 1, 2)
1359
+ # ]
1360
+
1361
+ # inputs = [ex.strip() for ex in examples_to_classify]
1362
+
1363
+ # labeled_examples = []
1364
+ # for i, (text, output_block) in enumerate(zip(inputs, raw_outputs)):
1365
+ # labeled_examples.append({
1366
+ # 'text': text,
1367
+ # 'entities': output_block,
1368
+ # 'system_prompt': st.session_state.system_prompt,
1369
+ # 'system_role': st.session_state.system_role,
1370
+ # 'task_type': 'Named Entity Recognition (NER)',
1371
+ # 'Use few-shot example?': 'Yes' if use_few_shot else 'No',
1372
+ # })
1373
+
1374
+
1375
+ else:
1376
+ system_prompt = label_prompt_template.format(
1377
+ classification_type=classification_type,
1378
+ system_role=st.session_state['system_role'],
1379
+ domain = domain,
1380
+ labels=", ".join(labels),
1381
+ few_shot_examples=few_shot_text,
1382
+ examples=examples_text,
1383
+ user_prompt=user_prompt
1384
+ )
1385
+ try:
1386
+ stream = client.chat.completions.create(
1387
+ model=selected_model,
1388
+ messages=[{"role": "system", "content": system_prompt}],
1389
+ temperature=temperature,
1390
+ stream=True,
1391
+ #max_tokens=20000,
1392
+ max_tokens=4000,
1393
+ top_p = 0.9,
1394
+
1395
+ )
1396
+ #new 24 March
1397
+ # Append user message
1398
+ st.session_state.messages.append({"role": "user", "content": system_prompt})
1399
+ #################
1400
+ response = st.write_stream(stream)
1401
+ st.session_state.messages.append({"role": "assistant", "content": response})
1402
+ # Display the labeled examples
1403
+ # # Optional: If you want to add it as a chat-style message log
1404
+ # preview_str = st.session_state.labeled_preview.to_markdown(index=False)
1405
+ # st.session_state.messages.append({"role": "assistant", "content": f"Here is a preview of the labeled examples:\n\n{preview_str}"})
1406
+
1407
+
1408
+ # # Stream response and append assistant message
1409
+ # #14/4/2024
1410
+ # response = st.write_stream(stream)
1411
+ # st.session_state.messages.append({"role": "assistant", "content": response})
1412
+
1413
+ # Initialize session state variables if they don't exist
1414
+ if 'system_prompt' not in st.session_state:
1415
+ st.session_state.system_prompt = system_prompt
1416
+
1417
+ if 'response' not in st.session_state:
1418
+ st.session_state.response = response
1419
+
1420
+ if 'generated_examples' not in st.session_state:
1421
+ st.session_state.generated_examples = []
1422
+
1423
+ if 'generated_examples_csv' not in st.session_state:
1424
+ st.session_state.generated_examples_csv = None
1425
+
1426
+ if 'generated_examples_json' not in st.session_state:
1427
+ st.session_state.generated_examples_json = None
1428
+
1429
+
1430
+
1431
+
1432
+ # Save labeled examples to CSV
1433
+ #new 14/4/2025
1434
+ #labeled_examples = []
1435
+ # if classification_type == "Named Entity Recognition (NER)":
1436
+ # labeled_examples = []
1437
+ # for line in response.split('\n'):
1438
+ # if line.strip():
1439
+ # parts = line.rsplit('Entities:', 1)
1440
+ # if len(parts) == 2:
1441
+ # text = parts[0].strip()
1442
+ # entities = parts[1].strip()
1443
+ # if text and entities:
1444
+ # labeled_examples.append({
1445
+ # 'text': text,
1446
+ # 'entities': entities,
1447
+ # 'system_prompt': st.session_state.system_prompt,
1448
+ # 'system_role': st.session_state.system_role,
1449
+ # 'task_type': 'Named Entity Recognition (NER)',
1450
+ # 'Use few-shot example?': 'Yes' if use_few_shot else 'No',
1451
+ # })
1452
+
1453
+ #new 22/4/2025
1454
+ labeled_examples = []
1455
+ if classification_type == "Named Entity Recognition (NER)":
1456
+ labeled_examples = [{
1457
+ 'ner_output': response.strip(),
1458
+ 'system_prompt': st.session_state.system_prompt,
1459
+ 'system_role': st.session_state.system_role,
1460
+ 'task_type': 'Named Entity Recognition (NER)',
1461
+ 'Use few-shot example?': 'Yes' if use_few_shot else 'No',
1462
+ }]
1463
+
1464
+ ######
1465
+
1466
+
1467
+ else:
1468
+ labeled_examples = []
1469
+ for line in response.split('\n'):
1470
+ if line.strip():
1471
+ parts = line.rsplit('Label:', 1)
1472
+ if len(parts) == 2:
1473
+ text = parts[0].strip()
1474
+ label = parts[1].strip()
1475
+ if text and label:
1476
+ labeled_examples.append({
1477
+ 'text': text,
1478
+ 'label': label,
1479
+ 'system_prompt': st.session_state.system_prompt,
1480
+ 'system_role': st.session_state.system_role,
1481
+ 'task_type': 'Data Labeling',
1482
+ 'Use few-shot example?': 'Yes' if use_few_shot else 'No',
1483
+ })
1484
+ # Save and provide download options
1485
+ if labeled_examples:
1486
+ # Update session state
1487
+ st.session_state.labeled_examples = labeled_examples
1488
+
1489
+ # Convert to CSV and JSON
1490
+ df = pd.DataFrame(labeled_examples)
1491
+ #new 22/4/2025
1492
+ # CSV
1493
+ st.session_state.labeled_examples_csv = df.to_csv(index=False).encode('utf-8')
1494
+
1495
+ # JSON
1496
+ st.session_state.labeled_examples_json = json.dumps({
1497
+ "metadata": {
1498
+ "domain": domain,
1499
+ "labels": labels,
1500
+ "used_few_shot": use_few_shot,
1501
+ "task_type": "Named Entity Recognition (NER)",
1502
+ "timestamp": datetime.now().isoformat()
1503
+ },
1504
+ "examples": labeled_examples
1505
+ }, indent=2).encode('utf-8')
1506
+
1507
+ ############
1508
+ # CSV
1509
+ # st.session_state.labeled_examples_csv = df.to_csv(index=False).encode('utf-8')
1510
+
1511
+ # # JSON
1512
+ # st.session_state.labeled_examples_json = json.dumps({
1513
+ # "metadata": {
1514
+ # "domain": domain,
1515
+ # "labels": labels,
1516
+ # "used_few_shot": use_few_shot,
1517
+ # "task_type": "Named Entity Recognition (NER)",
1518
+ # "timestamp": datetime.now().isoformat()
1519
+ # },
1520
+ # "examples": labeled_examples
1521
+ # }, indent=2).encode('utf-8')
1522
+
1523
+ ########
1524
+ # st.session_state.labeled_examples_csv = df.to_csv(index=False).encode('utf-8')
1525
+ # st.session_state.labeled_examples_json = json.dumps(labeled_examples, indent=2).encode('utf-8')
1526
+
1527
+ # Download buttons
1528
+ st.download_button(
1529
+ "📥 Download Labeled Examples (CSV)",
1530
+ st.session_state.labeled_examples_csv,
1531
+ "labeled_examples.csv",
1532
+ "text/csv",
1533
+ key='download-labeled-csv'
1534
+ )
1535
+
1536
+ st.markdown("""
1537
+ <div style='text-align: left; margin:15px 0; font-weight: 600; color: #666;'>. . . . . . or</div>
1538
+ """, unsafe_allow_html=True)
1539
+
1540
+ st.download_button(
1541
+ "📥 Download Labeled Examples (JSON)",
1542
+ st.session_state.labeled_examples_json,
1543
+ "labeled_examples.json",
1544
+ "application/json",
1545
+ key='download-labeled-json'
1546
+ )
1547
+ # Display the labeled examples
1548
+ st.markdown("##### 📋 Labeled Examples Preview")
1549
+ st.dataframe(df, use_container_width=True)
1550
+ # Display section
1551
+ #st.markdown("### 📋 Labeled Examples Preview")
1552
+ #st.dataframe(st.session_state.labeled_preview, use_container_width=True)
1553
+
1554
+
1555
+
1556
+ # if labeled_examples:
1557
+ # df = pd.DataFrame(labeled_examples)
1558
+ # csv = df.to_csv(index=False).encode('utf-8')
1559
+ # st.download_button(
1560
+ # "📥 Download Labeled Examples",
1561
+ # csv,
1562
+ # "labeled_examples.csv",
1563
+ # "text/csv",
1564
+ # key='download-labeled-csv'
1565
+ # )
1566
+ # # Add space and center the "or"
1567
+ # st.markdown("""
1568
+ # <div style='text-align: left; margin:15px 0; font-weight: 600; color: #666;'>. . . . . . or</div>
1569
+ # """, unsafe_allow_html=True)
1570
+
1571
+ # if labeled_examples:
1572
+ # df = pd.DataFrame(labeled_examples)
1573
+ # csv = df.to_csv(index=False).encode('utf-8')
1574
+ # st.download_button(
1575
+ # "📥 Download Labeled Examples",
1576
+ # csv,
1577
+ # "labeled_examples.json",
1578
+ # "text/json",
1579
+ # key='download-labeled-JSON'
1580
+ # )
1581
+
1582
+ # Add follow-up interaction options
1583
+ #st.markdown("---")
1584
+ #follow_up = st.radio(
1585
+ #"What would you like to do next?",
1586
+ #["Label more data", "Data Generation"],
1587
+ # key="labeling_follow_up"
1588
+ # )
1589
+
1590
+ if st.button("Continue"):
1591
+ if follow_up == "Label more data":
1592
+ st.session_state.examples_to_classify = []
1593
+ st.experimental_rerun()
1594
+ elif follow_up == "Data Generation":
1595
+ st.session_state.task_choice = "Data Labeling"
1596
+ st.experimental_rerun()
1597
+
1598
+ except Exception as e:
1599
+ st.error("An error occurred during labeling.")
1600
+ st.error(f"Details: {e}")
1601
+ else:
1602
+ st.warning("Please enter at least one example to classify.")
1603
+
1604
+ #st.session_state.messages.append({"role": "assistant", "content": response})
1605
+
1606
+
1607
+
1608
+
1609
+ # Footer
1610
+ st.markdown("---")
1611
+ st.markdown(
1612
+ """
1613
+ <div style='text-align: center'>
1614
+ <p>Made with ❤️ by Wedyan AlSakran 2025</p>
1615
+ </div>
1616
+ """,
1617
+ unsafe_allow_html=True
1618
+ )