jskinner215's picture
import numpy as np
2cc5fad
raw
history blame
2.91 kB
import streamlit as st
import pandas as pd
import numpy as np
from io import StringIO
from transformers import AutoTokenizer, AutoModelForTableQuestionAnswering
# Initialize TAPAS model and tokenizer
tokenizer = AutoTokenizer.from_pretrained("google/tapas-large-finetuned-wtq")
model = AutoModelForTableQuestionAnswering.from_pretrained("google/tapas-large-finetuned-wtq")
def ask_llm_chunk(chunk, questions):
chunk = chunk.astype(str)
try:
inputs = tokenizer(table=chunk, queries=questions, padding="max_length", return_tensors="pt")
if inputs["input_ids"].shape[1] > 512:
return ["Token limit exceeded for this chunk"] * len(questions)
outputs = model(**inputs)
predicted_answer_coordinates, predicted_aggregation_indices = tokenizer.convert_logits_to_predictions(
inputs,
outputs.logits.detach(),
outputs.logits_aggregation.detach()
)
except Exception as e:
st.write(f"An error occurred: {e}")
return ["Error processing this chunk"] * len(questions)
answers = []
for coordinates in predicted_answer_coordinates:
if len(coordinates) == 1:
answers.append(chunk.iat[coordinates[0]])
else:
cell_values = []
for coordinate in coordinates:
cell_values.append(chunk.iat[coordinate])
answers.append(", ".join(cell_values))
return answers
MAX_ROWS_PER_CHUNK = 50 # Reduced chunk size
def summarize_map_reduce(data, questions):
try:
dataframe = pd.read_csv(StringIO(data))
except Exception as e:
st.write(f"Error reading the CSV file: {e}")
return []
num_chunks = len(dataframe) // MAX_ROWS_PER_CHUNK + 1
dataframe_chunks = np.array_split(dataframe, num_chunks)
all_answers = []
for chunk in dataframe_chunks:
chunk_answers = ask_llm_chunk(chunk, questions)
all_answers.extend(chunk_answers)
return all_answers
st.title("TAPAS Table Question Answering")
# Upload CSV data
csv_file = st.file_uploader("Upload a CSV file", type=["csv"])
if csv_file is not None:
data = csv_file.read().decode("utf-8")
st.write("CSV Data Preview:")
st.write(pd.read_csv(StringIO(data)).head())
# Input for questions
questions = st.text_area("Enter your questions (one per line)")
questions = questions.split("\n") # split questions by line
questions = [q for q in questions if q] # remove empty strings
if st.button("Submit"):
if data and questions:
try:
answers = summarize_map_reduce(data, questions)
st.write("Answers:")
for q, a in zip(questions, answers):
st.write(f"Question: {q}")
st.write(f"Answer: {a}")
except Exception as e:
st.write(f"An error occurred: {e}")