import streamlit as st
import pandas as pd
import numpy as np
import json

import time
import requests

import os
import glob
import re
#import smart_open
import plotly.express as px 
import random
#import difflib
import pdb

from sentence_transformers import SentenceTransformer, models, util

enable_summary_button = True
dump_pos_data_for_reporting = True

bucket_name = "paper_n1"

prefix_lst = [
  "pgj_d_4096", 
  "pgj_d_2048", 
  "pgj_d_1024_v2", 
  "pgj_d_1024_layer_14", 
  "pgj_d_1024_layer_7", 
  "pgj_d_1024_layer_2", 
  "pgj_d_1024_layer_1" ]

  # "my_gptj_6b_tpu_size_8", 

model_names = {
  prefix_lst[0]: 'PatentGPT-J-6B',
  prefix_lst[1]: 'PatentGPT-J-1.6B',

  # prefix_lst[2]: 'PatentGPT-J-279M',
  # prefix_lst[3]: 'PatentGPT-J-191M',
  # prefix_lst[4]: 'PatentGPT-J-128M',
  # prefix_lst[5]: 'PatentGPT-J-115M',}

  prefix_lst[2]: 'PatentGPT-J-456M',
  prefix_lst[3]: 'PatentGPT-J-279M',
  prefix_lst[4]: 'PatentGPT-J-191M',
  prefix_lst[5]: 'PatentGPT-J-128M',
  prefix_lst[6]: 'PatentGPT-J-115M',}

  # prefix_lst[7]:'GPT-J-6B'

# experiment 3
# folder = os.path.join('experiments', 'non_patent')
# id_to_scroll = 1  # which of the above to scroll through
# first_claim_only = True 

#experiment 2
# folder = os.path.join('experiments', 'ipg20220104_500')
# #folder = "device_serve_results"
# id_to_scroll = 1  # which of the above to scroll through
# first_claim_only = False

# prefix_lst = ["my_gptj_6b_tpu_size_8", "pgj_d_4096", "pgj_d_2048", "pgj_d_1024_layer_14", "pgj_d_1024_layer_7", "pgj_d_1024_layer_2", "pgj_d_1024_layer_1"]
# #, "pgj_large", "pgj_medium", "pgj_small", ]
# # "pgj_d_1024_layer_14"

# experiment 1
folder = os.path.join('experiments', 'ipg22_500')
# (previous) folder = "eval_ipg22_500"
id_to_scroll = 1  # which of the above to scroll through
first_claim_only = True 

ignore_outscope = True  # ignore pick > 10

# def show_diff(a, b):
#   #print('{} => {}'.format(a,b))  
#   for i, s in enumerate(difflib.ndiff(a, b)):
#       if s[0]==' ': continue
#       elif s[0]=='-':
#           print(u'Delete "{}" from position {}'.format(s[-1],i))
#       elif s[0]=='+':
#           print(u'Add "{}" to position {}'.format(s[-1],i))    

def handle_char_return(text):
  if text == '(none)':  # unicorn text
    text == ''

  return text
  
  #return ch.replace('\n', '\\n')

  #if ch == '\n':
  #  ch = "'\\n'"
  #return ch

def get_remaining(lst, pos):
  s = ''
  for i in range(pos, len(lst)):
    text = lst[i]['actual_next_token_text']
    if text.startswith(' ') == False:
      s += text
    else:
      break  

  return s

