import streamlit as st import pandas as pd import numpy as np from io import StringIO from transformers import AutoTokenizer, AutoModelForTableQuestionAnswering # Initialize TAPAS model and tokenizer tokenizer = AutoTokenizer.from_pretrained("google/tapas-large-finetuned-wtq") model = AutoModelForTableQuestionAnswering.from_pretrained("google/tapas-large-finetuned-wtq") def ask_llm_chunk(chunk, questions): chunk = chunk.astype(str) try: inputs = tokenizer(table=chunk, queries=questions, padding="max_length", return_tensors="pt") if inputs["input_ids"].shape[1] > 512: return ["Token limit exceeded for this chunk"] * len(questions) outputs = model(**inputs) predicted_answer_coordinates, predicted_aggregation_indices = tokenizer.convert_logits_to_predictions( inputs, outputs.logits.detach(), outputs.logits_aggregation.detach() ) except Exception as e: st.write(f"An error occurred: {e}") return ["Error processing this chunk"] * len(questions) answers = [] for coordinates in predicted_answer_coordinates: if len(coordinates) == 1: answers.append(chunk.iat[coordinates[0]]) else: cell_values = [] for coordinate in coordinates: cell_values.append(chunk.iat[coordinate]) answers.append(", ".join(cell_values)) return answers MAX_ROWS_PER_CHUNK = 50 # Reduced chunk size def summarize_map_reduce(data, questions): try: dataframe = pd.read_csv(StringIO(data)) except Exception as e: st.write(f"Error reading the CSV file: {e}") return [] num_chunks = len(dataframe) // MAX_ROWS_PER_CHUNK + 1 dataframe_chunks = np.array_split(dataframe, num_chunks) all_answers = [] for chunk in dataframe_chunks: chunk_answers = ask_llm_chunk(chunk, questions) all_answers.extend(chunk_answers) return all_answers st.title("TAPAS Table Question Answering") # Upload CSV data csv_file = st.file_uploader("Upload a CSV file", type=["csv"]) if csv_file is not None: data = csv_file.read().decode("utf-8") st.write("CSV Data Preview:") st.write(pd.read_csv(StringIO(data)).head()) # Input for questions questions = st.text_area("Enter your questions (one per line)") questions = questions.split("\n") # split questions by line questions = [q for q in questions if q] # remove empty strings if st.button("Submit"): if data and questions: try: answers = summarize_map_reduce(data, questions) st.write("Answers:") for q, a in zip(questions, answers): st.write(f"Question: {q}") st.write(f"Answer: {a}") except Exception as e: st.write(f"An error occurred: {e}")