Siyunb323 commited on
Commit
1d021ff
·
1 Parent(s): 2af82d9
Files changed (1) hide show
  1. app.py +83 -12
app.py CHANGED
@@ -1,5 +1,16 @@
 
1
  import gradio as gr
2
  import pandas as pd
 
 
 
 
 
 
 
 
 
 
3
 
4
  with open("./description.md", "r", encoding="utf-8") as file:
5
  description_text = file.read()
@@ -8,18 +19,78 @@ with open("./input_demo.txt", "r", encoding="utf-8") as file:
8
  demo = file.read()
9
 
10
  # 定义处理函数
11
- def process_data(task_name, model_name, pooling_method, input_text, file=None):
12
- if file:
13
- df = pd.read_csv(file.name)
14
- output = f"Processed {len(df)} rows from uploaded file using task: {task_name}, model: {model_name}, pooling: {pooling_method}."
15
- dataframe_output = df # 返回数据框用于显示
16
- file_output = df # 输出文件框显示相同的数据
17
- else:
18
- lines = input_text.split("\n")
19
- output = f"Processed {len(lines)} rows of text using task: {task_name}, model: {model_name}, pooling: {pooling_method}."
20
- dataframe_output = pd.DataFrame({"Error": ["No file uploaded. DataFrame preview unavailable."]}) # 错误信息
21
- file_output = pd.DataFrame()
22
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
  return output, dataframe_output, file_output
24
 
25
  ## 输入组件
 
1
+ import os
2
  import gradio as gr
3
  import pandas as pd
4
+ import tempfile
5
+
6
+ def save_dataframe_to_file(dataframe, file_format="csv"):
7
+ temp_dir = tempfile.gettempdir() # 获取系统临时目录
8
+ file_path = os.path.join(temp_dir, f"output.{file_format}")
9
+ if file_format == "csv":
10
+ dataframe.to_csv(file_path, index=False)
11
+ elif file_format == "xlsx":
12
+ dataframe.to_excel(file_path, index=False)
13
+ return file_path
14
 
15
  with open("./description.md", "r", encoding="utf-8") as file:
16
  description_text = file.read()
 
19
  demo = file.read()
20
 
21
  # 定义处理函数
22
+ import pandas as pd
23
+
24
+ def process_data(task_name, model_name, pooling_method, input_text=None, file=None):
25
+ output = ""
26
+ dataframe_output = pd.DataFrame()
27
+ file_output = pd.DataFrame()
28
+
29
+ # 情况 1: file 和 input_text 都为 None
30
+ if file is None and (input_text is None or input_text.strip() == ""):
31
+ output = "No valid input detected. Please check your input and ensure it follows the expected format."
32
+
33
+ # 情况 2: file 和 input_text 都不为 None
34
+ elif file is not None and input_text is not None:
35
+ output = "Detected both text and file input. Prioritizing file input."
36
+ # 检查文件类型
37
+ if not (file.name.endswith('.csv') or file.name.endswith('.xlsx')):
38
+ output += " File format must be xlsx or csv."
39
+ else:
40
+ # 读取文件
41
+ if file.name.endswith('.csv'):
42
+ df = pd.read_csv(file)
43
+ else:
44
+ df = pd.read_excel(file)
45
+ # 检查第一行是否为 "prompt" 和 "response"
46
+ if list(df.columns) == ['prompt', 'response']:
47
+ dataframe_output = df
48
+ else:
49
+ df_values = [list(df.columns)] + df.values.tolist()
50
+ dataframe_output = pd.DataFrame(df_values, columns=['prompt', 'response'])
51
+ file_output = save_dataframe_to_file(dataframe_output, file_format="csv")
52
+
53
+ # 情况 3: 只有 file
54
+ elif file is not None:
55
+ # 检查文件类型
56
+ if not (file.name.endswith('.csv') or file.name.endswith('.xlsx')):
57
+ output = "File format must be xlsx or csv."
58
+ else:
59
+ # 读取文件
60
+ if file.name.endswith('.csv'):
61
+ df = pd.read_csv(file)
62
+ else:
63
+ df = pd.read_excel(file)
64
+ # 检查第一行是否为 "prompt" 和 "response"
65
+ if list(df.columns) == ['prompt', 'response']:
66
+ dataframe_output = df
67
+ else:
68
+ df_values = [list(df.columns)] + df.values.tolist()
69
+ dataframe_output = pd.DataFrame(df_values, columns=['prompt', 'response'])
70
+ file_output = save_dataframe_to_file(dataframe_output, file_format="csv")
71
+ output = f"Processed {len(dataframe_output)} rows from uploaded file using task: {task_name}, model: {model_name}, pooling: {pooling_method}."
72
+
73
+ # 情况 4: 只有 input_text
74
+ elif input_text is not None:
75
+ lines = input_text.strip().split("\n")
76
+ rows = []
77
+ for line in lines:
78
+ try:
79
+ split_line = line.split(",", maxsplit=1)
80
+ if len(split_line) == 2:
81
+ rows.append(split_line)
82
+ except Exception as e:
83
+ output = f"Error processing line: {line}"
84
+ break
85
+
86
+ if output == "":
87
+ if rows[0] == ['prompt', 'response']:
88
+ dataframe_output = pd.DataFrame(rows[1:], columns=['prompt', 'response'])
89
+ else:
90
+ dataframe_output = pd.DataFrame(rows, columns=['prompt', 'response'])
91
+ file_output = save_dataframe_to_file(dataframe_output, file_format="csv")
92
+ output = f"Processed {len(dataframe_output)} rows of text using task: {task_name}, model: {model_name}, pooling: {pooling_method}."
93
+
94
  return output, dataframe_output, file_output
95
 
96
  ## 输入组件