nhosseini commited on
Commit
b9f8ef1
·
verified ·
1 Parent(s): 8890d1c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +106 -91
app.py CHANGED
@@ -1,91 +1,106 @@
1
- import gradio as gr
2
- import pandas as pd
3
- from transformers import TapexTokenizer, BartForConditionalGeneration, pipeline
4
-
5
- # Initialize TAPEX (Microsoft) model and tokenizer
6
- tokenizer_tapex = TapexTokenizer.from_pretrained("microsoft/tapex-large-finetuned-wtq")
7
- model_tapex = BartForConditionalGeneration.from_pretrained("microsoft/tapex-large-finetuned-wtq")
8
-
9
- # Initialize TAPAS (Google) models and pipelines
10
- pipe_tapas = pipeline(task="table-question-answering", model="google/tapas-large-finetuned-wtq")
11
- pipe_tapas2 = pipeline(task="table-question-answering", model="google/tapas-large-finetuned-wikisql-supervised")
12
-
13
- def process_table_query(query, table_data):
14
- """
15
- Process a query and CSV data using TAPEX.
16
- """
17
- # Convert all columns in the table to strings for TAPEX compatibility
18
- table_data = table_data.astype(str)
19
-
20
- # Microsoft TAPEX model (using TAPEX tokenizer and model)
21
- encoding = tokenizer_tapex(table=table_data, query=query, return_tensors="pt", max_length=1024, truncation=True)
22
- outputs = model_tapex.generate(**encoding)
23
- result_tapex = tokenizer_tapex.batch_decode(outputs, skip_special_tokens=True)[0]
24
-
25
- return result_tapex
26
-
27
- # Gradio interface
28
- def answer_query_from_csv(query, file):
29
- """
30
- Function to handle file input and return model results.
31
- """
32
- # Read the file into a DataFrame
33
- table_data = pd.read_csv(file)
34
-
35
- # Convert object-type columns to lowercase (if they are valid strings)
36
- for column in table_data.columns:
37
- if table_data[column].dtype == 'object':
38
- table_data[column] = table_data[column].apply(lambda x: x.lower() if isinstance(x, str) else x)
39
-
40
- # Convert all table cells to strings for TAPEX compatibility
41
- table_data = table_data.astype(str)
42
-
43
- # Extract year, month, day, and time components for datetime columns
44
- for column in table_data.columns:
45
- if pd.api.types.is_datetime64_any_dtype(table_data[column]):
46
- table_data[f'{column}_year'] = table_data[column].dt.year
47
- table_data[f'{column}_month'] = table_data[column].dt.month
48
- table_data[f'{column}_day'] = table_data[column].dt.day
49
- table_data[f'{column}_time'] = table_data[column].dt.strftime('%H:%M:%S')
50
-
51
- # Process the CSV file and query
52
- result_tapex = process_table_query(query, table_data)
53
-
54
- # Process the query using TAPAS pipelines
55
- result_tapas = pipe_tapas(table=table_data, query=query)['cells'][0]
56
- result_tapas2 = pipe_tapas2(table=table_data, query=query)['cells'][0]
57
-
58
- return result_tapex, result_tapas, result_tapas2
59
-
60
- # Create Gradio interface
61
- with gr.Blocks() as interface:
62
- gr.Markdown("# Table Question Answering with TAPEX and TAPAS Models")
63
-
64
- # Add a notice about the token limit
65
- gr.Markdown("### Note: Only the first 1024 tokens (query + table data) will be considered. If your table is too large, it will be truncated to fit within this limit.")
66
-
67
- # Two-column layout (input on the left, output on the right)
68
- with gr.Row():
69
- with gr.Column():
70
- # Input fields for the query and file
71
- query_input = gr.Textbox(label="Enter your query:")
72
- csv_input = gr.File(label="Upload your CSV file")
73
-
74
- with gr.Column():
75
- # Output textboxes for the answers
76
- result_tapex = gr.Textbox(label="TAPEX Answer")
77
- result_tapas = gr.Textbox(label="TAPAS (WikiTableQuestions) Answer")
78
- result_tapas2 = gr.Textbox(label="TAPAS (WikiSQL) Answer")
79
-
80
- # Submit button
81
- submit_btn = gr.Button("Submit")
82
-
83
- # Action when submit button is clicked
84
- submit_btn.click(
85
- fn=answer_query_from_csv,
86
- inputs=[query_input, csv_input],
87
- outputs=[result_tapex, result_tapas, result_tapas2]
88
- )
89
-
90
- # Launch the Gradio interface
91
- interface.launch(share=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import pandas as pd
3
+ from transformers import TapexTokenizer, BartForConditionalGeneration, pipeline
4
+
5
+ # Initialize TAPEX (Microsoft) model and tokenizer
6
+ tokenizer_tapex = TapexTokenizer.from_pretrained("microsoft/tapex-large-finetuned-wtq")
7
+ model_tapex = BartForConditionalGeneration.from_pretrained("microsoft/tapex-large-finetuned-wtq")
8
+
9
+ # Initialize TAPAS (Google) models and pipelines
10
+ pipe_tapas = pipeline(task="table-question-answering", model="google/tapas-large-finetuned-wtq")
11
+ pipe_tapas2 = pipeline(task="table-question-answering", model="google/tapas-large-finetuned-wikisql-supervised")
12
+
13
+ def chunk_dataframe(df, max_tokens=1024):
14
+ """
15
+ Chunk a large dataframe into smaller pieces that fit within the token limit.
16
+ For simplicity, we're assuming the number of rows determines the token count.
17
+ """
18
+ chunk_size = max_tokens // len(df.columns) # Approximate number of rows that fit
19
+ return [df[i:i+chunk_size] for i in range(0, len(df), chunk_size)]
20
+
21
+ def process_table_query(query, table_data):
22
+ """
23
+ Process a query and CSV data using TAPEX.
24
+ """
25
+ # Convert all columns in the table to strings for TAPEX compatibility
26
+ table_data = table_data.astype(str)
27
+
28
+ # Chunk the table if it's too large
29
+ chunks = chunk_dataframe(table_data)
30
+
31
+ results = []
32
+ for chunk in chunks:
33
+ # Microsoft TAPEX model (using TAPEX tokenizer and model)
34
+ encoding = tokenizer_tapex(table=chunk, query=query, return_tensors="pt", max_length=1024, truncation=True)
35
+ outputs = model_tapex.generate(**encoding)
36
+ result_tapex = tokenizer_tapex.batch_decode(outputs, skip_special_tokens=True)[0]
37
+ results.append(result_tapex)
38
+
39
+ # Aggregate results
40
+ return ' '.join(results)
41
+
42
+ # Gradio interface
43
+ def answer_query_from_csv(query, file):
44
+ """
45
+ Function to handle file input and return model results.
46
+ """
47
+ # Read the file into a DataFrame
48
+ table_data = pd.read_csv(file)
49
+
50
+ # Convert object-type columns to lowercase (if they are valid strings)
51
+ for column in table_data.columns:
52
+ if table_data[column].dtype == 'object':
53
+ table_data[column] = table_data[column].apply(lambda x: x.lower() if isinstance(x, str) else x)
54
+
55
+ # Convert all table cells to strings for TAPEX compatibility
56
+ table_data = table_data.astype(str)
57
+
58
+ # Extract year, month, day, and time components for datetime columns
59
+ for column in table_data.columns:
60
+ if pd.api.types.is_datetime64_any_dtype(table_data[column]):
61
+ table_data[f'{column}_year'] = table_data[column].dt.year
62
+ table_data[f'{column}_month'] = table_data[column].dt.month
63
+ table_data[f'{column}_day'] = table_data[column].dt.day
64
+ table_data[f'{column}_time'] = table_data[column].dt.strftime('%H:%M:%S')
65
+
66
+ # Process the CSV file and query using TAPEX
67
+ result_tapex = process_table_query(query, table_data)
68
+
69
+ # Process the query using TAPAS pipelines
70
+ result_tapas = pipe_tapas(table=table_data, query=query)['cells'][0]
71
+ result_tapas2 = pipe_tapas2(table=table_data, query=query)['cells'][0]
72
+
73
+ return result_tapex, result_tapas, result_tapas2
74
+
75
+ # Create Gradio interface
76
+ with gr.Blocks() as interface:
77
+ gr.Markdown("# Table Question Answering with TAPEX and TAPAS Models")
78
+
79
+ # Add a notice about the token limit
80
+ gr.Markdown("### Note: Only the first 1024 tokens (query + table data) will be considered per chunk. If your table is too large, it will be chunked and processed separately.")
81
+
82
+ # Two-column layout (input on the left, output on the right)
83
+ with gr.Row():
84
+ with gr.Column():
85
+ # Input fields for the query and file
86
+ query_input = gr.Textbox(label="Enter your query:")
87
+ csv_input = gr.File(label="Upload your CSV file")
88
+
89
+ with gr.Column():
90
+ # Output textboxes for the answers
91
+ result_tapex = gr.Textbox(label="TAPEX Answer")
92
+ result_tapas = gr.Textbox(label="TAPAS (WikiTableQuestions) Answer")
93
+ result_tapas2 = gr.Textbox(label="TAPAS (WikiSQL) Answer")
94
+
95
+ # Submit button
96
+ submit_btn = gr.Button("Submit")
97
+
98
+ # Action when submit button is clicked
99
+ submit_btn.click(
100
+ fn=answer_query_from_csv,
101
+ inputs=[query_input, csv_input],
102
+ outputs=[result_tapex, result_tapas, result_tapas2]
103
+ )
104
+
105
+ # Launch the Gradio interface
106
+ interface.launch(share=True)