def calc_details(base_fn):
  full_fn = os.path.join(folder, base_fn)
  #gs_fn = "gs://%s/%s/%s" % (bucket_name, folder, base_fn)
  #with smart_open.open(gs_fn) as f:

  if os.path.exists(full_fn) == False:
    return None, -1, -1, None, None, None, None, None

  with open(full_fn) as f:
    result = json.loads(f.read())
    print("Loaded: %s" % full_fn)

  lst = result['output']
  recv = result['recv']
  sum_pick = 0
  sum_prob = 0 
  sum_outscope_count = 0
  sum_outscope_len = 0
  sum_hit_1 = 0
  sum_top_10_len = 0
  full_text = ''

  token_count = 0
  #found_end = False
  
  #pdb.set_trace()

  for i, tk in enumerate(lst[:-1]):
    # if found_end:
    #   break 

    token_text = handle_char_return(tk['actual_next_token_text'])

    # Due to tokenizer difference, the following needs more work in the future.
    # if base_fn.find('gptj') >= 0: 
    #   # using the original gpt-j-6b model
    #   # need to skip special tokens
    #   if i <= 7:
    #     continue  # skip |start of claim|>

    #   remaining_text = get_remaining(lst, i)
    #   if remaining_text.find('<|end_of_claim|>') >= 0:
    #     pos1 = remaining_text.find('<|end_of_claim|>')
    #     token_text = remaining_text[:pos1]
    #     found_end = True
    #     #pdb.set_trace() 
    #     #break

    # The following was for GPT-J-6B. Not needed for PatentGPT-J. 
    #if token_text.find('<|end_of_claim|>') == 0: 
    #  #pdb.set_trace()
    #  break
    
    next_top_seq = int(tk['actual_next_token_top_seq'])
    next_top_prob = float(tk['actual_next_token_top_prob']) 

    full_text += token_text
    if next_top_seq == 0:
      sum_hit_1 += 1   # press "tab" for the top pick

    if ignore_outscope and next_top_seq>=10: 
      sum_outscope_count += 1
      sum_outscope_len += len(token_text)  # use length as keystrokes 
    else:
      sum_pick += min(next_top_seq+1, len(token_text))
      #sum_pick += (next_top_seq+1) # press "down" & "tab"
      sum_prob += next_top_prob
      sum_top_10_len += len(token_text)

    token_count += 1

  if ignore_outscope: 
    if token_count == 0: # unlikely
      avg_pick = 0
      avg_prob = 0
    else:
      avg_pick = float(sum_pick) / token_count
      avg_prob = float(sum_prob) / token_count
  else:
    avg_pick = float(sum_pick) / token_count
    avg_prob = float(sum_prob) / token_count  

  # if len(lst) < 2048:  # for debugging
  #   s = '<|start_of_claim|>' + full_text
  #   if len(s) != len(recv['context']):
  #     print('length mismatch --> full_text: %s, recv: %s' % (len(s), len(recv['context'])))
  #     show_diff(s, recv['context'])
  #     pdb.set_trace() 

  return result, avg_pick, avg_prob, token_count, sum_pick, sum_prob, sum_outscope_count, sum_outscope_len, sum_hit_1, sum_top_10_len, full_text

