jskinner215 commited on
Commit
d304ae4
·
1 Parent(s): 862e59b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -168
app.py CHANGED
@@ -6,48 +6,19 @@ from weaviate_utils import *
6
  from tapas_utils import *
7
  from ui_utils import *
8
 
9
- # ...
10
- selected_class = ui_utils.display_class_dropdown(client)
11
- ui_utils.handle_new_class_selection(selected_class)
12
- ui_utils.csv_upload_and_ingestion(selected_class)
13
- ui_utils.display_query_input()
14
- # ...
15
-
16
  # Initialize Weaviate client
17
  client = initialize_weaviate_client()
18
 
19
  # Initialize TAPAS
20
  tokenizer, model = initialize_tapas()
21
 
22
- # UI components
23
- display_initial_buttons()
24
- selected_class = display_class_dropdown(client)
25
- handle_new_class_selection()
26
- csv_upload_and_ingestion()
27
- display_query_input()
28
-
29
- # Initialize session state attributes
30
- if "debug" not in st.session_state:
31
- st.session_state.debug = False
32
-
33
- st_callback = StreamlitCallbackHandler(st.container())
34
 
35
  class StreamlitCallbackHandler(logging.Handler):
36
  def emit(self, record):
37
  log_entry = self.format(record)
38
  st.write(log_entry)
39
-
40
- # Initialize TAPAS model and tokenizer
41
- #tokenizer = AutoTokenizer.from_pretrained("google/tapas-large-finetuned-wtq")
42
- #model = AutoModelForTableQuestionAnswering.from_pretrained("google/tapas-large-finetuned-wtq")
43
-
44
- # Initialize Weaviate client for the embedded instance
45
- #client = weaviate.Client(
46
- # embedded_options=EmbeddedOptions()
47
- #)
48
-
49
- # Global list to store debugging information
50
- DEBUG_LOGS = []
51
 
52
  def log_debug_info(message):
53
  if st.session_state.debug:
@@ -61,140 +32,16 @@ def log_debug_info(message):
61
 
62
  logger.debug(message)
63
 
 
 
 
 
 
 
64
 
65
- # Function to check if a class already exists in Weaviate
66
- #def class_exists(class_name):
67
- # try:
68
- # client.schema.get_class(class_name)
69
- # return True
70
- # except:
71
- # return False
72
-
73
- #def map_dtype_to_weaviate(dtype):
74
- ## """
75
- # Map pandas data types to Weaviate data types.
76
- # """
77
- # if "int" in str(dtype):
78
- # return "int"
79
- # elif "float" in str(dtype):
80
- # return "number"
81
- # elif "bool" in str(dtype):
82
- # return "boolean"
83
- # else:
84
- # return "string"
85
-
86
- # def ingest_data_to_weaviate(dataframe, class_name, class_description):
87
- # # Create class schema
88
- # class_schema = {
89
- # "class": class_name,
90
- # "description": class_description,
91
- # "properties": [] # Start with an empty properties list
92
- # }
93
- #
94
- # # Try to create the class without properties first
95
- # try:
96
- # client.schema.create({"classes": [class_schema]})
97
- # except weaviate.exceptions.SchemaValidationException:
98
- # # Class might already exist, so we can continue
99
- # pass#
100
-
101
- # # Now, let's add properties to the class
102
- # for column_name, data_type in zip(dataframe.columns, dataframe.dtypes):
103
- # property_schema = {
104
- # "name": column_name,
105
- # "description": f"Property for {column_name}",
106
- # "dataType": [map_dtype_to_weaviate(data_type)]
107
- # }
108
- # try:
109
- # client.schema.property.create(class_name, property_schema)
110
- # except weaviate.exceptions.SchemaValidationException:
111
- # # Property might already exist, so we can continue
112
- # pass
113
- #
114
- # # Ingest data
115
- # for index, row in dataframe.iterrows():
116
- # obj = {
117
- # "class": class_name,
118
- # "id": str(index),
119
- # "properties": row.to_dict()
120
- # }
121
- # client.data_object.create(obj)
122
-
123
- # Log data ingestion
124
- # log_debug_info(f"Data ingested into Weaviate for class: {class_name}")
125
-
126
- def query_weaviate(question):
127
- # This is a basic example; adapt the query based on the question
128
- results = client.query.get(class_name).with_near_text(question).do()
129
- return results
130
-
131
- #def ask_llm_chunk(chunk, questions):
132
- # chunk = chunk.astype(str)
133
- # try:
134
- # inputs = tokenizer(table=chunk, queries=questions, padding="max_length", truncation=True, return_tensors="pt")
135
- # except Exception as e:
136
- # log_debug_info(f"Tokenization error: {e}")
137
- # st.write(f"An error occurred: {e}")
138
- # return ["Error occurred while tokenizing"] * len(questions)
139
- #
140
- ## if inputs["input_ids"].shape[1] > 512:
141
- # log_debug_info("Token limit exceeded for chunk")
142
- # st.warning("Token limit exceeded for chunk")
143
- # return ["Token limit exceeded for chunk"] * len(questions)#
144
- #
145
- # outputs = model(**inputs)
146
- # predicted_answer_coordinates, predicted_aggregation_indices = tokenizer.convert_logits_to_predictions(
147
- # inputs,
148
- # outputs.logits.detach(),
149
- # outputs.logits_aggregation.detach()
150
- # )
151
- #
152
- # answers = []
153
- # for coordinates in predicted_answer_coordinates:
154
- # if len(coordinates) == 1:
155
- # row, col = coordinates[0]
156
- # try:
157
- # value = chunk.iloc[row, col]
158
- # log_debug_info(f"Accessed value for row {row}, col {col}: {value}")
159
- # answers.append(value)
160
- # except Exception as e:
161
- # log_debug_info(f"Error accessing value for row {row}, col {col}: {e}")
162
- # st.write(f"An error occurred: {e}")
163
- # else:
164
- # cell_values = []
165
- # for coordinate in coordinates:
166
- # row, col = coordinate
167
- # try:
168
- # value = chunk.iloc[row, col]
169
- # cell_values.append(value)
170
- # except Exception as e:
171
- # log_debug_info(f"Error accessing value for row {row}, col {col}: {e}")
172
- # st.write(f"An error occurred: {e}")
173
- # answers.append(", ".join(map(str, cell_values)))
174
- #
175
- # return answers
176
-
177
- # MAX_ROWS_PER_CHUNK = 200
178
-
179
- # def summarize_map_reduce(data, questions):
180
- # dataframe = pd.read_csv(StringIO(data))
181
- # num_chunks = len(dataframe) // MAX_ROWS_PER_CHUNK + 1
182
- # dataframe_chunks = [deepcopy(chunk) for chunk in np.array_split(dataframe, num_chunks)]
183
- # all_answers = []
184
- # for chunk in dataframe_chunks:
185
- # chunk_answers = ask_llm_chunk(chunk, questions)
186
- # all_answers.extend(chunk_answers)
187
- # return all_answers
188
-
189
- def get_class_schema(class_name):
190
- """
191
- Get the schema for a specific class.
192
- """
193
- all_classes = client.schema.get()["classes"]
194
- for cls in all_classes:
195
- if cls["class"] == class_name:
196
- return cls
197
- return None
198
 
