Spaces:
Runtime error
Runtime error
File size: 9,094 Bytes
dc2b56f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 |
from datasets import load_dataset, Dataset
import json
import csv
import openai
import anthropic
import requests
import os
import logging
from tqdm import tqdm
import time
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
def prepare_dataset(dataset_source, dataset_path, tokenizer, hf_token=None):
"""
Prepare a dataset for fine-tuning, either from Hugging Face or a local file.
Args:
dataset_source (str): 'huggingface' or 'local'
dataset_path (str): Path or identifier of the dataset
tokenizer: The tokenizer associated with the model
hf_token (str, optional): Hugging Face token for accessing datasets
Returns:
Dataset: Prepared dataset ready for fine-tuning
"""
if dataset_source == 'huggingface':
try:
dataset = load_dataset(dataset_path, split="train", use_auth_token=hf_token)
except ValueError:
# If use_auth_token is not supported, try without it
dataset = load_dataset(dataset_path, split="train")
elif dataset_source == 'local':
if not os.path.exists(dataset_path):
raise FileNotFoundError(f"File not found: {dataset_path}")
if dataset_path.endswith('.json'):
with open(dataset_path, 'r') as f:
data = json.load(f)
if isinstance(data, list):
dataset = Dataset.from_list(data)
elif isinstance(data, dict):
dataset = Dataset.from_dict(data)
else:
raise ValueError("JSON file must contain either a list or a dictionary.")
elif dataset_path.endswith('.csv'):
with open(dataset_path, 'r') as f:
reader = csv.DictReader(f)
data = list(reader)
dataset = Dataset.from_list(data)
else:
raise ValueError("Unsupported file format. Please use JSON or CSV.")
else:
raise ValueError("Invalid dataset source. Use 'huggingface' or 'local'.")
# Check if 'conversations' column exists, if not, try to create it
if 'conversations' not in dataset.column_names:
if 'text' in dataset.column_names:
dataset = dataset.map(lambda example: {'conversations': [{'human': example['text'], 'assistant': ''}]})
else:
raise ValueError("Dataset does not contain 'conversations' or 'text' column. Please check your dataset structure.")
# Only apply standardize_sharegpt if 'conversations' column exists
if 'conversations' in dataset.column_names:
dataset = standardize_sharegpt(dataset)
def formatting_prompts_func(examples):
if tokenizer is None:
raise ValueError("Tokenizer is not properly initialized. Please load the model and tokenizer before preparing the dataset.")
convos = examples["conversations"]
texts = [tokenizer.apply_chat_template(convo, tokenize=False, add_generation_prompt=False) for convo in convos]
return {"text": texts}
dataset = dataset.map(formatting_prompts_func, batched=True)
if 'text' not in dataset.column_names:
def format_conversation(example):
formatted_text = ""
for turn in example['conversations']:
formatted_text += f"{turn['role']}: {turn['content']}\n"
return {"text": formatted_text.strip()}
dataset = dataset.map(format_conversation)
return dataset
def standardize_sharegpt(dataset):
# This is a simplified version. You might need to adjust it based on your specific needs.
def process_conversation(conversation):
standardized = []
for turn in conversation:
if 'human' in turn:
standardized.append({'role': 'user', 'content': turn['human']})
if 'assistant' in turn:
standardized.append({'role': 'assistant', 'content': turn['assistant']})
return standardized
return dataset.map(lambda x: {'conversations': process_conversation(x['conversations'])})
def create_synthetic_dataset(examples, expected_structure, num_samples, ai_provider, api_key, model_name=None):
"""
Create a synthetic dataset based on example conversations and expected structure.
Args:
examples (str): Example conversations to base the synthetic data on.
expected_structure (str): Description of the expected dataset structure.
num_samples (int): Number of synthetic samples to generate.
ai_provider (str): AI provider to use for generation ('OpenAI', 'Anthropic', or 'Ollama').
api_key (str): API key for the chosen AI provider.
model_name (str, optional): Model name for Ollama (if applicable).
Returns:
Dataset: Synthetic dataset ready for fine-tuning.
"""
synthetic_data = []
prompt = f"""
You are an AI assistant creating training dataset for finetuning a model.
You are provided an one-shot or few-shot output example of output that application expects from the AI model. You are also provided the
expected structure that the to-be trained AI model expects during training process.
Examples:
{examples}
Expected structure:
{expected_structure}
Please help Generate a new dataset in the provided same style and expected structure. Do not produce any extra output except the dataset in the training needed structure:
"""
if ai_provider == "OpenAI":
client = openai.OpenAI(api_key=api_key)
for _ in tqdm(range(num_samples), desc="Generating samples"):
try:
response = client.chat.completions.create(
model="gpt-4-0125-preview",
messages=[{"role": "user", "content": prompt}],
timeout=30 # 30 seconds timeout
)
conversation = response.choices[0].message.content
synthetic_data.append({"conversations": json.loads(conversation)})
except json.JSONDecodeError:
logger.warning(f"Failed to decode response as JSON: {response.choices[0].message.content}")
except openai.APITimeoutError:
logger.warning("OpenAI API request timed out")
except Exception as e:
logger.error(f"Unexpected error: {str(e)}")
time.sleep(1) # Rate limiting
elif ai_provider == "Anthropic":
client = anthropic.Anthropic(api_key=api_key)
for _ in tqdm(range(num_samples), desc="Generating samples"):
try:
response = client.completions.create(
model="claude-3-opus-20240229",
prompt=f"Human: {prompt}\n\nAssistant:",
max_tokens_to_sample=1000,
timeout=30 # 30 seconds timeout
)
synthetic_data.append({"conversations": json.loads(response.completion)})
except json.JSONDecodeError:
logger.warning(f"Failed to decode response as JSON: {response.completion}")
except anthropic.APITimeoutError:
logger.warning("Anthropic API request timed out")
except Exception as e:
logger.error(f"Unexpected error: {str(e)}")
time.sleep(1) # Rate limiting
elif ai_provider == "Ollama":
for _ in tqdm(range(num_samples), desc="Generating samples"):
try:
response = requests.post('http://localhost:11434/api/generate',
json={
"model": model_name,
"prompt": prompt,
"stream": False
},
timeout=30) # 30 seconds timeout
response.raise_for_status()
synthetic_data.append({"conversations": json.loads(response.json()["response"])})
except json.JSONDecodeError:
logger.warning(f"Failed to decode response as JSON: {response.json()['response']}")
except requests.Timeout:
logger.warning("Ollama API request timed out")
except Exception as e:
logger.error(f"Unexpected error: {str(e)}")
time.sleep(1) # Rate limiting
dataset = Dataset.from_list(synthetic_data)
dataset = standardize_sharegpt(dataset)
if 'text' not in dataset.column_names:
def format_conversation(example):
formatted_text = ""
for turn in example['conversations']:
formatted_text += f"{turn['role']}: {turn['content']}\n"
return {"text": formatted_text.strip()}
dataset = dataset.map(format_conversation)
return dataset |