def show_avg(base_fn, model_name, patent_claim_num, show_pick=False): 

  result, avg_pick, avg_prob, token_count, sum_pick, sum_prob, sum_outscope_count, sum_outscope_len, sum_hit_1, sum_top_10_len, full_text = calc_details(base_fn)  

  if token_count == 0:
    print('debug 2')
    pdb.set_trace()

  if result is None:
    return None

  lst = result['output']
  result = ''
  sum_all = {}
  for i, tk in enumerate(lst):
    token_text = handle_char_return(tk['actual_next_token_text'])

    if token_text == '<|end_of_claim|>': 
      break

    if token_text == '(none)': # for unicorn text
      break      

    # Skip GPT-J, due to different tokenization  
    # if base_fn.find('gptj') >= 0: 
    #   # using the original gpt-j-6b model
    #   # need to skip special tokens
    #   if i <= 7:
    #     continue  # skip |start of claim|>
    #   if token_text == '.<': # assuming .<|end of claim|>
    #     break

    pick = int(tk['actual_next_token_top_seq'])
    prob = float(tk['actual_next_token_top_prob'])

    colors = [
      ['00ff00', '000000', '1'], 
      ['008800', 'ffffff', '2-10'], 
      ['ff0000', 'ffffff', 'out of top 10'], 
    ]
    #colors = [
    #  ['00ff00', '000000', '1'], 
    #  ['008800', 'ffffff', '2-10'], 
    #  ['aa0000', 'ffffff', '11-100'], 
    #  ['ff0000', 'ffffff', '101~']
    #]

    for j, item in enumerate(colors):
      sum_all[item[2]] = 0

    # skip follow-up subword 
    # if token_text.startswith(' ') == False:
    #   bg_color = ''
    #   fg_color = '' 
    # else:     

    if pick == 0:
      bg_color = colors[0][0]
      fg_color = colors[0][1]
      tag = colors[0][2]
      sum_all[tag] += 1
    elif pick >= 1 and pick < 10:
      bg_color = colors[1][0]
      fg_color = colors[1][1]
      tag = colors[1][2]
      sum_all[tag] += 1
    else: # pick >= 10
      #elif pick >= 10 and pick < 100:
      bg_color = colors[2][0]
      fg_color = colors[2][1]
      tag = colors[2][2]
      sum_all[tag] += 1
    #else:   #pick >= 100:
    #  bg_color = colors[3][0]
    #  fg_color = colors[3][1]
    #  tag = colors[3][2]
    #  sum_all[tag] += 1

    if show_pick:
      pick = '[%s]' % pick
    else:
      pick = ''

    result += "<span style=background-color:#%s;color:#%s;border-radius:5px;>%s%s</span>  " % (bg_color, fg_color, token_text, pick) #&nbsp;

  color_msg = ''
  for i, v in enumerate(colors):
    color_msg += "<span style=background-color:#%s;color:#%s;border-radius:5px;>&nbsp;%s&nbsp;</span> " % (v[0], v[1], v[2]) 

  #result, avg_pick, avg_prob, token_count, sum_pick, sum_prob, sum_outscope, sum_hit_1, sum_top_10_len, full_text = calc_details(base_fn)  

  # sum_pick as top 1~10
  keys_with_auto = (sum_pick+sum_outscope_len)
  keys_without_auto = len(full_text)
  saved_ratio = float(keys_without_auto-keys_with_auto)/keys_without_auto * 100
  s = 'model: %s\n' \
    'Autocomplete Effectiveness: %.1f%% (keystrokes saved)\n' \
    'Total keystrokes: %s (with autocomplete), %s (without autocomplete)\n' \
    'Keystroke distribution: top 1~10: %s (top 1: %s), out of top 10: %s' % (model_name, saved_ratio, keys_with_auto, keys_without_auto,  sum_pick, sum_hit_1, sum_outscope_len)
  st.text(s)

  # s = 'file: %s, sum_pick: %s, sum_hit_1: %s, token_count: %s, sum_outscope: %s, avg_pick: %.2f, avg_prob: %.2f, sum_prob: %.2f, hit_1 ratio: %.2f  &nbsp;&nbsp;&nbsp;&nbsp;&nbsp;' % (base_fn, sum_pick, sum_hit_1, token_count, sum_outscope, avg_pick, avg_prob, sum_prob, float(sum_hit_1)/token_count)
  #s += color_msg

  s = color_msg
  st.markdown(s, unsafe_allow_html=True)
  #st.text('file: %s, avg_pick: %5.2f, avg_prob: %.2f, hit count: %s/%s  ' % (base_fn, avg_pick, avg_prob, hit_0_count, len(lst)))
  # show histogram

  st.markdown(result, unsafe_allow_html=True)
  #st.text_area('context with top seq & prob:', result, height=400)  

  sum_lst = [sum_all['1'], sum_all['2-10'], sum_all['out of top 10']]
  #sum_lst = [['1', sum_all['1']], ['2-10', sum_all['2-10']]]
  #sum_lst = [sum_all['1'], sum_all['2-10'], sum_all['11-100'], sum_all['101~']]

  return sum_lst

