import json import os import csv #用来计算数据集中不同问题种类对应的pass@k的平均值 input_dir = 'E:\python-testn\pythonProject3\hh_1\evaluate_result' # 获取目录中的所有文件 files = os.listdir(input_dir) with open("cata_result.csv", "w", newline='') as csvfile: writer = csv.writer(csvfile) writer.writerow(["Model", "String", "Math","Array","Sorting","Hash Table","Stack","Search","Matrix"]) for file_name in files: # 构建完整的文件路径 input_file_path = os.path.join(input_dir, file_name) first_underscore_index = file_name.find('_') # 找到最后一个 - 的位置 last_dash_index = file_name.rfind('-') model_name = file_name[first_underscore_index + 1:last_dash_index] print(model_name) with open(input_file_path, "r", encoding="utf-8") as file: data1 = json.load(file) with open("humaneval_with_cata.json","r",encoding="utf-8") as file: data2=json.load(file) sum0=0 count0=0 sum1=0 count1=0 sum2=0 count2=0 sum3=0 count3=0 sum4=0 count4=0 sum5=0 count5=0 sum6=0 count6=0 sum7=0 count7=0 for (item1,item2) in zip(data1["humaneval"]["pass@1"],data2): if "String" in item2["answer"]: index, value = item1 sum0=sum0+value count0=count0+1 if "Math" in item2["answer"]: index, value = item1 sum1=sum1+value count1=count1+1 if "Array" in item2["answer"]: index, value = item1 sum2=sum2+value count2=count2+1 if "Sorting" in item2["answer"]: index, value = item1 sum3=sum3+value count3=count3+1 if "Hash table" in item2["answer"]: index, value = item1 sum4 = sum4 + value count4 = count4 + 1 if "Stack" in item2["answer"]: index, value = item1 sum5=sum5+value count5=count5+1 if "Search" in item2["answer"]: index, value = item1 sum6=sum6+value count6=count6+1 if "Matrix" in item2["answer"]: index, value = item1 sum7=sum7+value count7=count7+1 mean0=round(sum0/count0*100,2) mean1=round(sum1/count1*100,2) mean2=round(sum2/count2*100,2) mean3=round(sum3/count3*100,2) mean4=round(sum4/count4*100,2) mean5=round(sum5/count5*100,2) mean6=round(sum6/count6*100,2) mean7=round(sum7/count7*100,2) print(count0,count1,count2,count3,count4,count5,count6,count7) print(mean0,mean1,mean2,mean3,mean4,mean5,mean6,mean7) with open("cata_result.csv", mode='a', newline='', encoding='utf-8') as file: writer = csv.writer(file) writer.writerow([model_name,mean0,mean1,mean2,mean3,mean4,mean5,mean6,mean7])