Spaces:
Build error
Build error
import streamlit as st | |
import pandas as pd | |
import numpy as np | |
from io import StringIO | |
from transformers import AutoTokenizer, AutoModelForTableQuestionAnswering | |
# Initialize TAPAS model and tokenizer | |
tokenizer = AutoTokenizer.from_pretrained("google/tapas-large-finetuned-wtq") | |
model = AutoModelForTableQuestionAnswering.from_pretrained("google/tapas-large-finetuned-wtq") | |
def ask_llm_chunk(chunk, questions): | |
chunk = chunk.astype(str) | |
try: | |
inputs = tokenizer(table=chunk, queries=questions, padding="max_length", return_tensors="pt") | |
if inputs["input_ids"].shape[1] > 512: | |
return ["Token limit exceeded for this chunk"] * len(questions) | |
outputs = model(**inputs) | |
predicted_answer_coordinates, predicted_aggregation_indices = tokenizer.convert_logits_to_predictions( | |
inputs, | |
outputs.logits.detach(), | |
outputs.logits_aggregation.detach() | |
) | |
except Exception as e: | |
st.write(f"An error occurred: {e}") | |
return ["Error processing this chunk"] * len(questions) | |
answers = [] | |
for coordinates in predicted_answer_coordinates: | |
if len(coordinates) == 1: | |
answers.append(chunk.iat[coordinates[0]]) | |
else: | |
cell_values = [] | |
for coordinate in coordinates: | |
cell_values.append(chunk.iat[coordinate]) | |
answers.append(", ".join(cell_values)) | |
return answers | |
MAX_ROWS_PER_CHUNK = 50 # Reduced chunk size | |
def summarize_map_reduce(data, questions): | |
try: | |
dataframe = pd.read_csv(StringIO(data)) | |
except Exception as e: | |
st.write(f"Error reading the CSV file: {e}") | |
return [] | |
num_chunks = len(dataframe) // MAX_ROWS_PER_CHUNK + 1 | |
dataframe_chunks = np.array_split(dataframe, num_chunks) | |
all_answers = [] | |
for chunk in dataframe_chunks: | |
chunk_answers = ask_llm_chunk(chunk, questions) | |
all_answers.extend(chunk_answers) | |
return all_answers | |
st.title("TAPAS Table Question Answering") | |
# Upload CSV data | |
csv_file = st.file_uploader("Upload a CSV file", type=["csv"]) | |
if csv_file is not None: | |
data = csv_file.read().decode("utf-8") | |
st.write("CSV Data Preview:") | |
st.write(pd.read_csv(StringIO(data)).head()) | |
# Input for questions | |
questions = st.text_area("Enter your questions (one per line)") | |
questions = questions.split("\n") # split questions by line | |
questions = [q for q in questions if q] # remove empty strings | |
if st.button("Submit"): | |
if data and questions: | |
try: | |
answers = summarize_map_reduce(data, questions) | |
st.write("Answers:") | |
for q, a in zip(questions, answers): | |
st.write(f"Question: {q}") | |
st.write(f"Answer: {a}") | |
except Exception as e: | |
st.write(f"An error occurred: {e}") | |