jskinner215 commited on
Commit
b0e4f45
·
1 Parent(s): e2061d5

Initial commit

Browse files
Files changed (1) hide show
  1. app.py +66 -0
app.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import pandas as pd
3
+ from io import StringIO
4
+ from transformers import AutoTokenizer, AutoModelForTableQuestionAnswering
5
+ import numpy as np
6
+
7
+ # Initialize TAPAS model and tokenizer
8
+ tokenizer = AutoTokenizer.from_pretrained("google/tapas-large-finetuned-wtq")
9
+ model = AutoModelForTableQuestionAnswering.from_pretrained("google/tapas-large-finetuned-wtq")
10
+
11
+ def ask_llm_chunk(chunk, questions):
12
+ chunk = chunk.astype(str)
13
+ inputs = tokenizer(table=chunk, queries=questions, padding="max_length", return_tensors="pt")
14
+ if inputs["input_ids"].shape[1] > 1024:
15
+ return ["Token limit exceeded for chunk"] * len(questions)
16
+ outputs = model(**inputs)
17
+ predicted_answer_coordinates, predicted_aggregation_indices = tokenizer.convert_logits_to_predictions(
18
+ inputs,
19
+ outputs.logits.detach(),
20
+ outputs.logits_aggregation.detach()
21
+ )
22
+ answers = []
23
+ for coordinates in predicted_answer_coordinates:
24
+ if len(coordinates) == 1:
25
+ answers.append(chunk.iat[coordinates[0]])
26
+ else:
27
+ cell_values = []
28
+ for coordinate in coordinates:
29
+ cell_values.append(chunk.iat[coordinate])
30
+ answers.append(", ".join(cell_values))
31
+ return answers
32
+
33
+ MAX_ROWS_PER_CHUNK = 200
34
+
35
+ def summarize_map_reduce(data, questions):
36
+ dataframe = pd.read_csv(StringIO(data))
37
+ num_chunks = len(dataframe) // MAX_ROWS_PER_CHUNK + 1
38
+ dataframe_chunks = np.array_split(dataframe, num_chunks)
39
+ all_answers = []
40
+ for chunk in dataframe_chunks:
41
+ chunk_answers = ask_llm_chunk(chunk, questions)
42
+ all_answers.extend(chunk_answers)
43
+ aggregated_answers = all_answers
44
+ return aggregated_answers
45
+
46
+ st.title("TAPAS Table Question Answering")
47
+
48
+ # Upload CSV data
49
+ csv_file = st.file_uploader("Upload a CSV file", type=["csv"])
50
+ if csv_file is not None:
51
+ data = csv_file.read().decode("utf-8")
52
+ st.write("CSV Data Preview:")
53
+ st.write(pd.read_csv(StringIO(data)).head())
54
+
55
+ # Input for questions
56
+ questions = st.text_area("Enter your questions (one per line)")
57
+ questions = questions.split("\n") # split questions by line
58
+ questions = [q for q in questions if q] # remove empty strings
59
+
60
+ if st.button("Submit"):
61
+ if data and questions:
62
+ answers = summarize_map_reduce(data, questions)
63
+ st.write("Answers:")
64
+ for q, a in zip(questions, answers):
65
+ st.write(f"Question: {q}")
66
+ st.write(f"Answer: {a}")