import base64
import re
import json
import pandas as pd
import gradio as gr
import pyterrier as pt
pt.init()
import pyt_splade
from pyterrier_gradio import Demo, MarkdownFile, interface, df2code, code2md, EX_Q, EX_D
factory_max = pyt_splade.SpladeFactory(agg='max')
factory_sum = pyt_splade.SpladeFactory(agg='sum')

COLAB_NAME = 'pyterrier_splade.ipynb'
COLAB_INSTALL = '''
!pip install -q git+https://github.com/naver/splade
!pip install -q git+https://github.com/seanmacavaney/pyt_splade@misc
'''.strip()

def generate_vis(df, mode='Document'):
  if len(df) == 0:
    return ''
  result = []
  if mode == 'Document':
    max_score = max(max(t.values()) for t in df['toks'])
  for row in df.itertuples(index=False):
    if mode == 'Query':
      tok_scores = {m.group(2): float(m.group(1)) for m in re.finditer(r'#combine:0=([0-9.]+)\((#base64\([^)]+\)|[^)]+)\)', row.query)}
      for key, value in list(tok_scores.items()):
        if key.startswith('#base64('):
          b64 = re.search('#base64\(([^)]+)\)', key).group(1)
          del tok_scores[key]
          key = base64.b64decode(b64).decode()
          tok_scores[key] = value
      max_score = max(tok_scores.values())
      orig_tokens = factory_max.tokenizer.tokenize(row.query_0)
      id = row.qid
    else:
      tok_scores = row.toks
      orig_tokens = factory_max.tokenizer.tokenize(row.text)
      id = row.docno
    def toks2span(toks):
      return '<kbd> </kbd>'.join(f'<kbd style="background-color: rgba(66, 135, 245, {tok_scores.get(t, 0)/max_score});">{t}</kbd>' for t in toks)
    orig_tokens_set = set(orig_tokens)
    exp_tokens = [t for t, v in sorted(tok_scores.items(), key=lambda x: (-x[1], x[0])) if t not in orig_tokens_set]
    result.append(f'''
<div style="font-size: 1.2em;">{mode}: <strong>{id}</strong></div>
<div style="margin: 4px 0 16px; padding: 4px; border: 1px solid black;">
<div>
{toks2span(orig_tokens)}
</div>
<div><strong>Expansion Tokens:</strong> {toks2span(exp_tokens)}</div>
</div>
''')
  return '\n'.join(result)

def predict_query(input, agg):
  code = f'''import pandas as pd
import pyterrier as pt ; pt.init()
import pyt_splade

splade = pyt_splade.SpladeFactory(agg={repr(agg)})

query_pipeline = splade.query()

query_pipeline({df2code(input)})
'''
  pipeline = {
    'max': factory_max,
    'sum': factory_sum
  }[agg].query()
  res = pipeline(input)
  vis = generate_vis(res, mode='Query')
  return (res, code2md(code, COLAB_INSTALL, COLAB_NAME), vis)

def predict_doc(input, agg):
  code = f'''import pandas as pd
import pyterrier as pt ; pt.init()
import pyt_splade

splade = pyt_splade.SpladeFactory(agg={repr(agg)})

doc_pipeline = splade.indexing()

doc_pipeline({df2code(input)})
'''
  pipeline = {
    'max': factory_max,
    'sum': factory_sum
  }[agg].indexing()
  res = pipeline(input)
  vis = generate_vis(res, mode='Document')
  res['toks'] = [json.dumps({k: round(v, 4) for k, v in t.items()}) for t in res['toks']]
  return (res, code2md(code, COLAB_INSTALL, COLAB_NAME), vis)

interface(
  MarkdownFile('README.md'),
  MarkdownFile('query.md'),
  Demo(
    predict_query,
    EX_Q,
    [
      gr.Dropdown(choices=['max', 'sum'], value='max', label='Aggregation'),
    ],
    scale=2/3
  ),
  MarkdownFile('doc.md'),
  Demo(
    predict_doc,
    EX_D,
    [
      gr.Dropdown(choices=['max', 'sum'], value='max', label='Aggregation'),
    ],
    scale=2/3
  ),
  MarkdownFile('wrapup.md'),
).launch(share=False)