def show_overall_summary(prefix_lst, select_lst):  
  # accumulate all

  # debug
  # for i, num in enumerate(select_lst): 
  #   pre_full_text = ''
  #   for prefix in prefix_lst:
  #     base_fn = '%s_%s_forward.json' % (prefix, num)
  #     result, avg_pick, avg_prob, token_count, sum_pick, sum_prob, sum_outscope, sum_hit_1, sum_top_10_len, full_text = calc_details(base_fn)  

  #     if pre_full_text == '':
  #       pre_full_text = full_text
  #     else:
  #       if pre_full_text != full_text:
  #         print('debug')
  #         pdb.set_trace() 

  # # 
  # pdb.set_trace()

  for prefix in prefix_lst:
    acc_token_count = 0
    acc_sum_pick = 0
    acc_sum_prob = 0 
    acc_sum_outscope_count = 0 
    acc_sum_outscope_len = 0 
    acc_sum_hit_1 = 0
    acc_sum_top_10_len = 0 
    acc_full_text_len = 0

    pre_full_text = ''
    for i, num in enumerate(select_lst):    
      base_fn = '%s_%s_forward.json' % (prefix, num)
      result, avg_pick, avg_prob, token_count, sum_pick, sum_prob, sum_outscope_count, sum_outscope_len, sum_hit_1, sum_top_10_len, full_text = calc_details(base_fn)  

      acc_token_count += token_count
      acc_sum_pick += sum_pick
      acc_sum_prob += sum_prob
      acc_sum_outscope_count += sum_outscope_count
      acc_sum_outscope_len += sum_outscope_len
      acc_sum_hit_1 += sum_hit_1
      acc_sum_top_10_len += sum_top_10_len
      acc_full_text_len += len(full_text)

    if acc_token_count > 0:
      # acc_sum_pick --> top 1~10
      keys_with_auto = acc_sum_pick + acc_sum_outscope_len
      keys_without_auto = acc_full_text_len
      saved_ratio = float(keys_without_auto-keys_with_auto)/keys_without_auto * 100

      st.text('[ %s ]\n' \
        'Autocomplete Effectiveness: %.1f%% (ratio of saving keystroke)\n' \
        '(sum) keys_with_auto: %s, top_10_keys: %s, out_of_scope: %s, sum_hit_1: %s\n' \
        'keys_without_auto: %s, top_10_len: %s, prob: %.2f' % (
        model_names[prefix], saved_ratio, 
        '{:,}'.format(keys_with_auto),
        '{:,}'.format(acc_sum_pick), 
        '{:,}'.format(acc_sum_outscope_len),
        '{:,}'.format(acc_sum_hit_1), 
        '{:,}'.format(keys_without_auto),
        '{:,}'.format(acc_sum_top_10_len), 
        acc_sum_prob, 
        ))

      st.text('%s & %.1f\\%% & %s & %s & %s & %s & %s \\\\' % (model_names[prefix], saved_ratio, '{:,}'.format(keys_with_auto), '{:,}'.format(acc_sum_pick), '{:,}'.format(acc_sum_outscope_len), '{:,}'.format(acc_sum_hit_1), '{:,}'.format(keys_without_auto)))

      # st.text('* acc_token_count =%s --> (avg) hits: %.2f, keys: %.2f, prob: %.2f, outscope: %.2f' % (
      #     acc_token_count, 
      #     float(acc_sum_hit_1)/acc_token_count,
      #     float(acc_sum_pick)/acc_token_count, 
      #     float(acc_sum_prob)/acc_token_count, 
      #     float(acc_sum_outscope_count)/acc_token_count))

def calc_height(s):
  return int(len(s) / 10 * 3) + 30

def remove_end_of_claim_text(gen_text):  
  tag = '<|end_of_claim|>'
  pos = gen_text.find(tag)
  if pos > 0:
    gen_text = gen_text[:pos+len(tag)]
    return gen_text

  tag = '<|endoftext|>'
  pos = gen_text.find(tag)
  if pos > 0:
    gen_text = gen_text[:pos+len(tag)]

  return gen_text

