abdullahmubeen10 commited on
Commit
00035e6
·
verified ·
1 Parent(s): 0c3388a

Update Demo.py

Browse files
Files changed (1) hide show
  1. Demo.py +11 -8
Demo.py CHANGED
@@ -53,22 +53,24 @@ def create_pipeline(model):
53
  .setInputCols(["document_table"])\
54
  .setOutputCol("table")
55
 
56
- tapas = TapasForQuestionAnswering\
57
- .pretrained(model, "en")\
58
  .setInputCols(["questions", "table"])\
59
  .setOutputCol("answers_wtq")
 
 
 
 
 
60
 
61
- pipeline = Pipeline(stages=[document_assembler, sentence_detector, table_assembler, tapas])
62
  return pipeline
63
 
64
  def fit_data(pipeline, json_data, question):
65
  spark_df = spark.createDataFrame([[json_data, question]]).toDF("table_json", "questions")
66
  model = pipeline.fit(spark_df)
67
- lightPipelineModel = LightPipeline(model)
68
- table, data = f""" table_json:{json_data} """ , """ questions:{question} """
69
- result = lightPipelineModel.fullAnnotate([table, data])
70
- st.write(result)
71
- return result
72
 
73
  # Sidebar content
74
  model = st.sidebar.selectbox(
@@ -170,6 +172,7 @@ output = fit_data(pipeline, table_json_str, text_to_analyze)
170
  st.markdown("---")
171
  st.subheader("Processed Output")
172
 
 
173
  # # Check if output is available
174
  # if output:
175
  # results_wtq = output[0][0] if output[0][0] else "No results found."
 
53
  .setInputCols(["document_table"])\
54
  .setOutputCol("table")
55
 
56
+ tapas_wtq = TapasForQuestionAnswering\
57
+ .pretrained("table_qa_tapas_base_finetuned_wtq", "en")\
58
  .setInputCols(["questions", "table"])\
59
  .setOutputCol("answers_wtq")
60
+
61
+ tapas_sqa = TapasForQuestionAnswering\
62
+ .pretrained("table_qa_tapas_base_finetuned_sqa", "en")\
63
+ .setInputCols(["questions", "table"])\
64
+ .setOutputCol("answers_sqa")
65
 
66
+ pipeline = Pipeline(stages=[document_assembler, sentence_detector, table_assembler, tapas_wtq, tapas_sqa])
67
  return pipeline
68
 
69
  def fit_data(pipeline, json_data, question):
70
  spark_df = spark.createDataFrame([[json_data, question]]).toDF("table_json", "questions")
71
  model = pipeline.fit(spark_df)
72
+ result = model.transform(spark_df)
73
+ return str(result.select("answers_wtq.result", "answers_sqa.result"))
 
 
 
74
 
75
  # Sidebar content
76
  model = st.sidebar.selectbox(
 
172
  st.markdown("---")
173
  st.subheader("Processed Output")
174
 
175
+ output
176
  # # Check if output is available
177
  # if output:
178
  # results_wtq = output[0][0] if output[0][0] else "No results found."