File size: 7,258 Bytes
ca1575d 4f033e8 ca1575d 00035e6 ca1575d 00035e6 ca1575d 00035e6 ca1575d 77ac1b2 ca1575d 4f033e8 ca1575d 4f033e8 ca1575d 4f033e8 77ac1b2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 |
import streamlit as st
import sparknlp
import pandas as pd
import json
from sparknlp.base import *
from sparknlp.annotator import *
from pyspark.ml import Pipeline
from sparknlp.pretrained import PretrainedPipeline
# Page configuration
st.set_page_config(
layout="wide",
initial_sidebar_state="auto"
)
# CSS for styling
st.markdown("""
<style>
.main-title {
font-size: 36px;
color: #4A90E2;
font-weight: bold;
text-align: center;
}
.section {
background-color: #f9f9f9;
padding: 10px;
border-radius: 10px;
margin-top: 10px;
}
.section p, .section ul {
color: #666666;
}
</style>
""", unsafe_allow_html=True)
@st.cache_resource
def init_spark():
return sparknlp.start()
@st.cache_resource
def create_pipeline(model):
document_assembler = MultiDocumentAssembler() \
.setInputCols("table_json", "questions") \
.setOutputCols("document_table", "document_questions")
sentence_detector = SentenceDetector() \
.setInputCols(["document_questions"]) \
.setOutputCol("questions")
table_assembler = TableAssembler()\
.setInputCols(["document_table"])\
.setOutputCol("table")
tapas_wtq = TapasForQuestionAnswering\
.pretrained("table_qa_tapas_base_finetuned_wtq", "en")\
.setInputCols(["questions", "table"])\
.setOutputCol("answers_wtq")
tapas_sqa = TapasForQuestionAnswering\
.pretrained("table_qa_tapas_base_finetuned_sqa", "en")\
.setInputCols(["questions", "table"])\
.setOutputCol("answers_sqa")
pipeline = Pipeline(stages=[document_assembler, sentence_detector, table_assembler, tapas_wtq, tapas_sqa])
return pipeline
def fit_data(pipeline, json_data, question):
spark_df = spark.createDataFrame([[json_data, question]]).toDF("table_json", "questions")
model = pipeline.fit(spark_df)
res = model.transform(spark_df)
return res.select("answers_wtq.result", "answers_sqa.result").collect()
# Sidebar content
model = st.sidebar.selectbox(
"Choose the pretrained model",
["table_qa_tapas_base_finetuned_wtq", "table_qa_tapas_base_finetuned_sqa"],
help="For more info about the models visit: https://sparknlp.org/models"
)
# Set up the page layout
title = 'TAPAS for Table-Based Question Answering with Spark NLP'
sub_title = ("""
TAPAS (Table Parsing Supervised via Pre-trained Language Models) enhances the BERT architecture to effectively process tabular data, allowing it to answer complex questions about tables without needing to convert them into text.<br>
<br>
<strong>table_qa_tapas_base_finetuned_wtq:</strong> This model excels at answering questions that require aggregating data across the entire table, such as calculating sums or averages.<br>
<strong>table_qa_tapas_base_finetuned_sqa:</strong> This model is designed for sequential question-answering tasks where the answer to each question may depend on the context provided by previous answers.
""")
st.markdown(f'<div class="main-title">{title}</div>', unsafe_allow_html=True)
st.markdown(f'<div class="section"><p>{sub_title}</p></div>', unsafe_allow_html=True)
# Reference notebook link in sidebar
link = """
<a href="https://colab.research.google.com/github/JohnSnowLabs/spark-nlp-workshop/blob/master/tutorials/streamlit_notebooks/NER_HINDI_ENGLISH.ipynb">
<img src="https://colab.research.google.com/assets/colab-badge.svg" style="zoom: 1.3" alt="Open In Colab"/>
</a>
"""
st.sidebar.markdown('Reference notebook:')
st.sidebar.markdown(link, unsafe_allow_html=True)
# Define the JSON data for the table
# New JSON data
json_data = '''
{
"header": ["name", "net_worth", "age", "nationality", "company", "industry"],
"rows": [
["Elon Musk", "$200,000,000,000", "52", "American", "Tesla, SpaceX", "Automotive, Aerospace"],
["Jeff Bezos", "$150,000,000,000", "60", "American", "Amazon", "E-commerce"],
["Bernard Arnault", "$210,000,000,000", "74", "French", "LVMH", "Luxury Goods"],
["Bill Gates", "$120,000,000,000", "68", "American", "Microsoft", "Technology"],
["Warren Buffett", "$110,000,000,000", "93", "American", "Berkshire Hathaway", "Conglomerate"],
["Larry Page", "$100,000,000,000", "51", "American", "Google", "Technology"],
["Mark Zuckerberg", "$85,000,000,000", "40", "American", "Meta", "Social Media"],
["Mukesh Ambani", "$80,000,000,000", "67", "Indian", "Reliance Industries", "Conglomerate"],
["Alice Walton", "$65,000,000,000", "74", "American", "Walmart", "Retail"],
["Francoise Bettencourt Meyers", "$70,000,000,000", "70", "French", "L'Oreal", "Cosmetics"],
["Amancio Ortega", "$75,000,000,000", "88", "Spanish", "Inditex (Zara)", "Retail"],
["Carlos Slim", "$55,000,000,000", "84", "Mexican", "America Movil", "Telecom"]
]
}
'''
# Define queries for selection
queries = [
"Who has a higher net worth, Bernard Arnault or Jeff Bezos?",
"List the top three individuals by net worth.",
"Who is the richest person in the technology industry?",
"Which company in the e-commerce industry has the highest net worth?",
"Who is the oldest billionaire on the list?",
"Which individual under the age of 60 has the highest net worth?",
"Who is the wealthiest American, and which company do they own?",
"Find all French billionaires and list their companies.",
"How many women are on the list, and what are their total net worths?",
"Who is the wealthiest non-American on the list?",
"Find the person who is the youngest and has a net worth over $100 billion.",
"Who owns companies in more than one industry, and what are those industries?",
"What is the total net worth of all individuals over 70?",
"How many billionaires are in the conglomerate industry?"
]
# Load the JSON data into a DataFrame and display it
table_data = json.loads(json_data)
df_table = pd.DataFrame(table_data["rows"], columns=table_data["header"])
df_table.index += 1
st.write("")
st.write("Context DataFrame (Click To Edit)")
edited_df = st.data_editor(df_table)
# Convert edited DataFrame back to JSON format
table_json_data = {
"header": edited_df.columns.tolist(),
"rows": edited_df.values.tolist()
}
table_json_str = json.dumps(table_json_data)
# User input for questions
selected_text = st.selectbox("Question Query", queries)
custom_input = st.text_input("Try it with your own Question!")
text_to_analyze = custom_input if custom_input else selected_text
# Initialize Spark and create the pipeline
spark = init_spark()
pipeline = create_pipeline(model)
# Run the pipeline with the selected query and the converted table data
output = fit_data(pipeline, table_json_str, text_to_analyze)
# Display the output
st.markdown("---")
st.subheader("Processed Output")
# Check if output is available
if output:
# Extract and Display results
results_wtq = output[0][0] if output[0][0] else "No results found."
results_sqa = output[0][1] if output[0][1] else "No results found."
st.markdown(f"**Answers from WTQ model:** {', '.join(results_wtq)}")
st.markdown(f"**Answers from SQA model:** {', '.join(results_sqa)}") |