Siyunb323's picture
update some files
9b462b6
raw
history blame
7.29 kB
import torch
import gradio as gr
import pandas as pd
from utils import save_dataframe_to_file, tokenize_Df
from model import load_model
with open("./description.md", "r", encoding="utf-8") as file:
description_text = file.read()
with open("./input_demo.txt", "r", encoding="utf-8") as file:
demo = file.read()
def process_data(task_name, model_name, pooling_method, input_text=None, file=None):
output = ""
dataframe_output = pd.DataFrame()
file_output = None
# 情况 1: file 和 input_text 都为 None
if file is None and (input_text is None or input_text.strip() == ""):
output = "No valid input detected. Please check your input and ensure it follows the expected format."
# 情况 2: file 和 input_text 都不为 None
elif file is not None and input_text is not None:
output = "Detected both text and file input. Prioritizing file input."
# 检查文件类型
if not (file.name.endswith('.csv') or file.name.endswith('.xlsx')):
output += " File format must be xlsx or csv."
elif task_name == "Appropriateness" and model_name == "One-phase Fine-tuned BERT":
output += " One-phase Fine-tuned BERT model does not support Appropriateness task."
else:
# 读取文件
df = pd.read_csv(file) if file.name.endswith('.csv') else pd.read_excel(file)
# 检查第一行是否为 "prompt" 和 "response"
if list(df.columns) == ['prompt', 'response']:
dataframe_output = df
else:
df_values = [list(df.columns)] + df.values.tolist()
dataframe_output = pd.DataFrame(df_values, columns=['prompt', 'response'])
# model 运行
loaded_net = load_model(model_name, pooling_method)
example = tokenize_Df(dataframe_output)
with torch.no_grad():
score = loaded_net(example)
if model_name == "One-phase Fine-tuned BERT":
dataframe_output['evaluation'] = score.numpy()
else:
dataframe_output['evaluation'] = score[0].numpy() if task_name=='Creativity' else score[1].numpy()
file_output = save_dataframe_to_file(dataframe_output, file_format="csv")
output += f" Processed {len(dataframe_output)} rows from uploaded file using task: {task_name}, model: {model_name}, pooling: {pooling_method}."
# 情况 3: 只有 file
elif file is not None:
# 检查文件类型
if not (file.name.endswith('.csv') or file.name.endswith('.xlsx')):
output = "File format must be xlsx or csv."
elif task_name == "Appropriateness" and model_name == "One-phase Fine-tuned BERT":
output = " One-phase Fine-tuned BERT model does not support Appropriateness task."
else:
# 读取文件
df = pd.read_csv(file) if file.name.endswith('.csv') else pd.read_excel(file)
# 检查第一行是否为 "prompt" 和 "response"
if list(df.columns) == ['prompt', 'response']:
dataframe_output = df
else:
df_values = [list(df.columns)] + df.values.tolist()
dataframe_output = pd.DataFrame(df_values, columns=['prompt', 'response'])
# model 运行
loaded_net = load_model(model_name, pooling_method)
example = tokenize_Df(dataframe_output)
with torch.no_grad():
score = loaded_net(example)
if model_name == "One-phase Fine-tuned BERT":
dataframe_output['evaluation'] = score.numpy()
else:
dataframe_output['evaluation'] = score[0].numpy() if task_name=='Creativity' else score[1].numpy()
file_output = save_dataframe_to_file(dataframe_output, file_format="csv")
output = f"Processed {len(dataframe_output)} rows from uploaded file using task: {task_name}, model: {model_name}, pooling: {pooling_method}."
# 情况 4: 只有 input_text
elif input_text is not None:
if task_name == "Appropriateness" and model_name == "One-phase Fine-tuned BERT":
output = "One-phase Fine-tuned BERT model does not support Appropriateness task."
else:
lines = input_text.strip().split("\n")
rows = []
for line in lines:
try:
split_line = line.split(",", maxsplit=1)
if len(split_line) == 2:
rows.append(split_line)
except Exception as e:
output = f"Error processing line: {line}"
break
if output == "":
dataframe_output = pd.DataFrame(rows[1:], columns=['prompt', 'response']) if rows[0] == ['prompt', 'response'] else pd.DataFrame(rows, columns=['prompt', 'response'])
# model 运行
loaded_net = load_model(model_name, pooling_method)
example = tokenize_Df(dataframe_output)
with torch.no_grad():
score = loaded_net(example)
if model_name == "One-phase Fine-tuned BERT":
dataframe_output['evaluation'] = score.numpy()
else:
dataframe_output['evaluation'] = score[0].numpy() if task_name=='Creativity' else score[1].numpy()
file_output = save_dataframe_to_file(dataframe_output, file_format="csv")
output = f"Processed {len(dataframe_output)} rows of text using task: {task_name}, model: {model_name}, pooling: {pooling_method}."
return output, dataframe_output, file_output
## 输入组件
task_dropdown = gr.Dropdown(
label="Task Name",
choices=["Creativity", "Appropriateness"],
value="Appropriateness")
model_dropdown = gr.Dropdown(
label="Model Name",
choices=[
"One-phase Fine-tuned BERT",
"Two-phase Fine-tuned BERT"],
value="Two-phase Fine-tuned BERT")
pooling_dropdown = gr.Dropdown(
label="Pooling",
choices=["mean", "cls"],
value="cls")
text_input = gr.Textbox(
label="Text Input",
lines=10,
value=demo)
file_input = gr.File(
label="Input File",
type="filepath",
file_types=[".csv", ".xlsx"])
## 输出组件
output_box = gr.Textbox(label="Output", lines=5, interactive=False)
dataframe_output = gr.Dataframe(label="DataFrame", interactive=False)
file_output = gr.File(label="Output File", interactive=False)
# 构建Gradio界面
interface = gr.Interface(
fn=process_data,
inputs=[task_dropdown, model_dropdown, pooling_dropdown, text_input, file_input],
outputs=[output_box, dataframe_output, file_output],
css=(""".file-download {display: none !important;}
h1 {text-align: center;}"""),
title="TwoPhaseBERT-CreativityAutoEvaluation",
description=description_text,
theme=gr.themes.Soft(),
)
# 启动界面
interface.launch()