|
import streamlit as st
|
|
import sparknlp
|
|
import pandas as pd
|
|
import json
|
|
|
|
from sparknlp.base import *
|
|
from sparknlp.annotator import *
|
|
from pyspark.ml import Pipeline
|
|
from sparknlp.pretrained import PretrainedPipeline
|
|
|
|
|
|
st.set_page_config(
|
|
layout="wide",
|
|
initial_sidebar_state="auto"
|
|
)
|
|
|
|
|
|
st.markdown("""
|
|
<style>
|
|
.main-title {
|
|
font-size: 36px;
|
|
color: #4A90E2;
|
|
font-weight: bold;
|
|
text-align: center;
|
|
}
|
|
.section {
|
|
background-color: #f9f9f9;
|
|
padding: 10px;
|
|
border-radius: 10px;
|
|
margin-top: 10px;
|
|
}
|
|
.section p, .section ul {
|
|
color: #666666;
|
|
}
|
|
</style>
|
|
""", unsafe_allow_html=True)
|
|
|
|
@st.cache_resource
|
|
def init_spark():
|
|
return sparknlp.start()
|
|
|
|
@st.cache_resource
|
|
def create_pipeline(model):
|
|
document_assembler = MultiDocumentAssembler() \
|
|
.setInputCols("table_json", "questions") \
|
|
.setOutputCols("document_table", "document_questions")
|
|
|
|
sentence_detector = SentenceDetector() \
|
|
.setInputCols(["document_questions"]) \
|
|
.setOutputCol("questions")
|
|
|
|
table_assembler = TableAssembler()\
|
|
.setInputCols(["document_table"])\
|
|
.setOutputCol("table")
|
|
|
|
tapas_wtq = TapasForQuestionAnswering\
|
|
.pretrained("table_qa_tapas_base_finetuned_wtq", "en")\
|
|
.setInputCols(["questions", "table"])\
|
|
.setOutputCol("answers_wtq")
|
|
|
|
tapas_sqa = TapasForQuestionAnswering\
|
|
.pretrained("table_qa_tapas_base_finetuned_sqa", "en")\
|
|
.setInputCols(["questions", "table"])\
|
|
.setOutputCol("answers_sqa")
|
|
|
|
pipeline = Pipeline(stages=[document_assembler, sentence_detector, table_assembler, tapas_wtq, tapas_sqa])
|
|
return pipeline
|
|
|
|
def fit_data(pipeline, json_data, question):
|
|
spark_df = spark.createDataFrame([[json_data, question]]).toDF("table_json", "questions")
|
|
model = pipeline.fit(spark_df)
|
|
result = model.transform(spark_df)
|
|
return result.select("answers_wtq.result", "answers_sqa.result").collect()
|
|
|
|
|
|
model = st.sidebar.selectbox(
|
|
"Choose the pretrained model",
|
|
["table_qa_tapas_base_finetuned_wtq", "table_qa_tapas_base_finetuned_sqa"],
|
|
help="For more info about the models visit: https://sparknlp.org/models"
|
|
)
|
|
|
|
|
|
title = 'TAPAS for Table-Based Question Answering with Spark NLP'
|
|
sub_title = (
|
|
'TAPAS (Table Parsing Supervised via Pre-trained Language Models) is a model that extends '
|
|
'the BERT architecture to handle tabular data. Unlike traditional models that require flattening '
|
|
'tables into text, TAPAS can directly interpret tables, making it a powerful tool for answering '
|
|
'questions that involve tabular data.'
|
|
)
|
|
|
|
st.markdown(f'<div class="main-title">{title}</div>', unsafe_allow_html=True)
|
|
st.markdown(f'<div class="section"><p>{sub_title}</p></div>', unsafe_allow_html=True)
|
|
|
|
|
|
link = """
|
|
<a href="https://github.com/JohnSnowLabs/spark-nlp-workshop/blob/master/tutorials/Certification_Trainings/Public/15.1_Table_Question_Answering.ipynb">
|
|
<img src="https://colab.research.google.com/assets/colab-badge.svg" style="zoom: 1.3" alt="Open In Colab"/>
|
|
</a>
|
|
"""
|
|
st.sidebar.markdown('Reference notebook:')
|
|
st.sidebar.markdown(link, unsafe_allow_html=True)
|
|
|
|
|
|
|
|
json_data = '''
|
|
{
|
|
"header": ["name", "net_worth", "age", "nationality", "company", "industry"],
|
|
"rows": [
|
|
["Elon Musk", "$200,000,000,000", "52", "American", "Tesla, SpaceX", "Automotive, Aerospace"],
|
|
["Jeff Bezos", "$150,000,000,000", "60", "American", "Amazon", "E-commerce"],
|
|
["Bernard Arnault", "$210,000,000,000", "74", "French", "LVMH", "Luxury Goods"],
|
|
["Bill Gates", "$120,000,000,000", "68", "American", "Microsoft", "Technology"],
|
|
["Warren Buffett", "$110,000,000,000", "93", "American", "Berkshire Hathaway", "Conglomerate"],
|
|
["Larry Page", "$100,000,000,000", "51", "American", "Google", "Technology"],
|
|
["Mark Zuckerberg", "$85,000,000,000", "40", "American", "Meta", "Social Media"],
|
|
["Mukesh Ambani", "$80,000,000,000", "67", "Indian", "Reliance Industries", "Conglomerate"],
|
|
["Alice Walton", "$65,000,000,000", "74", "American", "Walmart", "Retail"],
|
|
["Francoise Bettencourt Meyers", "$70,000,000,000", "70", "French", "L'Oreal", "Cosmetics"],
|
|
["Amancio Ortega", "$75,000,000,000", "88", "Spanish", "Inditex (Zara)", "Retail"],
|
|
["Carlos Slim", "$55,000,000,000", "84", "Mexican", "America Movil", "Telecom"]
|
|
]
|
|
}
|
|
'''
|
|
|
|
|
|
queries = [
|
|
"Who has a higher net worth, Bernard Arnault or Jeff Bezos?",
|
|
"List the top three individuals by net worth.",
|
|
"Who is the richest person in the technology industry?",
|
|
"Which company in the e-commerce industry has the highest net worth?",
|
|
"Who is the oldest billionaire on the list?",
|
|
"Which individual under the age of 60 has the highest net worth?",
|
|
"Who is the wealthiest American, and which company do they own?",
|
|
"Find all French billionaires and list their companies.",
|
|
"How many women are on the list, and what are their total net worths?",
|
|
"Who is the wealthiest non-American on the list?",
|
|
"Find the person who is the youngest and has a net worth over $100 billion.",
|
|
"Who owns companies in more than one industry, and what are those industries?",
|
|
"What is the total net worth of all individuals over 70?",
|
|
"How many billionaires are in the conglomerate industry?"
|
|
]
|
|
|
|
|
|
|
|
table_data = json.loads(json_data)
|
|
df_table = pd.DataFrame(table_data["rows"], columns=table_data["header"])
|
|
df_table.index += 1
|
|
|
|
st.write("")
|
|
st.write("Context DataFrame (Click To Edit)")
|
|
edited_df = st.data_editor(df_table)
|
|
|
|
|
|
table_json_data = {
|
|
"header": edited_df.columns.tolist(),
|
|
"rows": edited_df.values.tolist()
|
|
}
|
|
table_json_str = json.dumps(table_json_data)
|
|
|
|
|
|
selected_text = st.selectbox("Question Query", queries)
|
|
custom_input = st.text_input("Try it with your own Question!")
|
|
text_to_analyze = custom_input if custom_input else selected_text
|
|
|
|
|
|
spark = init_spark()
|
|
pipeline = create_pipeline(model)
|
|
|
|
|
|
output = fit_data(pipeline, table_json_str, text_to_analyze)
|
|
|
|
|
|
st.markdown("---")
|
|
st.subheader("Processed output:")
|
|
st.write("**Answer:**", ', '.join(output[0][0]))
|
|
|