|
import json |
|
from datasets import Dataset |
|
|
|
|
|
def prepare_data(file_path): |
|
with open(file_path, 'r', encoding='utf-8') as f: |
|
data = json.load(f) |
|
|
|
processed_data = [] |
|
for query in data['queries']: |
|
words = query['text'].split() |
|
labels = ['O'] * len(words) |
|
for start, end, entity_type, entity_text in query['entities']: |
|
entity_words = entity_text.lower().split() |
|
found = False |
|
for i in range(len(words) - len(entity_words) + 1): |
|
if [w.lower() for w in words[i:i + len(entity_words)]] == entity_words: |
|
for j, word in enumerate(words[i:i + len(entity_words)]): |
|
labels[i + j] = f'B-{entity_type}' if j == 0 else f'I-{entity_type}' |
|
found = True |
|
break |
|
if not found: |
|
print(f"Warning: Entity '{entity_text}' not found in text '{query['text']}'") |
|
processed_data.append({'words': words, 'labels': labels}) |
|
|
|
return Dataset.from_list(processed_data) |
|
|
|
|
|
train_dataset = prepare_data('/home/ebk/PycharmProjects/pythonProject/tripgo-hotel/train_dataset.json') |
|
eval_dataset = prepare_data('/home/ebk/PycharmProjects/pythonProject/tripgo-hotel/eval_dataset.json') |
|
|
|
|