acecalisto3 commited on
Commit
a96d22b
·
verified ·
1 Parent(s): 40667c5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +640 -221
app.py CHANGED
@@ -1,227 +1,646 @@
1
- import streamlit as st
2
- from transformers import AutoModelForCausalLM, AutoTokenizer
3
  import os
4
- import json
5
- import logging
6
- from threading import Lock
7
- import torch
8
-
9
- # Constants with optimized values for Mixtral
10
- DEFAULT_MODEL_NAME = "mistralai/Mixtral-8x7B-Instruct-v0.1"
11
- MAX_INPUT_TOKENS = 24576 # 24K tokens for input (leaving room for output)
12
- MAX_NEW_TOKENS = 8192 # 8K tokens for generation
13
- DEFAULT_CONTEXT_LENGTH = 16384 # 16K default context
14
- CONFIG_FILE = "chatbot_config.json"
15
- CACHE_DIR = "model_cache"
16
-
17
- class EnhancedChatbot:
18
- def __init__(self):
19
- self.model = None
20
- self.tokenizer = None
21
- self.model_lock = Lock()
22
-
23
- # Ensure cache directory exists
24
- os.makedirs(CACHE_DIR, exist_ok=True)
25
-
26
- # Initialize configuration with higher limits
27
- self.config = self.load_config()
28
-
29
- # Initialize model and tokenizer
30
- try:
31
- self.load_model()
32
- except Exception as e:
33
- st.error(f"Error loading model: {str(e)}")
34
- logging.error(f"Error loading model: {str(e)}")
35
-
36
- def load_config(self):
37
- """Load or create configuration file with optimized settings"""
38
- default_config = {
39
- "model_name": DEFAULT_MODEL_NAME,
40
- "max_new_tokens": MAX_NEW_TOKENS,
41
- "context_length": DEFAULT_CONTEXT_LENGTH,
42
- "temperature": 0.7,
43
- "top_p": 0.95,
44
- "top_k": 50,
45
- "repetition_penalty": 1.1,
46
- "system_message": "You are a helpful AI assistant with high context understanding.",
47
- "gpu_layers": "auto"
48
- }
49
-
50
- try:
51
- if os.path.exists(CONFIG_FILE):
52
- with open(CONFIG_FILE, 'r') as f:
53
- config = json.load(f)
54
- # Update with any missing keys from default_config
55
- for key, value in default_config.items():
56
- if key not in config:
57
- config[key] = value
58
- else:
59
- config = default_config
60
- self.save_config(config)
61
-
62
- return config
63
-
64
- except Exception as e:
65
- logging.error(f"Error loading config: {str(e)}")
66
- return default_config
67
-
68
- def load_model(self):
69
- """Load the model and tokenizer with optimized settings"""
70
- try:
71
- # Clear CUDA cache if using GPU
72
- if torch.cuda.is_available():
73
- torch.cuda.empty_cache()
74
-
75
- # Load tokenizer first
76
- self.tokenizer = AutoTokenizer.from_pretrained(
77
- self.config["model_name"],
78
- cache_dir=CACHE_DIR,
79
- model_max_length=self.config["context_length"],
80
- padding_side="left"
81
- )
82
 
