Spaces:
Paused
Paused
import json | |
import pandas as pd | |
data_path = [ | |
"./data/Chinese-medical-dialogue-data-master/Data_数据/IM_内科/内科5000-33000.csv", | |
"./data/Chinese-medical-dialogue-data-master/Data_数据/Oncology_肿瘤科/肿瘤科5-10000.csv", | |
"./data/Chinese-medical-dialogue-data-master/Data_数据/Pediatric_儿科/儿科5-14000.csv", | |
"./data/Chinese-medical-dialogue-data-master/Data_数据/Surgical_外科/外科5-14000.csv", | |
] | |
train_json_path = "./data/train.json" | |
val_json_path = "./data/val.json" | |
# 每个数据取 10000 条作为训练 | |
train_size = 10000 | |
# 每个数据取 2000 条作为验证 | |
val_size = 2000 | |
def main(): | |
train_f = open(train_json_path, "a", encoding='utf-8') | |
val_f = open(val_json_path, "a", encoding='utf-8') | |
for path in data_path: | |
data = pd.read_csv(path, encoding='ANSI') | |
train_count = 0 | |
val_count = 0 | |
for index, row in data.iterrows(): | |
question = row["ask"] | |
answer = row["answer"] | |
line = { | |
"question": question, | |
"answer": answer | |
} | |
line = json.dumps(line, ensure_ascii=False) | |
if train_count < train_size: | |
train_f.write(line + "\n") | |
train_count = train_count + 1 | |
elif val_count < val_size: | |
val_f.write(line + "\n") | |
val_count = val_count + 1 | |
else: | |
break | |
print("数据处理完毕!") | |
train_f.close() | |
val_f.close() | |
if __name__ == '__main__': | |
main() | |