def dump_pos_data(prefix_lst, select_lst):
  #statistics = [[0]*3]*2048
  statistics = []
  for i in range(2048):
    statistics.append([0,0,0])

  #results.append(['model', 'pos', 'key'])
  #results.append(['model', 'patent_claim', 'pos', 'top-1', 'top-2~10', 'out of top 10'])
  max_len = -1
  for prefix in prefix_lst:
    model_name = model_names[prefix].replace('PatentGPT-J-', '')
    if model_name != '456M':
      continue

    #total = {}
    for i, num in enumerate(select_lst):    
      base_fn = '%s_%s_forward.json' % (prefix, num)
      full_fn = os.path.join(folder, base_fn)
      if os.path.exists(full_fn) == False:
        continue 

      with open(full_fn) as f:
        result = json.loads(f.read())
        print("Loaded: %s" % full_fn)

      lst = result['output']
      for j, tk in enumerate(lst[:-1]):
        max_len = max(j, max_len)        
        next_top_seq = int(tk['actual_next_token_top_seq'])
        #next_top_prob = float(tk['actual_next_token_top_prob']) 

        top_1 = top_2_to_10 = out_of_scope = 0
        if next_top_seq == 0:
          top_1 = 1
          tag = 'top-1'
          statistics[j][0] += 1
        elif next_top_seq > 0 and next_top_seq < 10:
          top_2_to_10 = 1
          tag = 'top-2~10'
          statistics[j][1] += 1
        else:
          out_of_scope = 1
          tag = 'out-of-scope'
          statistics[j][2] += 1

        #total[tag] = total.get(tag, 0) + 1
        #results.append([model_name, str(i+1), tag])
        #results.append([model_name, str(i+1), tag])
        #results.append([model_name, num, str(i+1), tag])
        #results.append([model_name, num, i+1, top_1, top_2_to_10, out_of_scope])
        #pdb.set_trace()
      #pdb.set_trace() 

  dump_file = 'dump4.txt'
  #pdb.set_trace() 
  with open(dump_file, 'w') as f:
    for i in range(max_len+1):
      f.write('%s, top-1, %s\n' % (i+1, statistics[i][0]))
      f.write('%s, top-2~10, %s\n' % (i+1, statistics[i][1]))
      f.write('%s, out_of_scope, %s\n' % (i+1, statistics[i][2]))
      # f.write('%s\n' % ', '.join([str(i+1)] + [ str(v) for v in statistics[i] ] ))
  print('saved: %s' % dump_file)


  # dump_file = 'dump2.txt'
  # with open(dump_file, 'w') as f:
  #   for line in results:
  #     f.write('%s\n' % ', '.join(line))
  # print('saved: %s' % dump_file)


def calc_sentence_similarity(sent_model, sent1, sent2):
  rewards = []
  embedding1 = sent_model.encode(sent1, convert_to_tensor=True)
  embedding2 = sent_model.encode(sent2, convert_to_tensor=True)
  similarity = util.cos_sim(embedding1, embedding2)[0][0]

  #pdb.set_trace()
  
  return similarity

sent_model = 'patent/st-aipd-nlp-g' 
print('loading SentenceTransformer: %s' % sent_model)
sent_aipd = SentenceTransformer(sent_model)

def load_data(demo):
  fn = 'ppo_open_llama_3b_v2.run.12.delta.txt'
  #fn = 'ppo_output/ppo_open_llama_3b_v2.run.12.delta.txt'
  with open(fn, 'r') as f:
    rows = json.load(f)

  if demo == 'demo1':
    new_rows = [ row for row in rows if row['instruction'].find('child') > 0 ]
  elif demo == 'demo2':
    new_rows = [ row for row in rows if row['instruction'].find('parent') > 0 ]
  else:
    new_rows = []    

  return new_rows

container_style = """
  <style>
      .container1 {
          border: 2px solid #3498db;
          border-radius: 8px;
          padding: 10px;
          margin-bottom: 20px;
      }
      .container2 {
          /* Add styles for Container 2 if needed */
      }
  </style>
"""

