File size: 2,987 Bytes
dfe37be
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import json
import os
import csv
#用来计算数据集中不同问题种类对应的pass@k的平均值
input_dir = 'E:\python-testn\pythonProject3\hh_1\evaluate_result'

# 获取目录中的所有文件
files = os.listdir(input_dir)

with open("cata_result.csv", "w", newline='') as csvfile:
    writer = csv.writer(csvfile)
    writer.writerow(["Model", "String", "Math","Array","Sorting","Hash Table","Stack","Search","Matrix"])


for file_name in files:
    # 构建完整的文件路径
    input_file_path = os.path.join(input_dir, file_name)
    first_underscore_index = file_name.find('_')

    # 找到最后一个 - 的位置
    last_dash_index = file_name.rfind('-')
    model_name = file_name[first_underscore_index + 1:last_dash_index]
    print(model_name)
    with open(input_file_path, "r", encoding="utf-8") as file:
        data1 = json.load(file)

    with open("humaneval_with_cata.json","r",encoding="utf-8") as file:
        data2=json.load(file)
    sum0=0
    count0=0
    sum1=0
    count1=0
    sum2=0
    count2=0
    sum3=0
    count3=0
    sum4=0
    count4=0
    sum5=0
    count5=0
    sum6=0
    count6=0
    sum7=0
    count7=0
    for (item1,item2) in zip(data1["humaneval"]["pass@1"],data2):


        if "String" in item2["answer"]:
            index, value = item1
            sum0=sum0+value
            count0=count0+1

        if "Math" in item2["answer"]:
            index, value = item1
            sum1=sum1+value
            count1=count1+1

        if "Array" in item2["answer"]:
            index, value = item1
            sum2=sum2+value
            count2=count2+1
        if "Sorting" in item2["answer"]:
            index, value = item1
            sum3=sum3+value
            count3=count3+1
        if "Hash table" in item2["answer"]:
            index, value = item1
            sum4 = sum4 + value
            count4 = count4 + 1

        if "Stack" in item2["answer"]:
            index, value = item1
            sum5=sum5+value
            count5=count5+1

        if "Search" in item2["answer"]:
            index, value = item1
            sum6=sum6+value
            count6=count6+1

        if "Matrix" in item2["answer"]:
            index, value = item1
            sum7=sum7+value
            count7=count7+1

    mean0=round(sum0/count0*100,2)
    mean1=round(sum1/count1*100,2)
    mean2=round(sum2/count2*100,2)
    mean3=round(sum3/count3*100,2)
    mean4=round(sum4/count4*100,2)
    mean5=round(sum5/count5*100,2)
    mean6=round(sum6/count6*100,2)
    mean7=round(sum7/count7*100,2)
    print(count0,count1,count2,count3,count4,count5,count6,count7)
    print(mean0,mean1,mean2,mean3,mean4,mean5,mean6,mean7)
    with open("cata_result.csv", mode='a', newline='', encoding='utf-8') as file:
        writer = csv.writer(file)
        writer.writerow([model_name,mean0,mean1,mean2,mean3,mean4,mean5,mean6,mean7])