83
- # Load model with optimized settings
84
- self.model = AutoModelForCausalLM.from_pretrained(
85
- self.config["model_name"],
86
- torch_dtype=torch.bfloat16, # Use bfloat16 for better performance
87
- low_cpu_mem_usage=True,
88
- cache_dir=CACHE_DIR,
89
- device_map="auto",
90
- max_memory={0: "24GiB"}, # Adjust based on your GPU
91
- trust_remote_code=True
92
- )
93
-
94
- logging.info(f"Model {self.config['model_name']} loaded successfully")
95
-
96
- except Exception as e:
97
- logging.error(f"Error loading model: {str(e)}")
98
- raise
99
-
100
- def generate_response(self, message, history):
101
- """Generate response with high token limit"""
102
- try:
103
- with self.model_lock:
104
- # Prepare conversation history
105
- full_prompt = self.prepare_prompt(message, history)
106
-
107
- # Tokenize with proper handling of long sequences
108
- inputs = self.tokenizer(full_prompt,
109
- return_tensors="pt",
110
- truncation=True,
111
- max_length=MAX_INPUT_TOKENS)
112
-
113
- # Move to GPU if available
114
- inputs = inputs.to(self.model.device)
115
-
116
- # Generate with optimized parameters
117
- outputs = self.model.generate(
118
- **inputs,
119
- max_new_tokens=self.config["max_new_tokens"],
120
- temperature=self.config["temperature"],
121
- top_p=self.config["top_p"],
122
- top_k=self.config["top_k"],
123
- repetition_penalty=self.config["repetition_penalty"],
124
- do_sample=True,
125
- pad_token_id=self.tokenizer.eos_token_id
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
126
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
127
 
128
- # Decode response
129
- response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
130
-
131
- return response.strip()
132
-
133
- except Exception as e:
134
- logging.error(f"Error generating response: {str(e)}")
135
- return "I apologize, but I encountered an error. Please try again."
136
-
137
- def prepare_prompt(self, message, history):
138
- """Prepare prompt with history management"""
139
- system_msg = self.config["system_message"]
140
- prompt = f"{system_msg}\n\n"
141
-
142
- # Add history with token counting
143
- total_tokens = 0
144
- for msg in history:
145
- tokens = len(self.tokenizer.encode(msg["content"]))
146
- if total_tokens + tokens < MAX_INPUT_TOKENS:
147
- prompt += f"{msg['role']}: {msg['content']}\n"
148
- total_tokens += tokens
149
- else:
150
- break
151
-
152
- prompt += f"user: {message}\nassistant:"
153
- return prompt
154
-
155
- # Streamlit UI with advanced settings
156
- def main():
157
- st.title("Enhanced AI Chatbot (High Context)")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
158
 
159
- try:
160
- chatbot = EnhancedChatbot()
161
-
162
- # Advanced settings in sidebar
163
- with st.sidebar:
164
- st.subheader("Model Settings")
165
-
166
- # Context length slider
167
- new_context = st.slider(
168
- "Context Length (tokens)",
169
- min_value=1024,
170
- max_value=32768,
171
- value=chatbot.config["context_length"],
172
- step=1024
173
- )
174
-
175
- # Generation settings
176
- new_max_tokens = st.slider(
177
- "Max New Tokens",
178
- min_value=1024,
179
- max_value=MAX_NEW_TOKENS,
180
- value=chatbot.config["max_new_tokens"],
181
- step=1024
182
- )
183
-
184
- temperature = st.slider(
185
- "Temperature",
186
- min_value=0.1,
187
- max_value=2.0,
188
- value=chatbot.config["temperature"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
189
  )
190
-
191
- # Update settings button
192
- if st.button("Update Settings"):
193
- chatbot.config.update({
194
- "context_length": new_context,
195
- "max_new_tokens": new_max_tokens,
196
- "temperature": temperature
197
- })
198
- chatbot.save_config(chatbot.config)
199
- st.experimental_rerun()
200
-
201
- # Chat interface
202
- if "messages" not in st.session_state:
203
- st.session_state.messages = []
204
-
205
- # Display chat messages
206
- for message in st.session_state.messages:
207
- with st.chat_message(message["role"]):
208
- st.markdown(message["content"])
209
-
210
- # Chat input
211
- if prompt := st.chat_input("What would you like to know?"):
212
- st.session_state.messages.append({"role": "user", "content": prompt})
213
- with st.chat_message("user"):
214
- st.markdown(prompt)
215
-
216
- with st.chat_message("assistant"):
217
- with st.spinner("Generating response..."):
218
- response = chatbot.generate_response(prompt, st.session_state.messages)
219
- st.markdown(response)
220
- st.session_state.messages.append({"role": "assistant", "content": response})
221
-
222
- except Exception as e:
223
- st.error(f"Application Error: {str(e)}")
224
- logging.error(f"Application Error: {str(e)}")
225
-
226
- if __name__ == "__main__":
227
- main()
 
1
+ import io
 
2
  import os
3
+ import re
4
+ import time
5
+ from itertools import islice
6
+ from functools import partial
7
+ from multiprocessing.pool import ThreadPool
8
+ from queue import Queue, Empty
9
+ from typing import Callable, Iterable, Iterator, Optional, TypeVar
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
+ import gradio as gr
12
+ import pandas as pd
13
+ import requests.exceptions
14
+ from huggingface_hub import InferenceClient, create_repo, whoami, DatasetCard
15
+
16
+
17
+ model_id = "microsoft/Phi-3-mini-4k-instruct"
18
+ client = InferenceClient(model_id)
19
+ save_dataset_hf_token = os.environ.get("SAVE_DATASET_HF_TOKEN")
20
+
21
+ MAX_TOTAL_NB_ITEMS = 100 # almost infinite, don't judge me (actually it's because gradio needs a fixed number of components)
22
+ MAX_NB_ITEMS_PER_GENERATION_CALL = 10
23
+ NUM_ROWS = 100
24
+ NUM_VARIANTS = 10
25
+ NAMESPACE = "infinite-dataset-hub"
26
+ URL = "https://huggingface.co/spaces/infinite-dataset-hub/infinite-dataset-hub"
27
+
28
+ GENERATE_DATASET_NAMES_FOR_SEARCH_QUERY = (
29
+ "A Machine Learning Practioner is looking for a dataset that matches '{search_query}'. "
30
+ f"Generate a list of {MAX_NB_ITEMS_PER_GENERATION_CALL} names of quality datasets that don't exist but sound plausible and would "
31
+ "be helpful. Feel free to reuse words from the query '{search_query}' to name the datasets. "
32
+ "Every dataset should be about '{search_query}' and have descriptive tags/keywords including the ML task name associated with the dataset (classification, regression, anomaly detection, etc.). Use the following format:\n1. DatasetName1 (tag1, tag2, tag3)\n1. DatasetName2 (tag1, tag2, tag3)"
33
+ )
34
+
35
+ GENERATE_DATASET_CONTENT_FOR_SEARCH_QUERY_AND_NAME_AND_TAGS = (
36
+ "An ML practitioner is looking for a dataset CSV after the query '{search_query}'. "
37
+ "Generate the first 5 rows of a plausible and quality CSV for the dataset '{dataset_name}'. "
38
+ "You can get inspiration from related keywords '{tags}' but most importantly the dataset should correspond to the query '{search_query}'. "
39
+ "Focus on quality text content and use a 'label' or 'labels' column if it makes sense (invent labels, avoid reusing the keywords, be accurate while labelling texts). "
40
+ "Reply using a short description of the dataset with title **Dataset Description:** followed by the CSV content in a code block and with title **CSV Content Preview:**."
41
+ )
42
+ GENERATE_MORE_ROWS = "Can you give me 10 additional samples in CSV format as well? Use the same CSV header '{csv_header}'."
43
+ GENERATE_VARIANTS_WITH_RARITY_AND_LABEL = "Focus on generating samples for the label '{label}' and ideally generate {rarity} samples."
44
+ GENERATE_VARIANTS_WITH_RARITY = "Focus on generating {rarity} samples."
45
+
46
+ RARITIES = ["pretty obvious", "common/regular", "unexpected but useful", "uncommon but still plausible", "rare/niche but still plausible"]
47
+ LONG_RARITIES = [
48
+ "obvious",
49
+ "expected",
50
+ "common",
51
+ "regular",
52
+ "unexpected but useful"
53
+ "original but useful",
54
+ "specific but not far-fetched",
55
+ "uncommon but still plausible",
56
+ "rare but still plausible",
57
+ "very niche but still plausible",
58
+ ]
59
+
60
+ landing_page_datasets_generated_text = """
61
+ 1. NewsEventsPredict (classification, media, trend)
62
+ 2. FinancialForecast (economy, stocks, regression)
63
+ 3. HealthMonitor (science, real-time, anomaly detection)
64
+ 4. SportsAnalysis (classification, performance, player tracking)
65
+ 5. SciLiteracyTools (language modeling, science literacy, text classification)
66
+ 6. RetailSalesAnalyzer (consumer behavior, sales trend, segmentation)
67
+ 7. SocialSentimentEcho (social media, emotion analysis, clustering)
68
+ 8. NewsEventTracker (classification, public awareness, topical clustering)
69
+ 9. HealthVitalSigns (anomaly detection, biometrics, prediction)
70
+ 10. GameStockPredict (classification, finance, sports contingency)
71
+ """
72
+ default_output = landing_page_datasets_generated_text.strip().split("\n")
73
+ assert len(default_output) == MAX_NB_ITEMS_PER_GENERATION_CALL
74
+
75
+ DATASET_CARD_CONTENT = """
76
+ ---
77
+ license: mit
78
+ tags:
79
+ - infinite-dataset-hub
80
+ - synthetic
81
+ ---
82
+ {title}
83
+ _Note: This is an AI-generated dataset so its content may be inaccurate or false_
84
+ {content}
85
+ **Source of the data:**
86
+ The dataset was generated using the [Infinite Dataset Hub]({url}) and {model_id} using the query '{search_query}':
87
+ - **Dataset Generation Page**: {dataset_url}
88
+ - **Model**: https://huggingface.co/{model_id}
89
+ - **More Datasets**: https://huggingface.co/datasets?other=infinite-dataset-hub
90
+ """
91
+
92
+ css = """
93
+ a {
94
+ color: var(--body-text-color);
95
+ }
96
+ .datasetButton {
97
+ justify-content: start;
98
+ justify-content: left;
99
+ }
100
+ .tags {
101
+ font-size: var(--button-small-text-size);
102
+ color: var(--body-text-color-subdued);
103
+ }
104
+ .topButton {
105
+ justify-content: start;
106
+ justify-content: left;
107
+ text-align: left;
108
+ background: transparent;
109
+ box-shadow: none;
110
+ padding-bottom: 0;
111
+ }
112
+ .topButton::before {
113
+ content: url("data:image/svg+xml,%3Csvg style='color: rgb(209 213 219)' xmlns='http://www.w3.org/2000/svg' xmlns:xlink='http://www.w3.org/1999/xlink' aria-hidden='true' focusable='false' role='img' width='1em' height='1em' preserveAspectRatio='xMidYMid meet' viewBox='0 0 25 25'%3E%3Cellipse cx='12.5' cy='5' fill='currentColor' fill-opacity='0.25' rx='7.5' ry='2'%3E%3C/ellipse%3E%3Cpath d='M12.5 15C16.6421 15 20 14.1046 20 13V20C20 21.1046 16.6421 22 12.5 22C8.35786 22 5 21.1046 5 20V13C5 14.1046 8.35786 15 12.5 15Z' fill='currentColor' opacity='0.5'%3E%3C/path%3E%3Cpath d='M12.5 7C16.6421 7 20 6.10457 20 5V11.5C20 12.6046 16.6421 13.5 12.5 13.5C8.35786 13.5 5 12.6046 5 11.5V5C5 6.10457 8.35786 7 12.5 7Z' fill='currentColor' opacity='0.5'%3E%3C/path%3E%3Cpath d='M5.23628 12C5.08204 12.1598 5 12.8273 5 13C5 14.1046 8.35786 15 12.5 15C16.6421 15 20 14.1046 20 13C20 12.8273 19.918 12.1598 19.7637 12C18.9311 12.8626 15.9947 13.5 12.5 13.5C9.0053 13.5 6.06886 12.8626 5.23628 12Z' fill='currentColor'%3E%3C/path%3E%3C/svg%3E");
114
+ margin-right: .25rem;
115
+ margin-left: -.125rem;
116
+ margin-top: .25rem;
117
+ }
118
+ .bottomButton {
119
+ justify-content: start;
120
+ justify-content: left;
121
+ text-align: left;
122
+ background: transparent;
123
+ box-shadow: none;
124
+ font-size: var(--button-small-text-size);
125
+ color: var(--body-text-color-subdued);
126
+ padding-top: 0;
127
+ align-items: baseline;
128
+ }
129
+ .bottomButton::before {
130
+ content: 'tags:';
131
+ margin-right: .25rem;
132
+ }
133
+ .buttonsGroup {
134
+ background: transparent;
135
+ }
136
+ .buttonsGroup:hover {
137
+ background: var(--input-background-fill);
138
+ }
139
+ .buttonsGroup div {
140
+ background: transparent;
141
+ }
142
+ .insivibleButtonGroup {
143
+ display: none;
144
+ }
145
+ @keyframes placeHolderShimmer{
146
+ 0%{
147
+ background-position: -468px 0
148
+ }
149
+ 100%{
150
+ background-position: 468px 0
151
+ }
152
+ }
153
+ .linear-background {
154
+ animation-duration: 1s;
155
+ animation-fill-mode: forwards;
156
+ animation-iteration-count: infinite;
157
+ animation-name: placeHolderShimmer;
158
+ animation-timing-function: linear;
159
+ background-image: linear-gradient(to right, var(--body-text-color-subdued) 8%, #dddddd11 18%, var(--body-text-color-subdued) 33%);
160
+ background-size: 1000px 104px;
161
+ color: transparent;
162
+ background-clip: text;
163
+ }
164
+ .settings {
165
+ background: transparent;
166
+ }
167
+ .settings button span {
168
+ color: var(--body-text-color-subdued);
169
+ }
170
+ """
171
+
172
+
173
+ with gr.Blocks(css=css) as demo:
174
+ generated_texts_state = gr.State((landing_page_datasets_generated_text,))
175
+ with gr.Column() as search_page:
176
+ with gr.Row():
177
+ with gr.Column(scale=10):
178
+ gr.Markdown(
179
+ "# 🤗 Infinite Dataset Hub ♾️\n\n"
180
+ "An endless catalog of datasets, created just for you by an AI model.\n\n"
181
  )
182
+ with gr.Row():
183
+ search_bar = gr.Textbox(max_lines=1, placeholder="Search datasets, get infinite results", show_label=False, container=False, scale=9)
184
+ search_button = gr.Button("🔍", variant="primary", scale=1)
185
+ button_groups: list[gr.Group] = []
186
+ buttons: list[gr.Button] = []
187
+ for i in range(MAX_TOTAL_NB_ITEMS):
188
+ if i < len(default_output):
189
+ line = default_output[i]
190
+ dataset_name, tags = line.split(".", 1)[1].strip(" )").split(" (", 1)
191
+ group_classes = "buttonsGroup"
192
+ dataset_name_classes = "topButton"
193
+ tags_classes = "bottomButton"
194
+ else:
195
+ dataset_name, tags = "⬜⬜⬜⬜⬜⬜", "░░░░, ░░░░, ░░░░"
196
+ group_classes = "buttonsGroup insivibleButtonGroup"
197
+ dataset_name_classes = "topButton linear-background"
198
+ tags_classes = "bottomButton linear-background"
199
+ with gr.Group(elem_classes=group_classes) as button_group:
200
+ button_groups.append(button_group)
201
+ buttons.append(gr.Button(dataset_name, elem_classes=dataset_name_classes))
202
+ buttons.append(gr.Button(tags, elem_classes=tags_classes))
203
+
204
+ load_more_datasets = gr.Button("Load more datasets") # TODO: dosable when reaching end of page
205
+ gr.Markdown(f"_powered by [{model_id}](https://huggingface.co/{model_id})_")
206
+ with gr.Column(scale=4, min_width="200px"):
207
+ with gr.Accordion("Settings", open=False, elem_classes="settings"):
208
+ gr.Markdown("Save datasets to your account")
209
+ gr.LoginButton()
210
+ select_namespace_dropdown = gr.Dropdown(choices=[NAMESPACE], value=NAMESPACE, label="Select user or organization", visible=False)
211
+ gr.Markdown("Save datasets as public or private datasets")
212
+ visibility_radio = gr.Radio(["public", "private"], value="public", container=False, interactive=False)
213
+ with gr.Column(visible=False) as dataset_page:
214
+ gr.Markdown(
215
+ "# 🤗 Infinite Dataset Hub ♾️\n\n"
216
+ "An endless catalog of datasets, created just for you.\n\n"
217
+ )
218
+ dataset_title = gr.Markdown()
219
+ gr.Markdown("_Note: This is an AI-generated dataset so its content may be inaccurate or false_")
220
+ dataset_content = gr.Markdown()
221
+ generate_full_dataset_button = gr.Button("Generate Full Dataset", variant="primary")
222
+ dataset_dataframe = gr.DataFrame(visible=False, interactive=False, wrap=True)
223
+ save_dataset_button = gr.Button("💾 Save Dataset", variant="primary", visible=False)
224
+ open_dataset_message = gr.Markdown("", visible=False)
225
+ dataset_share_button = gr.Button("Share Dataset URL")
226
+ dataset_share_textbox = gr.Textbox(visible=False, show_copy_button=True, label="Copy this URL:", interactive=False, show_label=True)
227
+ back_button = gr.Button("< Back", size="sm")
228
+
229
+ ###################################
230
+ #
231
+ # Utils
232
+ #
233
+ ###################################
234
+
235
+ T = TypeVar("T")
236
+
237
+ def batched(it: Iterable[T], n: int) -> Iterator[list[T]]:
238
+ it = iter(it)
239
+ while batch := list(islice(it, n)):
240
+ yield batch
241
+
242
+
243
+ def stream_reponse(msg: str, generated_texts: tuple[str] = (), max_tokens=500) -> Iterator[str]:
244
+ messages = [
245
+ {"role": "user", "content": msg}
246
+ ] + [
247
+ item
248
+ for generated_text in generated_texts
249
+ for item in [
250
+ {"role": "assistant", "content": generated_text},
251
+ {"role": "user", "content": "Can you generate more ?"},
252
+ ]
253
+ ]
254
+ for _ in range(3):
255
+ try:
256
+ for message in client.chat_completion(
257
+ messages=messages,
258
+ max_tokens=max_tokens,
259
+ stream=True,
260
+ top_p=0.8,
261
+ seed=42,
262
+ ):
263
+ yield message.choices[0].delta.content
264
+ except requests.exceptions.ConnectionError as e:
265
+ print(e + "\n\nRetrying in 1sec")
266
+ time.sleep(1)
267
+ continue
268
+ break
269
+
270
+
271
+ def gen_datasets_line_by_line(search_query: str, generated_texts: tuple[str] = ()) -> Iterator[str]:
272
+ search_query = search_query or ""
273
+ search_query = search_query[:1000] if search_query.strip() else ""
274
+ generated_text = ""
275
+ current_line = ""
276
+ for token in stream_reponse(
277
+ GENERATE_DATASET_NAMES_FOR_SEARCH_QUERY.format(search_query=search_query),
278
+ generated_texts=generated_texts,
279
+ ):
280
+ current_line += token
281
+ if current_line.endswith("\n"):
282
+ yield current_line
283
+ generated_text += current_line
284
+ current_line = ""
285
+ yield current_line
286
+ generated_text += current_line
287
+ print("-----\n\n" + generated_text)
288
+
289
+
290
+ def gen_dataset_content(search_query: str, dataset_name: str, tags: str) -> Iterator[str]:
291
+ search_query = search_query or ""
292
+ search_query = search_query[:1000] if search_query.strip() else ""
293
+ generated_text = ""
294
+ for token in stream_reponse(GENERATE_DATASET_CONTENT_FOR_SEARCH_QUERY_AND_NAME_AND_TAGS.format(
295
+ search_query=search_query,
296
+ dataset_name=dataset_name,
297
+ tags=tags,
298
+ ), max_tokens=1500):
299
+ generated_text += token
300
+ yield generated_text
301
+ print("-----\n\n" + generated_text)
302
+
303
+
304
+ def _write_generator_to_queue(queue: Queue, func: Callable[..., Iterable], kwargs: dict) -> None:
305
+ for i, result in enumerate(func(**kwargs)):
306
+ queue.put(result)
307
+ return None
308
+
309
+
310
+ def iflatmap_unordered(
311
+ func: Callable[..., Iterable[T]],
312
+ *,
313
+ kwargs_iterable: Iterable[dict],
314
+ ) -> Iterable[T]:
315
+ queue = Queue()
316
+ with ThreadPool() as pool:
317
+ async_results = [
318
+ pool.apply_async(_write_generator_to_queue, (queue, func, kwargs)) for kwargs in kwargs_iterable
319
+ ]
320
+ try:
321
+ while True:
322
+ try:
323
+ yield queue.get(timeout=0.05)
324
+ except Empty:
325
+ if all(async_result.ready() for async_result in async_results) and queue.empty():
326
+ break
327
+ finally:
328
+ # we get the result in case there's an error to raise
329
+ [async_result.get(timeout=0.05) for async_result in async_results]
330
+
331
+
332
+ def generate_partial_dataset(title: str, content: str, search_query: str, variant: str, csv_header: str, output: list[dict[str, str]], indices_to_generate: list[int], max_tokens=1500) -> Iterator[int]:
333
+ dataset_name, tags = title.strip("# ").split("\ntags:", 1)
334
+ dataset_name, tags = dataset_name.strip(), tags.strip()
335
+ messages = [
336
+ {
337
+ "role": "user",
338
+ "content": GENERATE_DATASET_CONTENT_FOR_SEARCH_QUERY_AND_NAME_AND_TAGS.format(
339
+ dataset_name=dataset_name,
340
+ tags=tags,
341
+ search_query=search_query,
342
+ )
343
+ },
344
+ {"role": "assistant", "content": title + "\n\n" + content},
345
+ {"role": "user", "content": GENERATE_MORE_ROWS.format(csv_header=csv_header) + " " + variant},
346
+ ]
347
+ for _ in range(3):
348
+ generated_text = ""
349
+ generated_csv = ""
350
+ current_line = ""
351
+ nb_samples = 0
352
+ _in_csv = False
353
+ try:
354
+ for message in client.chat_completion(
355
+ messages=messages,
356
+ max_tokens=max_tokens,
357
+ stream=True,
358
+ top_p=0.8,
359
+ seed=42,
360
+ ):
361
+ if nb_samples >= len(indices_to_generate):
362
+ break
363
+ current_line += message.choices[0].delta.content
364
+ generated_text += message.choices[0].delta.content
365
+ if current_line.endswith("\n"):
366
+ _in_csv = _in_csv ^ current_line.lstrip().startswith("```")
367
+ if current_line.strip() and _in_csv and not current_line.lstrip().startswith("```"):
368
+ generated_csv += current_line
369
+ try:
370
+ generated_df = parse_csv_df(generated_csv.strip(), csv_header=csv_header)
371
+ if len(generated_df) > nb_samples:
372
+ output[indices_to_generate[nb_samples]] = generated_df.iloc[-1].to_dict()
373
+ nb_samples += 1
374
+ yield 1
375
+ except Exception:
376
+ pass
377
+ current_line = ""
378
+ except requests.exceptions.ConnectionError as e:
379
+ print(e + "\n\nRetrying in 1sec")
380
+ time.sleep(1)
381
+ continue
382
+ break
383
+ # for debugging
384
+ # with open(f".output{indices_to_generate[0]}.txt", "w") as f:
385
+ # f.write(generated_text)
386
+
387
+
388
+ def generate_variants(preview_df: pd.DataFrame):
389
+ label_candidate_columns = [column for column in preview_df.columns if "label" in column.lower()]
390
+ if label_candidate_columns:
391
+ labels = preview_df[label_candidate_columns[0]].unique()
392
+ if len(labels) > 1:
393
+ return [
394
+ GENERATE_VARIANTS_WITH_RARITY_AND_LABEL.format(rarity=rarity, label=label)
395
+ for rarity in RARITIES
396
+ for label in labels
397
+ ]
398
+ return [
399
+ GENERATE_VARIANTS_WITH_RARITY.format(rarity=rarity)
400
+ for rarity in LONG_RARITIES
401
+ ]
402
+
403
+
404
+ def parse_preview_df(content: str) -> tuple[str, pd.DataFrame]:
405
+ _in_csv = False
406
+ csv = "\n".join(
407
+ line for line in content.split("\n") if line.strip()
408
+ and (_in_csv := (_in_csv ^ line.lstrip().startswith("```")))
409
+ and not line.lstrip().startswith("```")
410
+ )
411
+ if not csv:
412
+ raise gr.Error("Failed to parse CSV Preview")
413
+ return csv.split("\n")[0], parse_csv_df(csv)
414
 
415
+
416
+ def parse_csv_df(csv: str, csv_header: Optional[str] = None) -> pd.DataFrame:
417
+ # Fix generation mistake when providing a list that is not in quotes
418
+ for match in re.finditer(r'''(?!")\[(["'][\w ]+["'][, ]*)+\](?!")''', csv):
419
+ span = match.string[match.start() : match.end()]
420
+ csv = csv.replace(span, '"' + span.replace('"', "'") + '"', 1)
421
+ # Add header if missing
422
+ if csv_header and csv.strip().split("\n")[0] != csv_header:
423
+ csv = csv_header + "\n" + csv
424
+ # Read CSV
425
+ df = pd.read_csv(io.StringIO(csv), skipinitialspace=True)
426
+ return df
427
+
428
+
429
+ ###################################
430
+ #
431
+ # Buttons
432
+ #
433
+ ###################################
434
+
435
+
436
+ def _search_datasets(search_query):
437
+ yield {generated_texts_state: []}
438
+ yield {
439
+ button_group: gr.Group(elem_classes="buttonsGroup insivibleButtonGroup")
440
+ for button_group in button_groups[MAX_NB_ITEMS_PER_GENERATION_CALL:]
441
+ }
442
+ yield {
443
+ k: v
444
+ for dataset_name_button, tags_button in batched(buttons, 2)
445
+ for k, v in {
446
+ dataset_name_button: gr.Button("⬜⬜⬜⬜⬜⬜", elem_classes="topButton linear-background"),
447
+ tags_button: gr.Button("░░░░, ░░░░, ░░░░", elem_classes="bottomButton linear-background")
448
+ }.items()
449
+ }
450
+ current_item_idx = 0
451
+ generated_text = ""
452
+ for line in gen_datasets_line_by_line(search_query):
453
+ if "I'm sorry" in line or "against Microsoft's use case policy" in line:
454
+ raise gr.Error("Error: inappropriate content")
455
+ if current_item_idx >= MAX_NB_ITEMS_PER_GENERATION_CALL:
456
+ return
457
+ if line.strip() and line.strip().split(".", 1)[0].isnumeric():
458
+ try:
459
+ dataset_name, tags = line.strip().split(".", 1)[1].strip(" )").split(" (", 1)
460
+ except ValueError:
461
+ dataset_name, tags = line.strip().split(".", 1)[1].strip(" )").split(" ", 1)
462
+ dataset_name, tags = dataset_name.strip("()[]* "), tags.strip("()[]* ")
463
+ generated_text += line
464
+ yield {
465
+ buttons[2 * current_item_idx]: gr.Button(dataset_name, elem_classes="topButton"),
466
+ buttons[2 * current_item_idx + 1]: gr.Button(tags, elem_classes="bottomButton"),
467
+ generated_texts_state: (generated_text,),
468
+ }
469
+ current_item_idx += 1
470
+
471
+
472
+ @search_button.click(inputs=search_bar, outputs=button_groups + buttons + [generated_texts_state])
473
+ def search_dataset_from_search_button(search_query):
474
+ yield from _search_datasets(search_query)
475
+
476
+
477
+ @search_bar.submit(inputs=search_bar, outputs=button_groups + buttons + [generated_texts_state])
478
+ def search_dataset_from_search_bar(search_query):
479
+ yield from _search_datasets(search_query)
480
+
481
+
482
+ @load_more_datasets.click(inputs=[search_bar, generated_texts_state], outputs=button_groups + buttons + [generated_texts_state])
483
+ def search_more_datasets(search_query, generated_texts):
484
+ current_item_idx = initial_item_idx = len(generated_texts) * MAX_NB_ITEMS_PER_GENERATION_CALL
485
+ yield {
486
+ button_group: gr.Group(elem_classes="buttonsGroup")
487
+ for button_group in button_groups[len(generated_texts) * MAX_NB_ITEMS_PER_GENERATION_CALL:(len(generated_texts) + 1) * MAX_NB_ITEMS_PER_GENERATION_CALL]
488
+ }
489
+ generated_text = ""
490
+ for line in gen_datasets_line_by_line(search_query, generated_texts=generated_texts):
491
+ if "I'm sorry" in line or "against Microsoft's use case policy" in line:
492
+ raise gr.Error("Error: inappropriate content")
493
+ if current_item_idx - initial_item_idx >= MAX_NB_ITEMS_PER_GENERATION_CALL:
494
+ return
495
+ if line.strip() and line.strip().split(".", 1)[0].isnumeric():
496
+ try:
497
+ dataset_name, tags = line.strip().split(".", 1)[1].strip(" )").split(" (", 1)
498
+ except ValueError:
499
+ dataset_name, tags = line.strip().split(".", 1)[1].strip(" )").split(" ", 1) [0], ""
500
+ dataset_name, tags = dataset_name.strip("()[]* "), tags.strip("()[]* ")
501
+ generated_text += line
502
+ yield {
503
+ buttons[2 * current_item_idx]: gr.Button(dataset_name, elem_classes="topButton"),
504
+ buttons[2 * current_item_idx + 1]: gr.Button(tags, elem_classes="bottomButton"),
505
+ generated_texts_state: (*generated_texts, generated_text),
506
+ }
507
+ current_item_idx += 1
508
+
509
+ def _show_dataset(search_query, dataset_name, tags):
510
+ yield {
511
+ search_page: gr.Column(visible=False),
512
+ dataset_page: gr.Column(visible=True),
513
+ dataset_title: f"# {dataset_name}\n\n tags: {tags}",
514
+ dataset_share_textbox: gr.Textbox(visible=False),
515
+ dataset_dataframe: gr.DataFrame(visible=False),
516
+ generate_full_dataset_button: gr.Button(interactive=True),
517
+ save_dataset_button: gr.Button(visible=False),
518
+ open_dataset_message: gr.Markdown(visible=False)
519
+ }
520
+ for generated_text in gen_dataset_content(search_query=search_query, dataset_name=dataset_name, tags=tags):
521
+ yield {dataset_content: generated_text}
522
+
523
+
524
+ show_dataset_inputs = [search_bar, *buttons]
525
+ show_dataset_outputs = [search_page, dataset_page, dataset_title, dataset_content, generate_full_dataset_button, dataset_dataframe, save_dataset_button, open_dataset_message, dataset_share_textbox]
526
+ scroll_to_top_js = """
527
+ function (...args) {
528
+ console.log(args);
529
+ if ('parentIFrame' in window) {
530
+ window.parentIFrame.scrollTo({top: 0, behavior:'smooth'});
531
+ } else {
532
+ window.scrollTo({ top: 0 });
533
+ }
534
+ return args;
535
+ }
536
+ """
537
+
538
+ def show_dataset_from_button(search_query, *buttons_values, i):
539
+ dataset_name, tags = buttons_values[2 * i : 2 * i + 2]
540
+ yield from _show_dataset(search_query, dataset_name, tags)
541
 
542
+ for i, (dataset_name_button, tags_button) in enumerate(batched(buttons, 2)):
543
+ dataset_name_button.click(partial(show_dataset_from_button, i=i), inputs=show_dataset_inputs, outputs=show_dataset_outputs, js=scroll_to_top_js)
544
+ tags_button.click(partial(show_dataset_from_button, i=i), inputs=show_dataset_inputs, outputs=show_dataset_outputs, js=scroll_to_top_js)
545
+
546
+
547
+ @back_button.click(outputs=[search_page, dataset_page], js=scroll_to_top_js)
548
+ def show_search_page():
549
+ return gr.Column(visible=True), gr.Column(visible=False)
550
+
551
+
552
+ @generate_full_dataset_button.click(inputs=[dataset_title, dataset_content, search_bar, select_namespace_dropdown, visibility_radio], outputs=[dataset_dataframe, generate_full_dataset_button, save_dataset_button])
553
+ def generate_full_dataset(title, content, search_query, namespace, visability):
554
+ dataset_name, tags = title.strip("# ").split("\ntags:", 1)
555
+ dataset_name, tags = dataset_name.strip(), tags.strip()
556
+ csv_header, preview_df = parse_preview_df(content)
557
+ # Remove dummy "id" columns
558
+ for column_name, values in preview_df.to_dict(orient="series").items():
559
+ try:
560
+ if [int(v) for v in values] == list(range(len(preview_df))):
561
+ preview_df = preview_df.drop(columns=column_name)
562
+ if [int(v) for v in values] == list(range(1, len(preview_df) + 1)):
563
+ preview_df = preview_df.drop(columns=column_name)
564
+ except Exception:
565
+ pass
566
+ columns = list(preview_df)
567
+ output: list[Optional[dict]] = [None] * NUM_ROWS
568
+ output[:len(preview_df)] = [{"idx": i, **x} for i, x in enumerate(preview_df.to_dict(orient="records"))]
569
+ yield {
570
+ dataset_dataframe: gr.DataFrame(pd.DataFrame([{"idx": i, **x} for i, x in enumerate(output) if x]), visible=True),
571
+ generate_full_dataset_button: gr.Button(interactive=False),
572
+ save_dataset_button: gr.Button(f"💾 Save Dataset {namespace}/{dataset_name}" + (" (private)" if visability != "public" else ""), visible=True, interactive=False)
573
+ }
574
+ kwargs_iterable = [
575
+ {
576
+ "title": title,
577
+ "content": content,
578
+ "search_query": search_query,
579
+ "variant": variant,
580
+ "csv_header": csv_header,
581
+ "output": output,
582
+ "indices_to_generate": list(range(len(preview_df) + i, NUM_ROWS, NUM_VARIANTS)),
583
+ }
584
+ for i, variant in enumerate(islice(generate_variants(preview_df), NUM_VARIANTS))
585
+ ]
586
+ for _ in iflatmap_unordered(generate_partial_dataset, kwargs_iterable=kwargs_iterable):
587
+ yield {dataset_dataframe: pd.DataFrame([{"idx": i, **{column_name: x.get(column_name) for column_name in columns}} for i, x in enumerate(output) if x])}
588
+ yield {save_dataset_button: gr.Button(interactive=True)}
589
+ print(f"Generated {dataset_name}!")
590
+
591
+
592
+ @save_dataset_button.click(inputs=[dataset_title, dataset_content, search_bar, dataset_dataframe, select_namespace_dropdown, visibility_radio], outputs=[save_dataset_button, open_dataset_message])
593
+ def save_dataset(title: str, content: str, search_query: str, df: pd.DataFrame, namespace: str, visability: str, oauth_token: Optional[gr.OAuthToken]):
594
+ dataset_name, tags = title.strip("# ").split("\ntags:", 1)
595
+ dataset_name, tags = dataset_name.strip(), tags.strip()
596
+ token = oauth_token.token if oauth_token else save_dataset_hf_token
597
+ repo_id = f"{namespace}/{dataset_name}"
598
+ dataset_url = f"{URL}?q={search_query.replace(' ', '+')}&dataset={dataset_name.replace(' ', '+')}&tags={tags.replace(' ', '+')}"
599
+ gr.Info("Saving dataset...")
600
+ yield {save_dataset_button: gr.Button(interactive=False)}
601
+ create_repo(repo_id=repo_id, repo_type="dataset", private=visability!="public", exist_ok=True, token=token)
602
+ df.to_csv(f"hf://datasets/{repo_id}/data.csv", storage_options={"token": token}, index=False)
603
+ DatasetCard(DATASET_CARD_CONTENT.format(title=title, content=content, url=URL, dataset_url=dataset_url, model_id=model_id, search_query=search_query)).push_to_hub(repo_id=repo_id, repo_type="dataset", token=token)
604
+ gr.Info(f"✅ Dataset saved at {repo_id}")
605
+ additional_message = "PS: You can also save datasets under your account in the Settings ;)"
606
+ yield {open_dataset_message: gr.Markdown(f"# 🎉 Yay ! Your dataset has been saved to [{repo_id}](https://huggingface.co/datasets/{repo_id}) !\n\nDataset link: [https://huggingface.co/datasets/{repo_id}](https://huggingface.co/datasets/{repo_id})\n\n{additional_message}", visible=True)}
607
+ print(f"Saved {dataset_name}!")
608
+
609
+
610
+ @dataset_share_button.click(inputs=[dataset_title, search_bar], outputs=[dataset_share_textbox])
611
+ def show_dataset_url(title, search_query):
612
+ dataset_name, tags = title.strip("# ").split("\ntags:", 1)
613
+ dataset_name, tags = dataset_name.strip(), tags.strip()
614
+ return gr.Textbox(
615
+ f"{URL}?q={search_query.replace(' ', '+')}&dataset={dataset_name.replace(' ', '+')}&tags={tags.replace(' ', '+')}",
616
+ visible=True,
617
+ )
618
+
619
+ @demo.load(outputs=show_dataset_outputs + button_groups + buttons + [generated_texts_state] + [select_namespace_dropdown, visibility_radio])
620
+ def load_app(request: gr.Request, oauth_token: Optional[gr.OAuthToken]):
621
+ if oauth_token:
622
+ user_info = whoami(oauth_token.token)
623
+ yield {
624
+ select_namespace_dropdown: gr.Dropdown(
625
+ choices=[user_info["name"]] + [org_info["name"] for org_info in user_info["orgs"]],
626
+ value=user_info["name"],
627
+ visible=True,
628
+ ),
629
+ visibility_radio: gr.Radio(interactive=True),
630
+ }
631
+ query_params = dict(request.query_params)
632
+ if "dataset" in query_params:
633
+ yield from _show_dataset(
634
+ search_query=query_params.get("q", query_params["dataset"]),
635
+ dataset_name=query_params["dataset"],
636
+ tags=query_params.get("tags", "")
637
  )
638
+ elif "q" in query_params:
639
+ yield {search_bar: query_params["q"]}
640
+ yield from _search_datasets(query_params["q"])
641
+ else:
642
+ yield {search_page: gr.Column(visible=True)}
643
+
644
+
645
+ demo.launch()
646
+