def main():
  st.set_page_config(  # Alternate names: setup_page, page, layout
    layout="wide",  # Can be "centered" or "wide". In the future also "dashboard", etc.
    initial_sidebar_state="auto",  # Can be "auto", "expanded", "collapsed"
    page_title="Demo 1",  # String or None. Strings get appended with "• Streamlit".
    page_icon=None,  # String, anything supported by st.image, or None.
  )

  opt_1 = 'parent --> child'
  opt_2 = 'child --> parent'
  options = [opt_1, opt_2]
  rows = None
  pos = None
  patent_num = ''
  claim_num1 = ''
  claim_num2 = ''
  instruction= ''
  input_text = ''
  output_text = ''
  response = ''
  query = ''
  score_lst_1 = 0
  score_lst_2 = 0
  rewards = ''  
  with st.container():
    col1, col2, col3 = st.columns([3, 5, 2])
    with col1:
      selected_option = st.selectbox('Select a demo:', options)
    if selected_option == opt_1:
      rows = load_data('demo1')
      msg = 'novelty = sim1-sim2'
      #msg = 'delta of similarities<br>(sim1-sim2)'
      c1_tag = 'pc'
      c2_tag = 'cc1'
      c3_tag = 'cc2'
    elif selected_option == opt_2:  
      rows = load_data('demo2')
      msg = 'similarity of<br>(pc1) and (pc2)'
      c1_tag = 'cc'
      c2_tag = 'pc1'
      c3_tag = 'pc2'
    else:
      st.text('Unknown option')
      return
    #rows = rows[:5000] # for debugging

    with col2:
      pos = st.slider("", 1, len(rows))
      #pos = st.slider("Degree of novelty (Generated v. Actual)", 1, len(rows))
      for i in range(pos):
        #prompt = '%s' % rows[i]
        #pdb.set_trace()

        patent_num = rows[i]['patent_num']
        claim_num1 = rows[i]['claim_num1']
        claim_num2 = rows[i]['claim_num2']
        instruction= rows[i]['instruction']
        input_text = rows[i]['input']
        output_text = rows[i]['output']
        response = rows[i]['response']
        query = rows[i]['query']
        score_lst_1 = rows[i]['score_lst_1']
        score_lst_2 = rows[i]['score_lst_2']
        delta = rows[i]['delta']
        rewards = rows[i]['rewards']
    with col3:
      #v = round(float(score_lst_1)-float(score_lst_2), 4)
      #v = delta #round(delta,10)
      st.markdown("<center><h7>%s<br>%s</h7></center>" % (msg, delta), unsafe_allow_html=True)
      #  style='text-align: center; color: black;'
        

  # selectbox_placeholder = st.empty()
  # selected_option = selectbox_placeholder.selectbox('Select a demo:', options)
  # container1 = st.container()


  # with st.container():
  #     col1, col2 = st.columns(2)
  #     with col1:
  #         st.write('Caption for first chart')
  #     with col2:
  #         st.line_chart((0,1), height=100)
  # with st.container():
  #     col1, col2 = st.columns(2)
  #     with col1:
  #         st.write('Caption for second chart')
  #     with col2:
  #         st.line_chart((1,0), height=100)

  #st.write('patent_num:', patent_num)
  # st.write('claim_num1:', claim_num1)
  # st.write('claim_num2:', claim_num2)
  st.write('(instruction) ', instruction)

  with st.container():
    with st.container(border=True):
      st.write('(%s) [ %s ]\n%s' % (c1_tag, patent_num, input_text))
      #st.write('input:' % patent_num)
      #st.write('input:\n', input_text)

    #container1.markdown("<div class='container1'>", unsafe_allow_html=True)
    col1, col2 = st.columns(2)
    with col1:
      with st.container(border=True):
        st.write('(%s) (actual)' % c2_tag)
        st.write(output_text)
    with col2:
      with st.container(border=True):
        st.write('(%s) (generated)' % c3_tag)
        st.write(response)

    col1, col2 = st.columns(2)
    with col1:
      with st.container(border=True):
        st.write('(sim1) similarity between %s and %s+%s: %s' % (c1_tag, c1_tag, c2_tag, str(score_lst_1)))
    with col2:
      with st.container(border=True):
        st.write('(sim2) similarity between %s and %s+%s: %s' % (c1_tag, c1_tag, c3_tag, str(score_lst_2)))

    #container1.markdown("</div>", unsafe_allow_html=True)

    # st.write("In Container 1")
    # table_name = st.radio("Please Select Table", list_of_tables)

  # st.write('output:')
  # st.write(output_text)
  # st.write('response:')
  # st.write(response)
  #st.write('query:', query)
  # st.write('score_lst_1:', score_lst_1)
  # st.write('score_lst_2:', score_lst_2)
  # st.write('rewards:', rewards)
  # st.text('hello')

  # dict_keys(['patent_num', 'claim_num1', 'claim_num2', 'instruction', 'input', 'output', 'query', 'response', 'score_lst_1', 'score_lst_2', 'rewards'])

  # st.subheader("Inspecting PatentGPT-J Model Evaluation")



  # num_set = set()
  # fn_lst = glob.glob(os.path.join(folder, '*'))
  # for i, fn in enumerate(fn_lst):
  #   for prefix in prefix_lst:    
  #     v = re.search('(.*?)%s\_(\d+\_\d+)\_(.*?)' % prefix, fn)
  #     if v is None:
  #       v = re.search('(.*?)%s\_(\w+\_\d+)\_(.*?)' % prefix, fn)

  #       #pdb.set_trace()
  #       if v is None:
  #         #pdb.set_trace()
  #         continue

  #     v = v.group(2)
  #     if first_claim_only:
  #       if v.endswith('_1'):
  #         num_set.add(v)
  #     else: 
  #       num_set.add(v)

  # num_lst = list(num_set)
  # num_lst.sort()

  # select_lst = []
  # for i, num in enumerate(num_lst):
  #   all_existed = True
  #   for prefix in prefix_lst:
  #     fn = os.path.join(folder, '%s_%s_forward.json' % (prefix, num))
  #     if os.path.exists(fn) == False:
  #       all_existed = False
  #       break
  #   if all_existed: 
  #     select_lst.append(num)
  # select_lst.sort()

  # if len(select_lst) == 0:
  #   st.text('select_lst is empty')
  #   return 

  # if dump_pos_data_for_reporting:
  #   dump_pos_data(prefix_lst, select_lst)
  #   st.text('Dump data: done')
  #   return

  # # debug
  # #base_fn = 'my_gptj_6b_tpu_size_8_11212952_1_forward.json'
  # #base_fn = 'pgj_small_text-1_1_forward.json'
  # #_ = show_avg(base_fn)

  # if enable_summary_button:
  #   if st.button('Show Summary'):
  #     st.text('len(select_lst) = %s' % len(select_lst))
  #     show_overall_summary(prefix_lst, select_lst)

  # # if 'num' not in st.session_state:
  # #   num = random.choice(select_lst)
  # #   st.session_state['num'] = num

  # # set_state('num', num)
  # # def set_state(k, v):
  # #   if k not in st.session_state:
  # #     st.session_state[ k ] = v

  # show_patent_lst = [ s.replace('_', ' (claim ') + ')' for s in select_lst]
  # selected = st.selectbox("Choose a patent claim", show_patent_lst)
  # num = selected.replace(')', '').replace(' (claim ', '_')
  # if st.button('Random pick'):
  #   num = random.choice(select_lst)

  # st.text('Selected: %s' % num)
  # st.session_state['num'] = num

  # avgs = []
  # for prefix in prefix_lst:
  #   base_fn = '%s_%s_forward.json' % (prefix, num)
  #   one_avg = show_avg(base_fn, model_names[prefix], num)
  #   if one_avg is not None:
  #     avgs.append(one_avg) 

  # # debug
  # #pdb.set_trace()
  # #return 
  # #

  # data_lst = []
  # for i in range(len(avgs[0])):
  #   row = []
  #   for j, prefix in enumerate(prefix_lst):
  #     row.append(avgs[j][i])
  #   data_lst.append(row)

  # df = pd.DataFrame(data_lst, index=['1','2-10','out of top 10'])
  # #df = pd.DataFrame(data_lst, index=['1','2-10','11-100','101~'])

  # #  ], index=['(a) 1','(b) 2-10','(c) 11-100','(d) 101~'])
  # #  [avgs[0][0], avgs[1][0], avgs[2][0]],
  # #  [avgs[0][1], avgs[1][1], avgs[2][1]],
  # #  [avgs[0][2], avgs[1][2], avgs[2][2]],
  # #  [avgs[0][3], avgs[1][3], avgs[2][3]],
  
  # #df = pd.DataFrame([[1,2],[3,1]], columns=['a', 'b'])
  # #df = pd.DataFrame([
  # #  [sum1[0], sum1[1], sum1[2], sum1[3]],
  # #  [sum2[0], sum2[1], sum2[2], sum2[3]],
  # #  [sum3[0], sum3[1], sum3[2], sum3[3]],
  # #  ]) #, index=['(a) 1','(b) 2-10','(c) 11-100','(d) 101~'])
  # #df = pd.DataFrame.from_dict(sum_all, orient='index')
  # #st.line_chart(df)

  # #data_canada = px.data.gapminder().query("country == 'Canada'")
  # #fig = px.bar(data_canada, x='year', y='pop')

  # if st.button('Show chart'):
  #   fig = px.bar(df, barmode='group')
  #   st.plotly_chart(fig, use_container_width=True)
  #   #fig.show()
  #   #st.area_chart(df)  
  #   #st.bar_chart(df)    

  # #
  # base_fn = '%s_%s_forward.json' % (prefix_lst[ id_to_scroll ], st.session_state['num'])
  # result, avg_pick, avg_prob, _, _, _, _, _, _, _, _ = calc_details(base_fn)
  # recv = result['recv']
  # lst = result['output']
  # input_tokens = result['input']

  # # (Pdb) print(token_pos_lst[0].keys())
  # #dict_keys(['idx', 'gen_text', 'actual_next_token_text', 'actual_next_token_top_seq', 'actual_next_token_top_prob', 'top_n_lst'])

  # height = calc_height(recv['context'])
  # st.text_area('context:', recv['context'], height=height)

  # pos = st.slider("Token position", 0, len(lst))
  # prompt = ''
  # for i in range(pos+1):
  #   prompt += input_tokens[i]['text']
  # height = calc_height(prompt)
  # st.text_area('prompt:', prompt, height=height)

  # ch = handle_char_return(lst[pos]['actual_next_token_text'])
  # st.text('actual_next_token_text: %s --> pick seq: %s (prob: %.2f)' % (ch, int(lst[pos]['actual_next_token_top_seq'])+1, 
  #   float(lst[pos]['actual_next_token_top_prob'])))

  # st.text('top 10 tokens:')
  # for i, v in enumerate(lst[pos]['top_n_lst']):
  #   ch = handle_char_return(v['top_n_text'])
  #   st.text('[ %s ][ %s ]( %.2f )' % (i+1, ch, float(v['top_n_prob'])))

  # gen_text = lst[pos]['gen_text']
  # gen_text = remove_end_of_claim_text(gen_text)

  # st.text('gen_text: %s' % gen_text)
  # #st.text("done. ok.")
  # #st.text('result:\n%s' % result)

