File size: 7,290 Bytes
e1af10a
c0035d4
49b4797
e1af10a
 
c0035d4
f57196e
 
 
f08229e
 
f57196e
1d021ff
 
 
06ae167
1d021ff
 
 
 
 
 
 
 
 
 
 
06ae167
 
1d021ff
 
e1af10a
1d021ff
 
 
 
 
 
e1af10a
 
 
 
 
 
 
 
 
 
 
1d021ff
e1af10a
1d021ff
 
 
 
 
 
06ae167
e1af10a
1d021ff
 
e1af10a
 
1d021ff
 
 
 
 
 
e1af10a
 
 
 
 
 
 
 
 
 
 
1d021ff
 
 
 
 
06ae167
e1af10a
06ae167
 
 
 
 
 
 
 
 
 
 
 
 
e1af10a
 
 
 
 
 
 
 
 
 
 
 
06ae167
 
1d021ff
f57196e
49b4797
2af82d9
 
49b4797
 
9b462b6
2af82d9
 
49b4797
 
 
 
2af82d9
 
 
49b4797
 
2af82d9
 
 
49b4797
 
2af82d9
 
 
49b4797
 
2af82d9
 
 
 
 
 
 
 
f57196e
2af82d9
 
 
 
 
031ca4e
 
cf32cd6
2af82d9
 
 
49b4797
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
import torch
import gradio as gr
import pandas as pd
from utils import save_dataframe_to_file, tokenize_Df
from model import load_model

with open("./description.md", "r", encoding="utf-8") as file:
    description_text = file.read()

with open("./input_demo.txt", "r", encoding="utf-8") as file:
    demo = file.read()

