Spaces:
Build error
Build error
from copy import deepcopy | |
from langchain.callbacks import StreamlitCallbackHandler | |
import streamlit as st | |
import pandas as pd | |
from io import StringIO | |
from transformers import AutoTokenizer, AutoModelForTableQuestionAnswering | |
import numpy as np | |
import weaviate | |
from weaviate.embedded import EmbeddedOptions | |
from weaviate import Client | |
from weaviate.util import generate_uuid5 | |
import logging | |
# Initialize session state attributes | |
if "debug" not in st.session_state: | |
st.session_state.debug = False | |
st_callback = StreamlitCallbackHandler(st.container()) | |
class StreamlitCallbackHandler(logging.Handler): | |
def emit(self, record): | |
log_entry = self.format(record) | |
st.write(log_entry) | |
# Initialize TAPAS model and tokenizer | |
#tokenizer = AutoTokenizer.from_pretrained("google/tapas-large-finetuned-wtq") | |
#model = AutoModelForTableQuestionAnswering.from_pretrained("google/tapas-large-finetuned-wtq") | |
# Initialize Weaviate client for the embedded instance | |
#client = weaviate.Client( | |
# embedded_options=EmbeddedOptions() | |
#) | |
# Global list to store debugging information | |
DEBUG_LOGS = [] | |
def log_debug_info(message): | |
if st.session_state.debug: | |
logger = logging.getLogger(__name__) | |
logger.setLevel(logging.DEBUG) | |
# Check if StreamlitCallbackHandler is already added to avoid duplicate logs | |
if not any(isinstance(handler, StreamlitCallbackHandler) for handler in logger.handlers): | |
handler = StreamlitCallbackHandler() | |
logger.addHandler(handler) | |
logger.debug(message) | |
# Function to check if a class already exists in Weaviate | |
#def class_exists(class_name): | |
# try: | |
# client.schema.get_class(class_name) | |
# return True | |
# except: | |
# return False | |
#def map_dtype_to_weaviate(dtype): | |
## """ | |
# Map pandas data types to Weaviate data types. | |
# """ | |
# if "int" in str(dtype): | |
# return "int" | |
# elif "float" in str(dtype): | |
# return "number" | |
# elif "bool" in str(dtype): | |
# return "boolean" | |
# else: | |
# return "string" | |
# def ingest_data_to_weaviate(dataframe, class_name, class_description): | |
# # Create class schema | |
# class_schema = { | |
# "class": class_name, | |
# "description": class_description, | |
# "properties": [] # Start with an empty properties list | |
# } | |
# | |
# # Try to create the class without properties first | |
# try: | |
# client.schema.create({"classes": [class_schema]}) | |
# except weaviate.exceptions.SchemaValidationException: | |
# # Class might already exist, so we can continue | |
# pass# | |
# # Now, let's add properties to the class | |
# for column_name, data_type in zip(dataframe.columns, dataframe.dtypes): | |
# property_schema = { | |
# "name": column_name, | |
# "description": f"Property for {column_name}", | |
# "dataType": [map_dtype_to_weaviate(data_type)] | |
# } | |
# try: | |
# client.schema.property.create(class_name, property_schema) | |
# except weaviate.exceptions.SchemaValidationException: | |
# # Property might already exist, so we can continue | |
# pass | |
# | |
# # Ingest data | |
# for index, row in dataframe.iterrows(): | |
# obj = { | |
# "class": class_name, | |
# "id": str(index), | |
# "properties": row.to_dict() | |
# } | |
# client.data_object.create(obj) | |
# Log data ingestion | |
# log_debug_info(f"Data ingested into Weaviate for class: {class_name}") | |
def query_weaviate(question): | |
# This is a basic example; adapt the query based on the question | |
results = client.query.get(class_name).with_near_text(question).do() | |
return results | |
def ask_llm_chunk(chunk, questions): | |
chunk = chunk.astype(str) | |
try: | |
inputs = tokenizer(table=chunk, queries=questions, padding="max_length", truncation=True, return_tensors="pt") | |
except Exception as e: | |
log_debug_info(f"Tokenization error: {e}") | |
st.write(f"An error occurred: {e}") | |
return ["Error occurred while tokenizing"] * len(questions) | |
if inputs["input_ids"].shape[1] > 512: | |
log_debug_info("Token limit exceeded for chunk") | |
st.warning("Token limit exceeded for chunk") | |
return ["Token limit exceeded for chunk"] * len(questions) | |
outputs = model(**inputs) | |
predicted_answer_coordinates, predicted_aggregation_indices = tokenizer.convert_logits_to_predictions( | |
inputs, | |
outputs.logits.detach(), | |
outputs.logits_aggregation.detach() | |
) | |
answers = [] | |
for coordinates in predicted_answer_coordinates: | |
if len(coordinates) == 1: | |
row, col = coordinates[0] | |
try: | |
value = chunk.iloc[row, col] | |
log_debug_info(f"Accessed value for row {row}, col {col}: {value}") | |
answers.append(value) | |
except Exception as e: | |
log_debug_info(f"Error accessing value for row {row}, col {col}: {e}") | |
st.write(f"An error occurred: {e}") | |
else: | |
cell_values = [] | |
for coordinate in coordinates: | |
row, col = coordinate | |
try: | |
value = chunk.iloc[row, col] | |
cell_values.append(value) | |
except Exception as e: | |
log_debug_info(f"Error accessing value for row {row}, col {col}: {e}") | |
st.write(f"An error occurred: {e}") | |
answers.append(", ".join(map(str, cell_values))) | |
return answers | |
MAX_ROWS_PER_CHUNK = 200 | |
def summarize_map_reduce(data, questions): | |
dataframe = pd.read_csv(StringIO(data)) | |
num_chunks = len(dataframe) // MAX_ROWS_PER_CHUNK + 1 | |
dataframe_chunks = [deepcopy(chunk) for chunk in np.array_split(dataframe, num_chunks)] | |
all_answers = [] | |
for chunk in dataframe_chunks: | |
chunk_answers = ask_llm_chunk(chunk, questions) | |
all_answers.extend(chunk_answers) | |
return all_answers | |
def get_class_schema(class_name): | |
""" | |
Get the schema for a specific class. | |
""" | |
all_classes = client.schema.get()["classes"] | |
for cls in all_classes: | |
if cls["class"] == class_name: | |
return cls | |
return None | |
st.title("TAPAS Table Question Answering with Weaviate") | |
# Get existing classes from Weaviate | |
existing_classes = [cls["class"] for cls in client.schema.get()["classes"]] | |
class_options = existing_classes + ["New Class"] | |
selected_class = st.selectbox("Select a class or create a new one:", class_options) | |
if selected_class == "New Class": | |
class_name = st.text_input("Enter the new class name:") | |
class_description = st.text_input("Enter a description for the class:") | |
else: | |
class_name = selected_class | |
class_description = "" # We can fetch the description from Weaviate if needed | |
# Upload CSV data | |
csv_file = st.file_uploader("Upload a CSV file", type=["csv"]) | |
# Display the schema if an existing class is selected | |
class_schema = None # Initialize class_schema to None | |
if selected_class != "New Class": | |
st.write(f"Schema for {selected_class}:") | |
class_schema = get_class_schema(selected_class) | |
if class_schema: | |
properties = class_schema["properties"] | |
schema_df = pd.DataFrame(properties) | |
st.table(schema_df[["name", "dataType"]]) # Display only the name and dataType columns | |
# Before ingesting data into Weaviate, check if CSV columns match the class schema | |
if csv_file is not None: | |
data = csv_file.read().decode("utf-8") | |
dataframe = pd.read_csv(StringIO(data)) | |
# Log CSV upload information | |
log_debug_info(f"CSV uploaded with shape: {dataframe.shape}") | |
# Display the uploaded CSV data | |
st.write("Uploaded CSV Data:") | |
st.write(dataframe) | |
# Check if columns match | |
if class_schema: # Ensure class_schema is not None | |
schema_columns = [prop["name"] for prop in class_schema["properties"]] | |
if set(dataframe.columns) != set(schema_columns): | |
st.error("The columns in the uploaded CSV do not match the schema of the selected class. Please check and upload the correct CSV or create a new class.") | |
else: | |
# Ingest data into Weaviate | |
ingest_data_to_weaviate(dataframe, class_name, class_description) | |
# Input for questions | |
questions = st.text_area("Enter your questions (one per line)") | |
questions = questions.split("\n") # split questions by line | |
questions = [q for q in questions if q] # remove empty strings | |
if st.button("Submit"): | |
if data and questions: | |
answers = summarize_map_reduce(data, questions) | |
st.write("Answers:") | |
for q, a in zip(questions, answers): | |
st.write(f"Question: {q}") | |
st.write(f"Answer: {a}") | |
# Display debugging information | |
if st.checkbox("Show Debugging Information"): | |
st.write("Debugging Logs:") | |
for log in DEBUG_LOGS: | |
st.write(log) | |
# Add Ctrl+Enter functionality for submitting the questions | |
st.markdown(""" | |
<script> | |
document.addEventListener("DOMContentLoaded", function(event) { | |
document.addEventListener("keydown", function(event) { | |
if (event.ctrlKey && event.key === "Enter") { | |
document.querySelector(".stButton button").click(); | |
} | |
}); | |
}); | |
</script> | |
""", unsafe_allow_html=True) |