199
  st.title("TAPAS Table Question Answering with Weaviate")
200
 
@@ -217,7 +64,7 @@ csv_file = st.file_uploader("Upload a CSV file", type=["csv"])
217
  class_schema = None # Initialize class_schema to None
218
  if selected_class != "New Class":
219
  st.write(f"Schema for {selected_class}:")
220
- class_schema = get_class_schema(selected_class)
221
  if class_schema:
222
  properties = class_schema["properties"]
223
  schema_df = pd.DataFrame(properties)
@@ -242,7 +89,7 @@ if csv_file is not None:
242
  st.error("The columns in the uploaded CSV do not match the schema of the selected class. Please check and upload the correct CSV or create a new class.")
243
  else:
244
  # Ingest data into Weaviate
245
- ingest_data_to_weaviate(dataframe, class_name, class_description)
246
 
247
  # Input for questions
248
  questions = st.text_area("Enter your questions (one per line)")
@@ -251,7 +98,7 @@ if csv_file is not None:
251
 
252
  if st.button("Submit"):
253
  if data and questions:
254
- answers = summarize_map_reduce(data, questions)
255
  st.write("Answers:")
256
  for q, a in zip(questions, answers):
257
  st.write(f"Question: {q}")
@@ -274,4 +121,4 @@ st.markdown("""
274
  });
275
  });
276
  </script>
277
- """, unsafe_allow_html=True)
 
6
  from tapas_utils import *
7
  from ui_utils import *
8
 
 
 
 
 
 
 
 
9
  # Initialize Weaviate client
10
  client = initialize_weaviate_client()
11
 
12
  # Initialize TAPAS
13
  tokenizer, model = initialize_tapas()
14
 
15
+ # Global list to store debugging information
16
+ DEBUG_LOGS = []
 
 
 
 
 
 
 
 
 
 
17
 
18
  class StreamlitCallbackHandler(logging.Handler):
19
  def emit(self, record):
20
  log_entry = self.format(record)
21
  st.write(log_entry)
 
 
 
 
 
 
 
 
 
 
 
 
22
 
23
  def log_debug_info(message):
24
  if st.session_state.debug:
 
32
 
33
  logger.debug(message)
34
 
35
+ # UI components
36
+ ui_utils.display_initial_buttons()
37
+ selected_class = ui_utils.display_class_dropdown(client)
38
+ ui_utils.handle_new_class_selection(client, selected_class)
39
+ ui_utils.csv_upload_and_ingestion(client, selected_class)
40
+ ui_utils.display_query_input()
41
 
42
+ # Initialize session state attributes
43
+ if "debug" not in st.session_state:
44
+ st.session_state.debug = False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
 
46
  st.title("TAPAS Table Question Answering with Weaviate")
47
 
 
64
  class_schema = None # Initialize class_schema to None
65
  if selected_class != "New Class":
66
  st.write(f"Schema for {selected_class}:")
67
+ class_schema = get_class_schema(client, selected_class)
68
  if class_schema:
69
  properties = class_schema["properties"]
70
  schema_df = pd.DataFrame(properties)
 
89
  st.error("The columns in the uploaded CSV do not match the schema of the selected class. Please check and upload the correct CSV or create a new class.")
90
  else:
91
  # Ingest data into Weaviate
92
+ ingest_data_to_weaviate(client, dataframe, class_name, class_description)
93
 
94
  # Input for questions
95
  questions = st.text_area("Enter your questions (one per line)")
 
98
 
99
  if st.button("Submit"):
100
  if data and questions:
101
+ answers = summarize_map_reduce(tokenizer, model, data, questions)
102
  st.write("Answers:")
103
  for q, a in zip(questions, answers):
104
  st.write(f"Question: {q}")
 
121
  });
122
  });
123
  </script>
124
+ """, unsafe_allow_html=True)