Update Demo.py
Browse files
Demo.py
CHANGED
@@ -53,22 +53,24 @@ def create_pipeline(model):
|
|
53 |
.setInputCols(["document_table"])\
|
54 |
.setOutputCol("table")
|
55 |
|
56 |
-
|
57 |
-
.pretrained(
|
58 |
.setInputCols(["questions", "table"])\
|
59 |
.setOutputCol("answers_wtq")
|
|
|
|
|
|
|
|
|
|
|
60 |
|
61 |
-
pipeline = Pipeline(stages=[document_assembler, sentence_detector, table_assembler,
|
62 |
return pipeline
|
63 |
|
64 |
def fit_data(pipeline, json_data, question):
|
65 |
spark_df = spark.createDataFrame([[json_data, question]]).toDF("table_json", "questions")
|
66 |
model = pipeline.fit(spark_df)
|
67 |
-
|
68 |
-
|
69 |
-
result = lightPipelineModel.fullAnnotate([table, data])
|
70 |
-
st.write(result)
|
71 |
-
return result
|
72 |
|
73 |
# Sidebar content
|
74 |
model = st.sidebar.selectbox(
|
@@ -170,6 +172,7 @@ output = fit_data(pipeline, table_json_str, text_to_analyze)
|
|
170 |
st.markdown("---")
|
171 |
st.subheader("Processed Output")
|
172 |
|
|
|
173 |
# # Check if output is available
|
174 |
# if output:
|
175 |
# results_wtq = output[0][0] if output[0][0] else "No results found."
|
|
|
53 |
.setInputCols(["document_table"])\
|
54 |
.setOutputCol("table")
|
55 |
|
56 |
+
tapas_wtq = TapasForQuestionAnswering\
|
57 |
+
.pretrained("table_qa_tapas_base_finetuned_wtq", "en")\
|
58 |
.setInputCols(["questions", "table"])\
|
59 |
.setOutputCol("answers_wtq")
|
60 |
+
|
61 |
+
tapas_sqa = TapasForQuestionAnswering\
|
62 |
+
.pretrained("table_qa_tapas_base_finetuned_sqa", "en")\
|
63 |
+
.setInputCols(["questions", "table"])\
|
64 |
+
.setOutputCol("answers_sqa")
|
65 |
|
66 |
+
pipeline = Pipeline(stages=[document_assembler, sentence_detector, table_assembler, tapas_wtq, tapas_sqa])
|
67 |
return pipeline
|
68 |
|
69 |
def fit_data(pipeline, json_data, question):
|
70 |
spark_df = spark.createDataFrame([[json_data, question]]).toDF("table_json", "questions")
|
71 |
model = pipeline.fit(spark_df)
|
72 |
+
result = model.transform(spark_df)
|
73 |
+
return str(result.select("answers_wtq.result", "answers_sqa.result"))
|
|
|
|
|
|
|
74 |
|
75 |
# Sidebar content
|
76 |
model = st.sidebar.selectbox(
|
|
|
172 |
st.markdown("---")
|
173 |
st.subheader("Processed Output")
|
174 |
|
175 |
+
output
|
176 |
# # Check if output is available
|
177 |
# if output:
|
178 |
# results_wtq = output[0][0] if output[0][0] else "No results found."
|