from copy import deepcopy from langchain.callbacks import StreamlitCallbackHandler import streamlit as st import pandas as pd from io import StringIO from transformers import AutoTokenizer, AutoModelForTableQuestionAnswering import numpy as np import weaviate from weaviate.embedded import EmbeddedOptions from weaviate import Client from weaviate.util import generate_uuid5 import logging # Initialize session state attributes if "debug" not in st.session_state: st.session_state.debug = False st_callback = StreamlitCallbackHandler(st.container()) class StreamlitCallbackHandler(logging.Handler): def emit(self, record): log_entry = self.format(record) st.write(log_entry) # Initialize TAPAS model and tokenizer #tokenizer = AutoTokenizer.from_pretrained("google/tapas-large-finetuned-wtq") #model = AutoModelForTableQuestionAnswering.from_pretrained("google/tapas-large-finetuned-wtq") # Initialize Weaviate client for the embedded instance #client = weaviate.Client( # embedded_options=EmbeddedOptions() #) # Global list to store debugging information DEBUG_LOGS = [] def log_debug_info(message): if st.session_state.debug: logger = logging.getLogger(__name__) logger.setLevel(logging.DEBUG) # Check if StreamlitCallbackHandler is already added to avoid duplicate logs if not any(isinstance(handler, StreamlitCallbackHandler) for handler in logger.handlers): handler = StreamlitCallbackHandler() logger.addHandler(handler) logger.debug(message) # Function to check if a class already exists in Weaviate #def class_exists(class_name): # try: # client.schema.get_class(class_name) # return True # except: # return False #def map_dtype_to_weaviate(dtype): ## """ # Map pandas data types to Weaviate data types. # """ # if "int" in str(dtype): # return "int" # elif "float" in str(dtype): # return "number" # elif "bool" in str(dtype): # return "boolean" # else: # return "string" # def ingest_data_to_weaviate(dataframe, class_name, class_description): # # Create class schema # class_schema = { # "class": class_name, # "description": class_description, # "properties": [] # Start with an empty properties list # } # # # Try to create the class without properties first # try: # client.schema.create({"classes": [class_schema]}) # except weaviate.exceptions.SchemaValidationException: # # Class might already exist, so we can continue # pass# # # Now, let's add properties to the class # for column_name, data_type in zip(dataframe.columns, dataframe.dtypes): # property_schema = { # "name": column_name, # "description": f"Property for {column_name}", # "dataType": [map_dtype_to_weaviate(data_type)] # } # try: # client.schema.property.create(class_name, property_schema) # except weaviate.exceptions.SchemaValidationException: # # Property might already exist, so we can continue # pass # # # Ingest data # for index, row in dataframe.iterrows(): # obj = { # "class": class_name, # "id": str(index), # "properties": row.to_dict() # } # client.data_object.create(obj) # Log data ingestion # log_debug_info(f"Data ingested into Weaviate for class: {class_name}") def query_weaviate(question): # This is a basic example; adapt the query based on the question results = client.query.get(class_name).with_near_text(question).do() return results def ask_llm_chunk(chunk, questions): chunk = chunk.astype(str) try: inputs = tokenizer(table=chunk, queries=questions, padding="max_length", truncation=True, return_tensors="pt") except Exception as e: log_debug_info(f"Tokenization error: {e}") st.write(f"An error occurred: {e}") return ["Error occurred while tokenizing"] * len(questions) if inputs["input_ids"].shape[1] > 512: log_debug_info("Token limit exceeded for chunk") st.warning("Token limit exceeded for chunk") return ["Token limit exceeded for chunk"] * len(questions) outputs = model(**inputs) predicted_answer_coordinates, predicted_aggregation_indices = tokenizer.convert_logits_to_predictions( inputs, outputs.logits.detach(), outputs.logits_aggregation.detach() ) answers = [] for coordinates in predicted_answer_coordinates: if len(coordinates) == 1: row, col = coordinates[0] try: value = chunk.iloc[row, col] log_debug_info(f"Accessed value for row {row}, col {col}: {value}") answers.append(value) except Exception as e: log_debug_info(f"Error accessing value for row {row}, col {col}: {e}") st.write(f"An error occurred: {e}") else: cell_values = [] for coordinate in coordinates: row, col = coordinate try: value = chunk.iloc[row, col] cell_values.append(value) except Exception as e: log_debug_info(f"Error accessing value for row {row}, col {col}: {e}") st.write(f"An error occurred: {e}") answers.append(", ".join(map(str, cell_values))) return answers MAX_ROWS_PER_CHUNK = 200 def summarize_map_reduce(data, questions): dataframe = pd.read_csv(StringIO(data)) num_chunks = len(dataframe) // MAX_ROWS_PER_CHUNK + 1 dataframe_chunks = [deepcopy(chunk) for chunk in np.array_split(dataframe, num_chunks)] all_answers = [] for chunk in dataframe_chunks: chunk_answers = ask_llm_chunk(chunk, questions) all_answers.extend(chunk_answers) return all_answers def get_class_schema(class_name): """ Get the schema for a specific class. """ all_classes = client.schema.get()["classes"] for cls in all_classes: if cls["class"] == class_name: return cls return None st.title("TAPAS Table Question Answering with Weaviate") # Get existing classes from Weaviate existing_classes = [cls["class"] for cls in client.schema.get()["classes"]] class_options = existing_classes + ["New Class"] selected_class = st.selectbox("Select a class or create a new one:", class_options) if selected_class == "New Class": class_name = st.text_input("Enter the new class name:") class_description = st.text_input("Enter a description for the class:") else: class_name = selected_class class_description = "" # We can fetch the description from Weaviate if needed # Upload CSV data csv_file = st.file_uploader("Upload a CSV file", type=["csv"]) # Display the schema if an existing class is selected class_schema = None # Initialize class_schema to None if selected_class != "New Class": st.write(f"Schema for {selected_class}:") class_schema = get_class_schema(selected_class) if class_schema: properties = class_schema["properties"] schema_df = pd.DataFrame(properties) st.table(schema_df[["name", "dataType"]]) # Display only the name and dataType columns # Before ingesting data into Weaviate, check if CSV columns match the class schema if csv_file is not None: data = csv_file.read().decode("utf-8") dataframe = pd.read_csv(StringIO(data)) # Log CSV upload information log_debug_info(f"CSV uploaded with shape: {dataframe.shape}") # Display the uploaded CSV data st.write("Uploaded CSV Data:") st.write(dataframe) # Check if columns match if class_schema: # Ensure class_schema is not None schema_columns = [prop["name"] for prop in class_schema["properties"]] if set(dataframe.columns) != set(schema_columns): st.error("The columns in the uploaded CSV do not match the schema of the selected class. Please check and upload the correct CSV or create a new class.") else: # Ingest data into Weaviate ingest_data_to_weaviate(dataframe, class_name, class_description) # Input for questions questions = st.text_area("Enter your questions (one per line)") questions = questions.split("\n") # split questions by line questions = [q for q in questions if q] # remove empty strings if st.button("Submit"): if data and questions: answers = summarize_map_reduce(data, questions) st.write("Answers:") for q, a in zip(questions, answers): st.write(f"Question: {q}") st.write(f"Answer: {a}") # Display debugging information if st.checkbox("Show Debugging Information"): st.write("Debugging Logs:") for log in DEBUG_LOGS: st.write(log) # Add Ctrl+Enter functionality for submitting the questions st.markdown(""" """, unsafe_allow_html=True)