import streamlit as st import time import requests import os import json import glob import re import random import difflib from random import randrange prefix_lst = [ "pgj_d_4096", "pgj_d_2048", "pgj_d_1024_v2", "pgj_d_1024_layer_14", "pgj_d_1024_layer_7", "pgj_d_1024_layer_2", "pgj_d_1024_layer_1" ] model_names = { prefix_lst[0]: 'PatentGPT-J-6B', prefix_lst[1]: 'PatentGPT-J-1.6B', prefix_lst[2]: 'PatentGPT-J-456M', prefix_lst[3]: 'PatentGPT-J-279M', prefix_lst[4]: 'PatentGPT-J-191M', prefix_lst[5]: 'PatentGPT-J-128M', prefix_lst[6]: 'PatentGPT-J-115M',} # experiment 3 # folder = os.path.join('experiments', 'non_patent') # id_to_scroll = 1 # which of the above to scroll through # first_claim_only = True #experiment 2 folder = os.path.join('experiments', 'ipg20220104_500') #folder = "device_serve_results" id_to_scroll = 1 # which of the above to scroll through first_claim_only = False # prefix_lst = ["my_gptj_6b_tpu_size_8", "pgj_d_4096", "pgj_d_2048", "pgj_d_1024_layer_14", "pgj_d_1024_layer_7", "pgj_d_1024_layer_2", "pgj_d_1024_layer_1"] # #, "pgj_large", "pgj_medium", "pgj_small", ] # # "pgj_d_1024_layer_14" # experiment 1 # folder = os.path.join('experiments', 'ipg22_500') # # (previous) folder = "eval_ipg22_500" # id_to_scroll = 1 # which of the above to scroll through # first_claim_only = True select_lst = [] def handle_char_return(text): if text == '(none)': # unicorn text text == '' return text def calc_height(s): return int(len(s) / 10 * 3) + 30 def remove_end_of_claim_text(gen_text): tag = '<|end_of_claim|>' pos = gen_text.find(tag) if pos > 0: gen_text = gen_text[:pos+len(tag)] return gen_text tag = '<|endoftext|>' pos = gen_text.find(tag) if pos > 0: gen_text = gen_text[:pos+len(tag)] return gen_text def update_content(): #st.write("The value of the slider is:", st.session_state.myslider) pass def prepare_select_lst(): num_set = set() fn_lst = glob.glob(os.path.join(folder, '*')) for i, fn in enumerate(fn_lst): for prefix in prefix_lst: v = re.search('(.*?)%s\_(\d+\_\d+)\_(.*?)' % prefix, fn) if v is None: v = re.search('(.*?)%s\_(\w+\_\d+)\_(.*?)' % prefix, fn) if v is None: continue v = v.group(2) if first_claim_only: if v.endswith('_1'): num_set.add(v) else: num_set.add(v) num_lst = list(num_set) num_lst.sort() select_lst = [] for i, num in enumerate(num_lst): all_existed = True for prefix in prefix_lst: fn = os.path.join(folder, '%s_%s_forward.json' % (prefix, num)) if os.path.exists(fn) == False: all_existed = False break if all_existed: select_lst.append(num) select_lst.sort() return select_lst def update_selected(): global select_lst #st.write("The value of the slider is:", st.session_state.myselectbox) #num = selected.replace(')', '').replace(' (claim ', '_') selected = st.session_state.myselectbox pick_and_load(select_lst, selected) def pick_and_load(select_lst, selected=None): if selected is None: pick = random.randrange(len(select_lst)) st.session_state['picked_flag'] = pick selected = select_lst[pick] num = selected.replace(')', '').replace(' (claim ', '_') st.session_state['num'] = num prefix = "pgj_d_1024_v2" # size: 456M base_fn = '%s_%s_forward.json' % (prefix, num) full_fn = os.path.join(folder, base_fn) with open(full_fn) as f: result = json.loads(f.read()) print("Loaded: %s" % full_fn) st.session_state['result'] = result return pick, num, result def main(): st.set_page_config( # Alternate names: setup_page, page, layout layout="wide", # Can be "centered" or "wide". In the future also "dashboard", etc. initial_sidebar_state="auto", # Can be "auto", "expanded", "collapsed" page_title="Patent-GPT-J demo", # String or None. Strings get appended with "• Streamlit". page_icon=None, # String, anything supported by st.image, or None. ) st.subheader("PatentGPT-J Demo 2 (Autocomplete Effectiveness)") st.text("Data coverage: ipg220104 (in 2022-01-04)") if 'select_lst' not in st.session_state: select_lst = prepare_select_lst() st.session_state['select_lst'] = select_lst else: select_lst = st.session_state['select_lst'] if len(select_lst) == 0: st.text('select_lst is empty') return show_patent_lst = [ s.replace('_', ' (claim ') + ')' for s in select_lst] #pick = 0 if 'picked_flag' not in st.session_state: pick, num, result = pick_and_load(select_lst) else: pick = st.session_state['picked_flag'] num = st.session_state['num'] result = st.session_state['result'] if st.button('Random pick'): pick, num, result = pick_and_load(select_lst) # to-do, on_change --> load the file selected = st.selectbox("Choose a patent claim", show_patent_lst, index=pick, key='myselectbox', on_change=update_selected) #st.text('Selected: %s' % num) recv = result['recv'] lst = result['output'] input_tokens = result['input'] height = calc_height(recv['context']) st.text_area('context:', recv['context'], height=height) pos = st.slider("Token position", 0, len(lst), key="myslider", on_change=update_content) prompt = '' for i in range(pos+1): prompt += input_tokens[i]['text'] height = calc_height(prompt) st.text_area('prompt:', prompt, height=height) ch = handle_char_return(lst[pos]['actual_next_token_text']) st.text('actual_next_token_text: %s --> pick seq: %s (prob: %.2f) top 10 tokens:' % (ch, int(lst[pos]['actual_next_token_top_seq'])+1, float(lst[pos]['actual_next_token_top_prob']))) msg = '' for i, v in enumerate(lst[pos]['top_n_lst']): ch = handle_char_return(v['top_n_text']) msg += '(%s)[%s](%.2f) ' % (i+1, ch, float(v['top_n_prob'])) if i == 4: st.text(msg) msg = '' st.text(msg) gen_text = lst[pos]['gen_text'] gen_text = remove_end_of_claim_text(gen_text) height = calc_height(gen_text) st.text_area('generated:', gen_text, height=height) #st.text('gen_text: %s' % gen_text) #st.text("done. ok.") #st.text('result:\n%s' % result) if __name__ == "__main__": main()