Sean MacAvaney commited on
Commit
1051b11
1 Parent(s): 68730a3

fixups: base64 query components and quoting agg in code sample

Browse files
Files changed (1) hide show
  1. app.py +10 -3
app.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import re
2
  import json
3
  import pandas as pd
@@ -23,7 +24,13 @@ def generate_vis(df, mode='Document'):
23
  max_score = max(max(t.values()) for t in df['toks'])
24
  for row in df.itertuples(index=False):
25
  if mode == 'Query':
26
- tok_scores = {m.group(2): float(m.group(1)) for m in re.finditer(r'combine:0=([0-9.]+)\(([^)]+)\)', row.query)}
 
 
 
 
 
 
27
  max_score = max(tok_scores.values())
28
  orig_tokens = factory_max.tokenizer.tokenize(row.query_0)
29
  id = row.qid
@@ -51,7 +58,7 @@ def predict_query(input, agg):
51
  import pyterrier as pt ; pt.init()
52
  import pyt_splade
53
 
54
- factory = pyt_splade.SpladeFactory(agg={agg})
55
 
56
  query_pipeline = factory.query()
57
 
@@ -70,7 +77,7 @@ def predict_doc(input, agg):
70
  import pyterrier as pt ; pt.init()
71
  import pyt_splade
72
 
73
- factory = pyt_splade.SpladeFactory(agg={agg})
74
 
75
  doc_pipeline = factory.indexing()
76
 
 
1
+ import base64
2
  import re
3
  import json
4
  import pandas as pd
 
24
  max_score = max(max(t.values()) for t in df['toks'])
25
  for row in df.itertuples(index=False):
26
  if mode == 'Query':
27
+ tok_scores = {m.group(2): float(m.group(1)) for m in re.finditer(r'#combine:0=([0-9.]+)\((#base64\([^)]+\)|[^)]+)\)', row.query)}
28
+ for key, value in list(tok_scores.items()):
29
+ if key.startswith('#base64('):
30
+ b64 = re.search('#base64\(([^)]+)\)', key).group(1)
31
+ del tok_scores[key]
32
+ key = base64.b64decode(b64).decode()
33
+ tok_scores[key] = value
34
  max_score = max(tok_scores.values())
35
  orig_tokens = factory_max.tokenizer.tokenize(row.query_0)
36
  id = row.qid
 
58
  import pyterrier as pt ; pt.init()
59
  import pyt_splade
60
 
61
+ factory = pyt_splade.SpladeFactory(agg={repr(agg)})
62
 
63
  query_pipeline = factory.query()
64
 
 
77
  import pyterrier as pt ; pt.init()
78
  import pyt_splade
79
 
80
+ factory = pyt_splade.SpladeFactory(agg={repr(agg)})
81
 
82
  doc_pipeline = factory.indexing()
83