Spaces:
Running
Running
Commit
·
05dae02
1
Parent(s):
da6fc76
group by model
Browse files
app.py
CHANGED
@@ -61,9 +61,12 @@ def response(num_responses, model, correct, prompt_ids):
|
|
61 |
responses = responses[responses['correct'].isin(correct)]
|
62 |
if prompt_ids:
|
63 |
responses = responses[responses['prompt'].isin(prompt_ids)]
|
64 |
-
if num_responses > len(responses):
|
65 |
-
|
66 |
-
|
|
|
|
|
|
|
67 |
|
68 |
|
69 |
def barplot_for_prompt_id(prompt_ids, models):
|
|
|
61 |
responses = responses[responses['correct'].isin(correct)]
|
62 |
if prompt_ids:
|
63 |
responses = responses[responses['prompt'].isin(prompt_ids)]
|
64 |
+
# if num_responses > len(responses):
|
65 |
+
# num_responses = len(responses)
|
66 |
+
|
67 |
+
# sample num_responses for each model
|
68 |
+
responses = responses.groupby('model').apply(lambda x: x.sample(num_responses) if num_responses < len(x) else x).reset_index()
|
69 |
+
return responses[['model', 'prompt', 'model_response', 'correct']]
|
70 |
|
71 |
|
72 |
def barplot_for_prompt_id(prompt_ids, models):
|