def process_data(task_name, model_name, pooling_method, input_text=None, file=None):
    output = ""
    dataframe_output = pd.DataFrame()
    file_output = None

    # 情况 1: file 和 input_text 都为 None
    if file is None and (input_text is None or input_text.strip() == ""):
        output = "No valid input detected. Please check your input and ensure it follows the expected format."

    # 情况 2: file 和 input_text 都不为 None
    elif file is not None and input_text is not None:
        output = "Detected both text and file input. Prioritizing file input."
        # 检查文件类型
        if not (file.name.endswith('.csv') or file.name.endswith('.xlsx')):
            output += " File format must be xlsx or csv."
        elif task_name == "Appropriateness" and model_name == "One-phase Fine-tuned BERT":
            output += " One-phase Fine-tuned BERT model does not support Appropriateness task."
        else:
            # 读取文件
            df = pd.read_csv(file) if file.name.endswith('.csv') else pd.read_excel(file)
            # 检查第一行是否为 "prompt" 和 "response"
            if list(df.columns) == ['prompt', 'response']:
                dataframe_output = df
            else:
                df_values = [list(df.columns)] + df.values.tolist()
                dataframe_output = pd.DataFrame(df_values, columns=['prompt', 'response'])

            # model 运行
            loaded_net = load_model(model_name, pooling_method)
            example = tokenize_Df(dataframe_output)
            with torch.no_grad():  
                score = loaded_net(example)

            if model_name == "One-phase Fine-tuned BERT":
                dataframe_output['evaluation'] = score.numpy()
            else: 
                dataframe_output['evaluation'] = score[0].numpy() if task_name=='Creativity' else score[1].numpy()
            file_output = save_dataframe_to_file(dataframe_output, file_format="csv") 
            output += f" Processed {len(dataframe_output)} rows from uploaded file using task: {task_name}, model: {model_name}, pooling: {pooling_method}."

    # 情况 3: 只有 file
    elif file is not None:
        # 检查文件类型
        if not (file.name.endswith('.csv') or file.name.endswith('.xlsx')):
            output = "File format must be xlsx or csv."
        elif task_name == "Appropriateness" and model_name == "One-phase Fine-tuned BERT":
            output = " One-phase Fine-tuned BERT model does not support Appropriateness task."            
        else:
            # 读取文件
            df = pd.read_csv(file) if file.name.endswith('.csv') else pd.read_excel(file)

            # 检查第一行是否为 "prompt" 和 "response"
            if list(df.columns) == ['prompt', 'response']:
                dataframe_output = df
            else:
                df_values = [list(df.columns)] + df.values.tolist()
                dataframe_output = pd.DataFrame(df_values, columns=['prompt', 'response'])
            
            # model 运行
            loaded_net = load_model(model_name, pooling_method)
            example = tokenize_Df(dataframe_output)
            with torch.no_grad():
                score = loaded_net(example)
            
            if model_name == "One-phase Fine-tuned BERT":
                dataframe_output['evaluation'] = score.numpy()
            else: 
                dataframe_output['evaluation'] = score[0].numpy() if task_name=='Creativity' else score[1].numpy()
            file_output = save_dataframe_to_file(dataframe_output, file_format="csv") 
            output = f"Processed {len(dataframe_output)} rows from uploaded file using task: {task_name}, model: {model_name}, pooling: {pooling_method}."

    # 情况 4: 只有 input_text
    elif input_text is not None:
        if task_name == "Appropriateness" and model_name == "One-phase Fine-tuned BERT":
            output = "One-phase Fine-tuned BERT model does not support Appropriateness task."
        else:
            lines = input_text.strip().split("\n")
            rows = []
            for line in lines:
                try:
                    split_line = line.split(",", maxsplit=1)
                    if len(split_line) == 2:
                        rows.append(split_line)
                except Exception as e:
                    output = f"Error processing line: {line}"
                    break
            
            if output == "":
                dataframe_output = pd.DataFrame(rows[1:], columns=['prompt', 'response']) if rows[0] == ['prompt', 'response'] else pd.DataFrame(rows, columns=['prompt', 'response'])
                
                # model 运行
                loaded_net = load_model(model_name, pooling_method)
                example = tokenize_Df(dataframe_output)
                with torch.no_grad(): 
                    score = loaded_net(example)

                if model_name == "One-phase Fine-tuned BERT":
                    dataframe_output['evaluation'] = score.numpy()
                else: 
                    dataframe_output['evaluation'] = score[0].numpy() if task_name=='Creativity' else score[1].numpy()
                file_output = save_dataframe_to_file(dataframe_output, file_format="csv") 
                output = f"Processed {len(dataframe_output)} rows of text using task: {task_name}, model: {model_name}, pooling: {pooling_method}."

    return output, dataframe_output, file_output

## 输入组件
task_dropdown = gr.Dropdown(
                label="Task Name", 
                choices=["Creativity", "Appropriateness"], 
                value="Appropriateness")

model_dropdown = gr.Dropdown(
                label="Model Name", 
                choices=[
                    "One-phase Fine-tuned BERT",
                    "Two-phase Fine-tuned BERT"], 
                value="Two-phase Fine-tuned BERT")

pooling_dropdown = gr.Dropdown(
                label="Pooling", 
                choices=["mean", "cls"], 
                value="cls")

text_input = gr.Textbox(
                label="Text Input", 
                lines=10, 
                value=demo)

file_input = gr.File(
                label="Input File", 
                type="filepath", 
                file_types=[".csv", ".xlsx"])

## 输出组件
output_box = gr.Textbox(label="Output", lines=5, interactive=False)

dataframe_output = gr.Dataframe(label="DataFrame", interactive=False)

file_output = gr.File(label="Output File", interactive=False)

# 构建Gradio界面
interface = gr.Interface(
    fn=process_data,
    inputs=[task_dropdown, model_dropdown, pooling_dropdown, text_input, file_input],
    outputs=[output_box, dataframe_output, file_output],
    css=(""".file-download {display: none !important;}
            h1 {text-align: center;}"""),
    title="TwoPhaseLLMs-CreativityAutoEvaluation",
    description=description_text,
    theme=gr.themes.Soft(),
)

# 启动界面
interface.launch()