Spaces:
Running
Running
import base64 | |
import re | |
import json | |
import pandas as pd | |
import gradio as gr | |
import pyterrier as pt | |
pt.init() | |
import pyt_splade | |
from pyterrier_gradio import Demo, MarkdownFile, interface, df2code, code2md | |
factory_max = pyt_splade.SpladeFactory(agg='max') | |
factory_sum = pyt_splade.SpladeFactory(agg='sum') | |
COLAB_NAME = 'pyterrier_splade.ipynb' | |
COLAB_INSTALL = ''' | |
!pip install -q git+https://github.com/naver/splade | |
!pip install -q git+https://github.com/seanmacavaney/pyt_splade@misc | |
'''.strip() | |
def generate_vis(df, mode='Document'): | |
if len(df) == 0: | |
return '' | |
result = [] | |
if mode == 'Document': | |
max_score = max(max(t.values()) for t in df['toks']) | |
for row in df.itertuples(index=False): | |
if mode == 'Query': | |
tok_scores = {m.group(2): float(m.group(1)) for m in re.finditer(r'#combine:0=([0-9.]+)\((#base64\([^)]+\)|[^)]+)\)', row.query)} | |
for key, value in list(tok_scores.items()): | |
if key.startswith('#base64('): | |
b64 = re.search('#base64\(([^)]+)\)', key).group(1) | |
del tok_scores[key] | |
key = base64.b64decode(b64).decode() | |
tok_scores[key] = value | |
max_score = max(tok_scores.values()) | |
orig_tokens = factory_max.tokenizer.tokenize(row.query_0) | |
id = row.qid | |
else: | |
tok_scores = row.toks | |
orig_tokens = factory_max.tokenizer.tokenize(row.text) | |
id = row.docno | |
def toks2span(toks): | |
return '<kbd> </kbd>'.join(f'<kbd style="background-color: rgba(66, 135, 245, {tok_scores.get(t, 0)/max_score});">{t}</kbd>' for t in toks) | |
orig_tokens_set = set(orig_tokens) | |
exp_tokens = [t for t, v in sorted(tok_scores.items(), key=lambda x: (-x[1], x[0])) if t not in orig_tokens_set] | |
result.append(f''' | |
<div style="font-size: 1.2em;">{mode}: <strong>{id}</strong></div> | |
<div style="margin: 4px 0 16px; padding: 4px; border: 1px solid black;"> | |
<div> | |
{toks2span(orig_tokens)} | |
</div> | |
<div><strong>Expansion Tokens:</strong> {toks2span(exp_tokens)}</div> | |
</div> | |
''') | |
return '\n'.join(result) | |
def predict_query(input, agg): | |
code = f'''import pandas as pd | |
import pyterrier as pt ; pt.init() | |
import pyt_splade | |
factory = pyt_splade.SpladeFactory(agg={repr(agg)}) | |
query_pipeline = factory.query() | |
query_pipeline({df2code(input)}) | |
''' | |
pipeline = { | |
'max': factory_max, | |
'sum': factory_sum | |
}[agg].query() | |
res = pipeline(input) | |
vis = generate_vis(res, mode='Query') | |
return (res, code2md(code, COLAB_INSTALL, COLAB_NAME), vis) | |
def predict_doc(input, agg): | |
code = f'''import pandas as pd | |
import pyterrier as pt ; pt.init() | |
import pyt_splade | |
factory = pyt_splade.SpladeFactory(agg={repr(agg)}) | |
doc_pipeline = factory.indexing() | |
doc_pipeline({df2code(input)}) | |
''' | |
pipeline = { | |
'max': factory_max, | |
'sum': factory_sum | |
}[agg].indexing() | |
res = pipeline(input) | |
vis = generate_vis(res, mode='Document') | |
res['toks'] = [json.dumps({k: round(v, 4) for k, v in t.items()}) for t in res['toks']] | |
return (res, code2md(code, COLAB_INSTALL, COLAB_NAME), vis) | |
interface( | |
MarkdownFile('README.md'), | |
MarkdownFile('query.md'), | |
Demo( | |
predict_query, | |
pd.DataFrame([ | |
{'qid': '1112389', 'query': 'what is the county for grand rapids, mn'}, | |
]), | |
[ | |
gr.Dropdown(choices=['max', 'sum'], value='max', label='Aggregation'), | |
], | |
scale=2/3 | |
), | |
MarkdownFile('doc.md'), | |
Demo( | |
predict_doc, | |
pd.DataFrame([ | |
{'docno': '0', 'text': 'The presence of communication amid scientific minds was equally important to the success of the Manhattan Project as scientific intellect was. The only cloud hanging over the impressive achievement of the atomic researchers and engineers is what their success truly meant; hundreds of thousands of innocent lives obliterated.'}, | |
]), | |
[ | |
gr.Dropdown(choices=['max', 'sum'], value='max', label='Aggregation'), | |
], | |
scale=2/3 | |
), | |
MarkdownFile('wrapup.md'), | |
).launch(share=False) | |