import streamlit as st import pandas as pd import numpy as np import json import time import requests import os import glob import re #import smart_open import plotly.express as px import random #import difflib import pdb from sentence_transformers import SentenceTransformer, models, util enable_summary_button = True dump_pos_data_for_reporting = True bucket_name = "paper_n1" prefix_lst = [ "pgj_d_4096", "pgj_d_2048", "pgj_d_1024_v2", "pgj_d_1024_layer_14", "pgj_d_1024_layer_7", "pgj_d_1024_layer_2", "pgj_d_1024_layer_1" ] # "my_gptj_6b_tpu_size_8", model_names = { prefix_lst[0]: 'PatentGPT-J-6B', prefix_lst[1]: 'PatentGPT-J-1.6B', # prefix_lst[2]: 'PatentGPT-J-279M', # prefix_lst[3]: 'PatentGPT-J-191M', # prefix_lst[4]: 'PatentGPT-J-128M', # prefix_lst[5]: 'PatentGPT-J-115M',} prefix_lst[2]: 'PatentGPT-J-456M', prefix_lst[3]: 'PatentGPT-J-279M', prefix_lst[4]: 'PatentGPT-J-191M', prefix_lst[5]: 'PatentGPT-J-128M', prefix_lst[6]: 'PatentGPT-J-115M',} # prefix_lst[7]:'GPT-J-6B' # experiment 3 # folder = os.path.join('experiments', 'non_patent') # id_to_scroll = 1 # which of the above to scroll through # first_claim_only = True #experiment 2 # folder = os.path.join('experiments', 'ipg20220104_500') # #folder = "device_serve_results" # id_to_scroll = 1 # which of the above to scroll through # first_claim_only = False # prefix_lst = ["my_gptj_6b_tpu_size_8", "pgj_d_4096", "pgj_d_2048", "pgj_d_1024_layer_14", "pgj_d_1024_layer_7", "pgj_d_1024_layer_2", "pgj_d_1024_layer_1"] # #, "pgj_large", "pgj_medium", "pgj_small", ] # # "pgj_d_1024_layer_14" # experiment 1 folder = os.path.join('experiments', 'ipg22_500') # (previous) folder = "eval_ipg22_500" id_to_scroll = 1 # which of the above to scroll through first_claim_only = True ignore_outscope = True # ignore pick > 10 # def show_diff(a, b): # #print('{} => {}'.format(a,b)) # for i, s in enumerate(difflib.ndiff(a, b)): # if s[0]==' ': continue # elif s[0]=='-': # print(u'Delete "{}" from position {}'.format(s[-1],i)) # elif s[0]=='+': # print(u'Add "{}" to position {}'.format(s[-1],i)) def handle_char_return(text): if text == '(none)': # unicorn text text == '' return text #return ch.replace('\n', '\\n') #if ch == '\n': # ch = "'\\n'" #return ch def get_remaining(lst, pos): s = '' for i in range(pos, len(lst)): text = lst[i]['actual_next_token_text'] if text.startswith(' ') == False: s += text else: break return s def calc_details(base_fn): full_fn = os.path.join(folder, base_fn) #gs_fn = "gs://%s/%s/%s" % (bucket_name, folder, base_fn) #with smart_open.open(gs_fn) as f: if os.path.exists(full_fn) == False: return None, -1, -1, None, None, None, None, None with open(full_fn) as f: result = json.loads(f.read()) print("Loaded: %s" % full_fn) lst = result['output'] recv = result['recv'] sum_pick = 0 sum_prob = 0 sum_outscope_count = 0 sum_outscope_len = 0 sum_hit_1 = 0 sum_top_10_len = 0 full_text = '' token_count = 0 #found_end = False #pdb.set_trace() for i, tk in enumerate(lst[:-1]): # if found_end: # break token_text = handle_char_return(tk['actual_next_token_text']) # Due to tokenizer difference, the following needs more work in the future. # if base_fn.find('gptj') >= 0: # # using the original gpt-j-6b model # # need to skip special tokens # if i <= 7: # continue # skip |start of claim|> # remaining_text = get_remaining(lst, i) # if remaining_text.find('<|end_of_claim|>') >= 0: # pos1 = remaining_text.find('<|end_of_claim|>') # token_text = remaining_text[:pos1] # found_end = True # #pdb.set_trace() # #break # The following was for GPT-J-6B. Not needed for PatentGPT-J. #if token_text.find('<|end_of_claim|>') == 0: # #pdb.set_trace() # break next_top_seq = int(tk['actual_next_token_top_seq']) next_top_prob = float(tk['actual_next_token_top_prob']) full_text += token_text if next_top_seq == 0: sum_hit_1 += 1 # press "tab" for the top pick if ignore_outscope and next_top_seq>=10: sum_outscope_count += 1 sum_outscope_len += len(token_text) # use length as keystrokes else: sum_pick += min(next_top_seq+1, len(token_text)) #sum_pick += (next_top_seq+1) # press "down" & "tab" sum_prob += next_top_prob sum_top_10_len += len(token_text) token_count += 1 if ignore_outscope: if token_count == 0: # unlikely avg_pick = 0 avg_prob = 0 else: avg_pick = float(sum_pick) / token_count avg_prob = float(sum_prob) / token_count else: avg_pick = float(sum_pick) / token_count avg_prob = float(sum_prob) / token_count # if len(lst) < 2048: # for debugging # s = '<|start_of_claim|>' + full_text # if len(s) != len(recv['context']): # print('length mismatch --> full_text: %s, recv: %s' % (len(s), len(recv['context']))) # show_diff(s, recv['context']) # pdb.set_trace() return result, avg_pick, avg_prob, token_count, sum_pick, sum_prob, sum_outscope_count, sum_outscope_len, sum_hit_1, sum_top_10_len, full_text def show_avg(base_fn, model_name, patent_claim_num, show_pick=False): result, avg_pick, avg_prob, token_count, sum_pick, sum_prob, sum_outscope_count, sum_outscope_len, sum_hit_1, sum_top_10_len, full_text = calc_details(base_fn) if token_count == 0: print('debug 2') pdb.set_trace() if result is None: return None lst = result['output'] result = '' sum_all = {} for i, tk in enumerate(lst): token_text = handle_char_return(tk['actual_next_token_text']) if token_text == '<|end_of_claim|>': break if token_text == '(none)': # for unicorn text break # Skip GPT-J, due to different tokenization # if base_fn.find('gptj') >= 0: # # using the original gpt-j-6b model # # need to skip special tokens # if i <= 7: # continue # skip |start of claim|> # if token_text == '.<': # assuming .<|end of claim|> # break pick = int(tk['actual_next_token_top_seq']) prob = float(tk['actual_next_token_top_prob']) colors = [ ['00ff00', '000000', '1'], ['008800', 'ffffff', '2-10'], ['ff0000', 'ffffff', 'out of top 10'], ] #colors = [ # ['00ff00', '000000', '1'], # ['008800', 'ffffff', '2-10'], # ['aa0000', 'ffffff', '11-100'], # ['ff0000', 'ffffff', '101~'] #] for j, item in enumerate(colors): sum_all[item[2]] = 0 # skip follow-up subword # if token_text.startswith(' ') == False: # bg_color = '' # fg_color = '' # else: if pick == 0: bg_color = colors[0][0] fg_color = colors[0][1] tag = colors[0][2] sum_all[tag] += 1 elif pick >= 1 and pick < 10: bg_color = colors[1][0] fg_color = colors[1][1] tag = colors[1][2] sum_all[tag] += 1 else: # pick >= 10 #elif pick >= 10 and pick < 100: bg_color = colors[2][0] fg_color = colors[2][1] tag = colors[2][2] sum_all[tag] += 1 #else: #pick >= 100: # bg_color = colors[3][0] # fg_color = colors[3][1] # tag = colors[3][2] # sum_all[tag] += 1 if show_pick: pick = '[%s]' % pick else: pick = '' result += "%s%s " % (bg_color, fg_color, token_text, pick) #  color_msg = '' for i, v in enumerate(colors): color_msg += " %s  " % (v[0], v[1], v[2]) #result, avg_pick, avg_prob, token_count, sum_pick, sum_prob, sum_outscope, sum_hit_1, sum_top_10_len, full_text = calc_details(base_fn) # sum_pick as top 1~10 keys_with_auto = (sum_pick+sum_outscope_len) keys_without_auto = len(full_text) saved_ratio = float(keys_without_auto-keys_with_auto)/keys_without_auto * 100 s = 'model: %s\n' \ 'Autocomplete Effectiveness: %.1f%% (keystrokes saved)\n' \ 'Total keystrokes: %s (with autocomplete), %s (without autocomplete)\n' \ 'Keystroke distribution: top 1~10: %s (top 1: %s), out of top 10: %s' % (model_name, saved_ratio, keys_with_auto, keys_without_auto, sum_pick, sum_hit_1, sum_outscope_len) st.text(s) # s = 'file: %s, sum_pick: %s, sum_hit_1: %s, token_count: %s, sum_outscope: %s, avg_pick: %.2f, avg_prob: %.2f, sum_prob: %.2f, hit_1 ratio: %.2f      ' % (base_fn, sum_pick, sum_hit_1, token_count, sum_outscope, avg_pick, avg_prob, sum_prob, float(sum_hit_1)/token_count) #s += color_msg s = color_msg st.markdown(s, unsafe_allow_html=True) #st.text('file: %s, avg_pick: %5.2f, avg_prob: %.2f, hit count: %s/%s ' % (base_fn, avg_pick, avg_prob, hit_0_count, len(lst))) # show histogram st.markdown(result, unsafe_allow_html=True) #st.text_area('context with top seq & prob:', result, height=400) sum_lst = [sum_all['1'], sum_all['2-10'], sum_all['out of top 10']] #sum_lst = [['1', sum_all['1']], ['2-10', sum_all['2-10']]] #sum_lst = [sum_all['1'], sum_all['2-10'], sum_all['11-100'], sum_all['101~']] return sum_lst def show_overall_summary(prefix_lst, select_lst): # accumulate all # debug # for i, num in enumerate(select_lst): # pre_full_text = '' # for prefix in prefix_lst: # base_fn = '%s_%s_forward.json' % (prefix, num) # result, avg_pick, avg_prob, token_count, sum_pick, sum_prob, sum_outscope, sum_hit_1, sum_top_10_len, full_text = calc_details(base_fn) # if pre_full_text == '': # pre_full_text = full_text # else: # if pre_full_text != full_text: # print('debug') # pdb.set_trace() # # # pdb.set_trace() for prefix in prefix_lst: acc_token_count = 0 acc_sum_pick = 0 acc_sum_prob = 0 acc_sum_outscope_count = 0 acc_sum_outscope_len = 0 acc_sum_hit_1 = 0 acc_sum_top_10_len = 0 acc_full_text_len = 0 pre_full_text = '' for i, num in enumerate(select_lst): base_fn = '%s_%s_forward.json' % (prefix, num) result, avg_pick, avg_prob, token_count, sum_pick, sum_prob, sum_outscope_count, sum_outscope_len, sum_hit_1, sum_top_10_len, full_text = calc_details(base_fn) acc_token_count += token_count acc_sum_pick += sum_pick acc_sum_prob += sum_prob acc_sum_outscope_count += sum_outscope_count acc_sum_outscope_len += sum_outscope_len acc_sum_hit_1 += sum_hit_1 acc_sum_top_10_len += sum_top_10_len acc_full_text_len += len(full_text) if acc_token_count > 0: # acc_sum_pick --> top 1~10 keys_with_auto = acc_sum_pick + acc_sum_outscope_len keys_without_auto = acc_full_text_len saved_ratio = float(keys_without_auto-keys_with_auto)/keys_without_auto * 100 st.text('[ %s ]\n' \ 'Autocomplete Effectiveness: %.1f%% (ratio of saving keystroke)\n' \ '(sum) keys_with_auto: %s, top_10_keys: %s, out_of_scope: %s, sum_hit_1: %s\n' \ 'keys_without_auto: %s, top_10_len: %s, prob: %.2f' % ( model_names[prefix], saved_ratio, '{:,}'.format(keys_with_auto), '{:,}'.format(acc_sum_pick), '{:,}'.format(acc_sum_outscope_len), '{:,}'.format(acc_sum_hit_1), '{:,}'.format(keys_without_auto), '{:,}'.format(acc_sum_top_10_len), acc_sum_prob, )) st.text('%s & %.1f\\%% & %s & %s & %s & %s & %s \\\\' % (model_names[prefix], saved_ratio, '{:,}'.format(keys_with_auto), '{:,}'.format(acc_sum_pick), '{:,}'.format(acc_sum_outscope_len), '{:,}'.format(acc_sum_hit_1), '{:,}'.format(keys_without_auto))) # st.text('* acc_token_count =%s --> (avg) hits: %.2f, keys: %.2f, prob: %.2f, outscope: %.2f' % ( # acc_token_count, # float(acc_sum_hit_1)/acc_token_count, # float(acc_sum_pick)/acc_token_count, # float(acc_sum_prob)/acc_token_count, # float(acc_sum_outscope_count)/acc_token_count)) def calc_height(s): return int(len(s) / 10 * 3) + 30 def remove_end_of_claim_text(gen_text): tag = '<|end_of_claim|>' pos = gen_text.find(tag) if pos > 0: gen_text = gen_text[:pos+len(tag)] return gen_text tag = '<|endoftext|>' pos = gen_text.find(tag) if pos > 0: gen_text = gen_text[:pos+len(tag)] return gen_text def dump_pos_data(prefix_lst, select_lst): #statistics = [[0]*3]*2048 statistics = [] for i in range(2048): statistics.append([0,0,0]) #results.append(['model', 'pos', 'key']) #results.append(['model', 'patent_claim', 'pos', 'top-1', 'top-2~10', 'out of top 10']) max_len = -1 for prefix in prefix_lst: model_name = model_names[prefix].replace('PatentGPT-J-', '') if model_name != '456M': continue #total = {} for i, num in enumerate(select_lst): base_fn = '%s_%s_forward.json' % (prefix, num) full_fn = os.path.join(folder, base_fn) if os.path.exists(full_fn) == False: continue with open(full_fn) as f: result = json.loads(f.read()) print("Loaded: %s" % full_fn) lst = result['output'] for j, tk in enumerate(lst[:-1]): max_len = max(j, max_len) next_top_seq = int(tk['actual_next_token_top_seq']) #next_top_prob = float(tk['actual_next_token_top_prob']) top_1 = top_2_to_10 = out_of_scope = 0 if next_top_seq == 0: top_1 = 1 tag = 'top-1' statistics[j][0] += 1 elif next_top_seq > 0 and next_top_seq < 10: top_2_to_10 = 1 tag = 'top-2~10' statistics[j][1] += 1 else: out_of_scope = 1 tag = 'out-of-scope' statistics[j][2] += 1 #total[tag] = total.get(tag, 0) + 1 #results.append([model_name, str(i+1), tag]) #results.append([model_name, str(i+1), tag]) #results.append([model_name, num, str(i+1), tag]) #results.append([model_name, num, i+1, top_1, top_2_to_10, out_of_scope]) #pdb.set_trace() #pdb.set_trace() dump_file = 'dump4.txt' #pdb.set_trace() with open(dump_file, 'w') as f: for i in range(max_len+1): f.write('%s, top-1, %s\n' % (i+1, statistics[i][0])) f.write('%s, top-2~10, %s\n' % (i+1, statistics[i][1])) f.write('%s, out_of_scope, %s\n' % (i+1, statistics[i][2])) # f.write('%s\n' % ', '.join([str(i+1)] + [ str(v) for v in statistics[i] ] )) print('saved: %s' % dump_file) # dump_file = 'dump2.txt' # with open(dump_file, 'w') as f: # for line in results: # f.write('%s\n' % ', '.join(line)) # print('saved: %s' % dump_file) def calc_sentence_similarity(sent_model, sent1, sent2): rewards = [] embedding1 = sent_model.encode(sent1, convert_to_tensor=True) embedding2 = sent_model.encode(sent2, convert_to_tensor=True) similarity = util.cos_sim(embedding1, embedding2)[0][0] #pdb.set_trace() return similarity sent_model = 'patent/st-aipd-nlp-g' print('loading SentenceTransformer: %s' % sent_model) sent_aipd = SentenceTransformer(sent_model) def load_data(demo): fn = 'ppo_open_llama_3b_v2.run.12.delta.txt' #fn = 'ppo_output/ppo_open_llama_3b_v2.run.12.delta.txt' with open(fn, 'r') as f: rows = json.load(f) if demo == 'demo1': new_rows = [ row for row in rows if row['instruction'].find('child') > 0 ] elif demo == 'demo2': new_rows = [ row for row in rows if row['instruction'].find('parent') > 0 ] else: new_rows = [] return new_rows container_style = """ """ def main(): st.set_page_config( # Alternate names: setup_page, page, layout layout="wide", # Can be "centered" or "wide". In the future also "dashboard", etc. initial_sidebar_state="auto", # Can be "auto", "expanded", "collapsed" page_title="Demo 1", # String or None. Strings get appended with "• Streamlit". page_icon=None, # String, anything supported by st.image, or None. ) opt_1 = 'parent --> child' opt_2 = 'child --> parent' options = [opt_1, opt_2] rows = None pos = None patent_num = '' claim_num1 = '' claim_num2 = '' instruction= '' input_text = '' output_text = '' response = '' query = '' score_lst_1 = 0 score_lst_2 = 0 rewards = '' with st.container(): col1, col2, col3 = st.columns([3, 5, 2]) with col1: selected_option = st.selectbox('Select a demo:', options) if selected_option == opt_1: rows = load_data('demo1') msg = 'novelty = sim1-sim2' #msg = 'delta of similarities
(sim1-sim2)' c1_tag = 'pc' c2_tag = 'cc1' c3_tag = 'cc2' elif selected_option == opt_2: rows = load_data('demo2') msg = 'similarity of
(pc1) and (pc2)' c1_tag = 'cc' c2_tag = 'pc1' c3_tag = 'pc2' else: st.text('Unknown option') return #rows = rows[:5000] # for debugging with col2: pos = st.slider("", 1, len(rows)) #pos = st.slider("Degree of novelty (Generated v. Actual)", 1, len(rows)) for i in range(pos): #prompt = '%s' % rows[i] #pdb.set_trace() patent_num = rows[i]['patent_num'] claim_num1 = rows[i]['claim_num1'] claim_num2 = rows[i]['claim_num2'] instruction= rows[i]['instruction'] input_text = rows[i]['input'] output_text = rows[i]['output'] response = rows[i]['response'] query = rows[i]['query'] score_lst_1 = rows[i]['score_lst_1'] score_lst_2 = rows[i]['score_lst_2'] delta = rows[i]['delta'] rewards = rows[i]['rewards'] with col3: #v = round(float(score_lst_1)-float(score_lst_2), 4) #v = delta #round(delta,10) st.markdown("
%s
%s
" % (msg, delta), unsafe_allow_html=True) # style='text-align: center; color: black;' # selectbox_placeholder = st.empty() # selected_option = selectbox_placeholder.selectbox('Select a demo:', options) # container1 = st.container() # with st.container(): # col1, col2 = st.columns(2) # with col1: # st.write('Caption for first chart') # with col2: # st.line_chart((0,1), height=100) # with st.container(): # col1, col2 = st.columns(2) # with col1: # st.write('Caption for second chart') # with col2: # st.line_chart((1,0), height=100) #st.write('patent_num:', patent_num) # st.write('claim_num1:', claim_num1) # st.write('claim_num2:', claim_num2) st.write('(instruction) ', instruction) with st.container(): with st.container(border=True): st.write('(%s) [ %s ]\n%s' % (c1_tag, patent_num, input_text)) #st.write('input:' % patent_num) #st.write('input:\n', input_text) #container1.markdown("
", unsafe_allow_html=True) col1, col2 = st.columns(2) with col1: with st.container(border=True): st.write('(%s) (actual)' % c2_tag) st.write(output_text) with col2: with st.container(border=True): st.write('(%s) (generated)' % c3_tag) st.write(response) col1, col2 = st.columns(2) with col1: with st.container(border=True): st.write('(sim1) similarity between %s and %s+%s: %s' % (c1_tag, c1_tag, c2_tag, str(score_lst_1))) with col2: with st.container(border=True): st.write('(sim2) similarity between %s and %s+%s: %s' % (c1_tag, c1_tag, c3_tag, str(score_lst_2))) #container1.markdown("
", unsafe_allow_html=True) # st.write("In Container 1") # table_name = st.radio("Please Select Table", list_of_tables) # st.write('output:') # st.write(output_text) # st.write('response:') # st.write(response) #st.write('query:', query) # st.write('score_lst_1:', score_lst_1) # st.write('score_lst_2:', score_lst_2) # st.write('rewards:', rewards) # st.text('hello') # dict_keys(['patent_num', 'claim_num1', 'claim_num2', 'instruction', 'input', 'output', 'query', 'response', 'score_lst_1', 'score_lst_2', 'rewards']) # st.subheader("Inspecting PatentGPT-J Model Evaluation") # num_set = set() # fn_lst = glob.glob(os.path.join(folder, '*')) # for i, fn in enumerate(fn_lst): # for prefix in prefix_lst: # v = re.search('(.*?)%s\_(\d+\_\d+)\_(.*?)' % prefix, fn) # if v is None: # v = re.search('(.*?)%s\_(\w+\_\d+)\_(.*?)' % prefix, fn) # #pdb.set_trace() # if v is None: # #pdb.set_trace() # continue # v = v.group(2) # if first_claim_only: # if v.endswith('_1'): # num_set.add(v) # else: # num_set.add(v) # num_lst = list(num_set) # num_lst.sort() # select_lst = [] # for i, num in enumerate(num_lst): # all_existed = True # for prefix in prefix_lst: # fn = os.path.join(folder, '%s_%s_forward.json' % (prefix, num)) # if os.path.exists(fn) == False: # all_existed = False # break # if all_existed: # select_lst.append(num) # select_lst.sort() # if len(select_lst) == 0: # st.text('select_lst is empty') # return # if dump_pos_data_for_reporting: # dump_pos_data(prefix_lst, select_lst) # st.text('Dump data: done') # return # # debug # #base_fn = 'my_gptj_6b_tpu_size_8_11212952_1_forward.json' # #base_fn = 'pgj_small_text-1_1_forward.json' # #_ = show_avg(base_fn) # if enable_summary_button: # if st.button('Show Summary'): # st.text('len(select_lst) = %s' % len(select_lst)) # show_overall_summary(prefix_lst, select_lst) # # if 'num' not in st.session_state: # # num = random.choice(select_lst) # # st.session_state['num'] = num # # set_state('num', num) # # def set_state(k, v): # # if k not in st.session_state: # # st.session_state[ k ] = v # show_patent_lst = [ s.replace('_', ' (claim ') + ')' for s in select_lst] # selected = st.selectbox("Choose a patent claim", show_patent_lst) # num = selected.replace(')', '').replace(' (claim ', '_') # if st.button('Random pick'): # num = random.choice(select_lst) # st.text('Selected: %s' % num) # st.session_state['num'] = num # avgs = [] # for prefix in prefix_lst: # base_fn = '%s_%s_forward.json' % (prefix, num) # one_avg = show_avg(base_fn, model_names[prefix], num) # if one_avg is not None: # avgs.append(one_avg) # # debug # #pdb.set_trace() # #return # # # data_lst = [] # for i in range(len(avgs[0])): # row = [] # for j, prefix in enumerate(prefix_lst): # row.append(avgs[j][i]) # data_lst.append(row) # df = pd.DataFrame(data_lst, index=['1','2-10','out of top 10']) # #df = pd.DataFrame(data_lst, index=['1','2-10','11-100','101~']) # # ], index=['(a) 1','(b) 2-10','(c) 11-100','(d) 101~']) # # [avgs[0][0], avgs[1][0], avgs[2][0]], # # [avgs[0][1], avgs[1][1], avgs[2][1]], # # [avgs[0][2], avgs[1][2], avgs[2][2]], # # [avgs[0][3], avgs[1][3], avgs[2][3]], # #df = pd.DataFrame([[1,2],[3,1]], columns=['a', 'b']) # #df = pd.DataFrame([ # # [sum1[0], sum1[1], sum1[2], sum1[3]], # # [sum2[0], sum2[1], sum2[2], sum2[3]], # # [sum3[0], sum3[1], sum3[2], sum3[3]], # # ]) #, index=['(a) 1','(b) 2-10','(c) 11-100','(d) 101~']) # #df = pd.DataFrame.from_dict(sum_all, orient='index') # #st.line_chart(df) # #data_canada = px.data.gapminder().query("country == 'Canada'") # #fig = px.bar(data_canada, x='year', y='pop') # if st.button('Show chart'): # fig = px.bar(df, barmode='group') # st.plotly_chart(fig, use_container_width=True) # #fig.show() # #st.area_chart(df) # #st.bar_chart(df) # # # base_fn = '%s_%s_forward.json' % (prefix_lst[ id_to_scroll ], st.session_state['num']) # result, avg_pick, avg_prob, _, _, _, _, _, _, _, _ = calc_details(base_fn) # recv = result['recv'] # lst = result['output'] # input_tokens = result['input'] # # (Pdb) print(token_pos_lst[0].keys()) # #dict_keys(['idx', 'gen_text', 'actual_next_token_text', 'actual_next_token_top_seq', 'actual_next_token_top_prob', 'top_n_lst']) # height = calc_height(recv['context']) # st.text_area('context:', recv['context'], height=height) # pos = st.slider("Token position", 0, len(lst)) # prompt = '' # for i in range(pos+1): # prompt += input_tokens[i]['text'] # height = calc_height(prompt) # st.text_area('prompt:', prompt, height=height) # ch = handle_char_return(lst[pos]['actual_next_token_text']) # st.text('actual_next_token_text: %s --> pick seq: %s (prob: %.2f)' % (ch, int(lst[pos]['actual_next_token_top_seq'])+1, # float(lst[pos]['actual_next_token_top_prob']))) # st.text('top 10 tokens:') # for i, v in enumerate(lst[pos]['top_n_lst']): # ch = handle_char_return(v['top_n_text']) # st.text('[ %s ][ %s ]( %.2f )' % (i+1, ch, float(v['top_n_prob']))) # gen_text = lst[pos]['gen_text'] # gen_text = remove_end_of_claim_text(gen_text) # st.text('gen_text: %s' % gen_text) # #st.text("done. ok.") # #st.text('result:\n%s' % result) if __name__ == "__main__": main() #def load_data_pre(demo): # fn = 'ppo_output/ppo_open_llama_3b_v2.run.12.keep.txt' # with open(fn, 'r') as f: # rows = json.load(f) # new_rows = [] # for i, row in enumerate(rows): # item1 = {} # item2 = {} # if demo == 'demo1': # item1[ 'delta' ] = abs(row['score_lst_1'][0] - row['score_lst_2'][0]) # item2[ 'delta' ] = abs(row['score_lst_1'][1] - row['score_lst_2'][1]) # elif demo == 'demo2': # #pdb.set_trace() # item1[ 'delta' ] = calc_sentence_similarity(sent_aipd, row['output'][0], row['response'][0]) # item2[ 'delta' ] = calc_sentence_similarity(sent_aipd, row['output'][1], row['response'][1]) # print('[ %s ] detla = %s' % (i, item1[ 'delta' ])) # for k in row.keys(): # item1[ k ] = row[ k ][0] # item2[ k ] = row[ k ][1] # if demo == 'demo1': # if item1['instruction'].find('child') > 0: # new_rows.append(item1) # if item2['instruction'].find('child') > 0: # new_rows.append(item2) # elif demo == 'demo2': # if item1['instruction'].find('parent') > 0: # new_rows.append(item1) # if item2['instruction'].find('parent') > 0: # new_rows.append(item2) # # Assuming new_rows is your list of dictionaries # sorted_rows = sorted(new_rows, key=lambda x: x['delta']) # # kv = {} # # for i, row in enumerate(new_rows): # # if diff > 0.0001: # # kv[i] = round(diff, 4) # # sorted_rows = [] # # sorted_kv = sorted(kv.items(), key=lambda x:x[1]) # # for k, v in sorted_kv: # # sorted_rows.append(new_rows[k]) # #pdb.set_trace() # return sorted_rows