if __name__ == "__main__":
  main()

#def load_data_pre(demo):
#   fn = 'ppo_output/ppo_open_llama_3b_v2.run.12.keep.txt'
#   with open(fn, 'r') as f:
#     rows = json.load(f)

#   new_rows = []
#   for i, row in enumerate(rows):
#     item1 = {}
#     item2 = {}
#     if demo == 'demo1':
#       item1[ 'delta' ] = abs(row['score_lst_1'][0] - row['score_lst_2'][0])
#       item2[ 'delta' ] = abs(row['score_lst_1'][1] - row['score_lst_2'][1])
#     elif demo == 'demo2':
#       #pdb.set_trace()
#       item1[ 'delta' ] = calc_sentence_similarity(sent_aipd, row['output'][0], row['response'][0])
#       item2[ 'delta' ] = calc_sentence_similarity(sent_aipd, row['output'][1], row['response'][1])

#       print('[ %s ] detla = %s' % (i, item1[ 'delta' ]))

#     for k in row.keys():
#       item1[ k ] = row[ k ][0]
#       item2[ k ] = row[ k ][1]

#     if demo == 'demo1':
#       if item1['instruction'].find('child') > 0:
#         new_rows.append(item1)
#       if item2['instruction'].find('child') > 0:
#         new_rows.append(item2)
#     elif demo == 'demo2':
#       if item1['instruction'].find('parent') > 0:
#         new_rows.append(item1)
#       if item2['instruction'].find('parent') > 0:
#         new_rows.append(item2)

#   # Assuming new_rows is your list of dictionaries
#   sorted_rows = sorted(new_rows, key=lambda x: x['delta'])

#   # kv = {}
#   # for i, row in enumerate(new_rows):
#   #   if diff > 0.0001:
#   #     kv[i] = round(diff, 4)

#   # sorted_rows = []
#   # sorted_kv = sorted(kv.items(), key=lambda x:x[1])
#   # for k, v in sorted_kv:
#   #   sorted_rows.append(new_rows[k])

#   #pdb.set_trace() 

#   return sorted_rows