0.1.1 refactor and ui changes
Browse files- app.py +32 -130
- ethics.py +48 -0
- explainable.py +46 -0
- graph.html +2 -2
- kron/llm_predictor/KronHFHubLLM.py +85 -0
- measurable.py +26 -0
- requirements.txt +1 -0
app.py
CHANGED
@@ -33,6 +33,7 @@ index_model = "Writer/camel-5b-hf"
|
|
33 |
INDEX_NAME = f"{index_model.replace('/', '-')}-default-no-coref"
|
34 |
persist_path = f"storage/{INDEX_NAME}"
|
35 |
MAX_LENGTH = 1024
|
|
|
36 |
|
37 |
import baseten
|
38 |
@st.cache_resource
|
@@ -68,10 +69,6 @@ f'''
|
|
68 |
)
|
69 |
st.caption('''###### corpus by [@[email protected]](https://sigmoid.social/@ArxivHealthcareNLP)''')
|
70 |
st.caption('''###### KG Questions by [arylwen](https://github.com/arylwen/mlk8s)''')
|
71 |
-
# st.write(
|
72 |
-
#f'''
|
73 |
-
##### How can <what most are doing> help with <what few are doing>?
|
74 |
-
#''')
|
75 |
|
76 |
from llama_index import StorageContext
|
77 |
from llama_index import ServiceContext
|
@@ -99,7 +96,6 @@ enc = tiktoken.get_encoding("gpt2")
|
|
99 |
tokenizer = lambda text: enc.encode(text, allowed_special={"<|endoftext|>"})
|
100 |
globals_helper._tokenizer = tokenizer
|
101 |
|
102 |
-
|
103 |
def set_openai_local():
|
104 |
openai.api_key = os.environ['LOCAL_OPENAI_API_KEY']
|
105 |
openai.api_base = os.environ['LOCAL_OPENAI_API_BASE']
|
@@ -111,12 +107,15 @@ def set_openai():
|
|
111 |
openai.api_base = os.environ['DAVINCI_OPENAI_API_BASE']
|
112 |
os.environ['OPENAI_API_KEY'] = os.environ['DAVINCI_OPENAI_API_KEY']
|
113 |
os.environ['OPENAI_API_BASE'] = os.environ['DAVINCI_OPENAI_API_BASE']
|
114 |
-
|
|
|
115 |
def get_hf_predictor(query_model):
|
116 |
# no embeddings for now
|
117 |
set_openai_local()
|
118 |
-
llm=HuggingFaceHub(repo_id=query_model, task="text-generation",
|
119 |
-
|
|
|
|
|
120 |
huggingfacehub_api_token=hf_api_key)
|
121 |
llm_predictor = LLMPredictor(llm)
|
122 |
return llm_predictor
|
@@ -264,7 +263,7 @@ else :
|
|
264 |
with query:
|
265 |
answer_model = st.radio(
|
266 |
"Choose the model used for inference:",
|
267 |
-
('
|
268 |
)
|
269 |
|
270 |
if answer_model == 'openai/text-davinci-003':
|
@@ -304,7 +303,7 @@ elif answer_model == 'baseten/Camel-5b':
|
|
304 |
most_connected = random.sample(graph_nodes[:100], 5)
|
305 |
low_connected = get_networkx_low_connected_components( "", persist_path)
|
306 |
least_connected = random.sample(low_connected, 5)
|
307 |
-
elif answer_model == '
|
308 |
query_model = 'Writer/camel-5b-hf'
|
309 |
print(answer_model)
|
310 |
clear_question(query_model)
|
@@ -314,8 +313,8 @@ elif answer_model == 'Local-Camel':
|
|
314 |
most_connected = random.sample(graph_nodes[:100], 5)
|
315 |
low_connected = get_networkx_low_connected_components( "", persist_path)
|
316 |
least_connected = random.sample(low_connected, 5)
|
317 |
-
elif answer_model == '
|
318 |
-
query_model = '
|
319 |
clear_question(query_model)
|
320 |
query_engine = build_hf_query_engine(query_model, persist_path)
|
321 |
graph_nodes = get_networkx_graph_nodes( "", persist_path)
|
@@ -325,21 +324,25 @@ elif answer_model == 'HF-TKI':
|
|
325 |
else:
|
326 |
print('This is a bug.')
|
327 |
|
328 |
-
# to clear input box
|
329 |
def submit():
|
330 |
st.session_state.question = st.session_state.question_input
|
331 |
st.session_state.question_input = ''
|
332 |
st.session_state.question_answered = False
|
333 |
|
334 |
with st.sidebar:
|
335 |
-
|
|
|
|
|
|
|
|
|
|
|
336 |
option_2 = st.selectbox("What few are studying:", least_connected, disabled=True)
|
337 |
|
338 |
with query:
|
339 |
-
st.caption(f'''######
|
340 |
-
#st.caption(f'''Model, question, answer and rating are logged to improve KG Questions.''')
|
341 |
question = st.text_input("Enter a question, e.g. What benchmarks can we use for QA?", key='question_input', on_change=submit )
|
342 |
-
|
343 |
if(st.session_state.question):
|
344 |
try :
|
345 |
with query:
|
@@ -363,121 +366,13 @@ if(st.session_state.question):
|
|
363 |
from streamlit_star_rating import st_star_rating
|
364 |
stars = st_star_rating("", maxValue=5, defaultValue=3, key="answer_rating")
|
365 |
st.write(answer_str)
|
366 |
-
|
367 |
with measurable:
|
368 |
-
from
|
369 |
-
|
370 |
-
from PIL import Image
|
371 |
-
wc_all, wc_question, wc_reference = st.columns([3, 3, 3])
|
372 |
-
wordcloud = WordCloud(max_font_size=50, max_words=1000, background_color="white")
|
373 |
-
with wc_all:
|
374 |
-
image = Image.open('docs/images/all_papers_wordcloud.png')
|
375 |
-
st.image(image)
|
376 |
-
st.caption('''###### Corpus term frequecy.''')
|
377 |
-
with wc_question:
|
378 |
-
wordcloud_q = wordcloud.generate(answer_str)
|
379 |
-
st.image(wordcloud_q.to_array())
|
380 |
-
st.caption('''###### Answer term frequecy.''')
|
381 |
-
with wc_reference:
|
382 |
-
all_reference_texts = ''
|
383 |
-
for nodewithscore in answer.source_nodes:
|
384 |
-
node = nodewithscore.node
|
385 |
-
from llama_index.schema import NodeRelationship
|
386 |
-
#if NodeRelationship.SOURCE in node.relationships:
|
387 |
-
all_reference_texts = all_reference_texts + '\n' + node.text
|
388 |
-
wordcloud_r = wordcloud.generate(all_reference_texts)
|
389 |
-
st.image(wordcloud_r.to_array())
|
390 |
-
st.caption('''###### Reference plus graph term frequecy.''')
|
391 |
-
|
392 |
with explainable:
|
393 |
-
|
394 |
-
|
395 |
-
graph = Network(height="450px", width="100%")
|
396 |
-
sources_table = []
|
397 |
-
#all_reference_texts = ''
|
398 |
-
for nodewithscore in answer.source_nodes:
|
399 |
-
node = nodewithscore.node
|
400 |
-
from llama_index.schema import NodeRelationship
|
401 |
-
if NodeRelationship.SOURCE in node.relationships:
|
402 |
-
node_id = node.relationships[NodeRelationship.SOURCE].node_id
|
403 |
-
node_id = node_id.split('/')[-1]
|
404 |
-
title = node_id.split('.')[2].replace('_', ' ')
|
405 |
-
link = '.'.join(node_id.split('.')[:2])[:10]
|
406 |
-
link = f'https://arxiv.org/abs/{link}'
|
407 |
-
href = f'<a target="_blank" href="{link}">{title}</a>'
|
408 |
-
sources_table.extend([[href, node.text]])
|
409 |
-
#all_reference_texts = all_reference_texts + '\n' + node.text
|
410 |
-
else:
|
411 |
-
#st.write(node.text) TODO second level relationships
|
412 |
-
rel_map = node.metadata['kg_rel_map']
|
413 |
-
for concept in rel_map.keys():
|
414 |
-
#st.write(concept)
|
415 |
-
graph.add_node(concept, concept, title=concept)
|
416 |
-
rels = rel_map[concept]
|
417 |
-
for rel in rels:
|
418 |
-
graph.add_node(rel[1], rel[1], title=rel[1])
|
419 |
-
graph.add_edge(concept, rel[1], title=rel[0])
|
420 |
-
# --- display the query terms graph
|
421 |
-
st.session_state.graph_name = 'graph.html'
|
422 |
-
graph.save_graph(st.session_state.graph_name)
|
423 |
-
import streamlit.components.v1 as components
|
424 |
-
graphHtml = open(st.session_state.graph_name, 'r', encoding='utf-8')
|
425 |
-
source_code = graphHtml.read()
|
426 |
-
components.html(source_code, height = 500)
|
427 |
-
# --- display the reference texts table
|
428 |
-
import pandas as pd
|
429 |
-
df = pd.DataFrame(sources_table)
|
430 |
-
df.columns = ['paper', 'relevant text']
|
431 |
-
st.markdown(""" <style>
|
432 |
-
table[class*="dataframe"] {
|
433 |
-
font-size: 10px;
|
434 |
-
}
|
435 |
-
</style> """, unsafe_allow_html=True)
|
436 |
-
st.write(df.to_html(escape=False), unsafe_allow_html=True)
|
437 |
-
# reference text wordcloud
|
438 |
-
#st.session_state.reference_wcloud = all_reference_texts
|
439 |
-
|
440 |
-
with ethical:
|
441 |
-
st.write('##### Bias, risks, limitations and terms of use for the models.')
|
442 |
-
ethics_statement = []
|
443 |
-
falcon = ['hf/tiiuae/falcon-7b-instruct', '<a target="_blank" href="https://huggingface.co/tiiuae/falcon-7b">Bias, Risks, and Limitations</a>']
|
444 |
-
cohere = ['cohere/command', '<a target="_blank" href="https://cohere.com/terms-of-use">Terms of use</a>']
|
445 |
-
camel = ['baseten/Camel-5b', '<a target="_blank" href="https://huggingface.co/Writer/camel-5b-hf">Bias, Risks, and Limitations</a>']
|
446 |
-
davinci = ['openai/text-davinci-003', '<a target="_blank" href="https://openai.com/policies/terms-of-use">Terms of Use</a>']
|
447 |
-
|
448 |
-
ethics_statement.extend([falcon, cohere, camel, davinci])
|
449 |
-
df = pd.DataFrame(ethics_statement)
|
450 |
-
df.columns = ['model', 'model link']
|
451 |
-
st.markdown(""" <style>
|
452 |
-
table[class*="dataframe"] {
|
453 |
-
font-size: 14px;
|
454 |
-
}
|
455 |
-
</style> """, unsafe_allow_html=True)
|
456 |
-
st.write(df.to_html(escape=False), unsafe_allow_html=True)
|
457 |
-
# license
|
458 |
-
st.write('')
|
459 |
-
st.write('##### How papers were included in the index based on license.')
|
460 |
-
st.caption(f'The paper id and title has been included in the index for a full attribution to the authors')
|
461 |
-
ccby = ['<a target="_blank" href="https://creativecommons.org/licenses/by/4.0/">CC BY</a>',
|
462 |
-
'<a target="_blank" href="https://github.com/arylwen/mlk8s/tree/main/apps/papers-kg">full content KG questions pipeline</a>']
|
463 |
-
ccbysa = ['<a target="_blank" href="https://creativecommons.org/licenses/by-sa/4.0/">CC BY-SA</a>',
|
464 |
-
'<a target="_blank" href="https://github.com/arylwen/mlk8s/tree/main/apps/papers-kg">full content KG questions pipeline</a>']
|
465 |
-
ccbyncsa = ['<a target="_blank" href="https://creativecommons.org/licenses/by-nc-sa/4.0/">CC NC-BY-NC-SA</a>',
|
466 |
-
'<a target="_blank" href="https://github.com/arylwen/mlk8s/tree/main/apps/papers-kg">full content KG questions pipeline</a>']
|
467 |
-
ccbyncnd = ['<a target="_blank" href="https://creativecommons.org/licenses/by-nc-nd/4.0/">CC NC-BY-NC-ND</a>',
|
468 |
-
'<a target="_blank" href="https://github.com/arylwen/mlk8s/tree/main/apps/papers-kg">arxiv metadata abstract KG questions pipeline</a>']
|
469 |
-
license_statement = [ccby, ccbysa, ccbyncsa, ccbyncnd]
|
470 |
-
df = pd.DataFrame(license_statement)
|
471 |
-
df.columns = ['license', 'how papers are used']
|
472 |
-
st.markdown(""" <style>
|
473 |
-
table[class*="dataframe"] {
|
474 |
-
font-size: 14px;
|
475 |
-
}
|
476 |
-
</style> """, unsafe_allow_html=True)
|
477 |
-
st.write(df.to_html(escape=False), unsafe_allow_html=True)
|
478 |
-
|
479 |
except Exception as e:
|
480 |
-
#print(f'{type(e)}, {e}')
|
481 |
answer_str = f'{type(e)}, {e}'
|
482 |
st.session_state.answer_rating = -1
|
483 |
st.write(f'An error occured, please try again. \n{answer_str}')
|
@@ -486,4 +381,11 @@ if(st.session_state.question):
|
|
486 |
req = st.session_state.question
|
487 |
if(__spaces__):
|
488 |
st.session_state.request_log.add_request_log_entry(query_model, req, answer_str, st.session_state.answer_rating)
|
489 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
INDEX_NAME = f"{index_model.replace('/', '-')}-default-no-coref"
|
34 |
persist_path = f"storage/{INDEX_NAME}"
|
35 |
MAX_LENGTH = 1024
|
36 |
+
MAX_NEW_TOKENS = 250
|
37 |
|
38 |
import baseten
|
39 |
@st.cache_resource
|
|
|
69 |
)
|
70 |
st.caption('''###### corpus by [@[email protected]](https://sigmoid.social/@ArxivHealthcareNLP)''')
|
71 |
st.caption('''###### KG Questions by [arylwen](https://github.com/arylwen/mlk8s)''')
|
|
|
|
|
|
|
|
|
72 |
|
73 |
from llama_index import StorageContext
|
74 |
from llama_index import ServiceContext
|
|
|
96 |
tokenizer = lambda text: enc.encode(text, allowed_special={"<|endoftext|>"})
|
97 |
globals_helper._tokenizer = tokenizer
|
98 |
|
|
|
99 |
def set_openai_local():
|
100 |
openai.api_key = os.environ['LOCAL_OPENAI_API_KEY']
|
101 |
openai.api_base = os.environ['LOCAL_OPENAI_API_BASE']
|
|
|
107 |
openai.api_base = os.environ['DAVINCI_OPENAI_API_BASE']
|
108 |
os.environ['OPENAI_API_KEY'] = os.environ['DAVINCI_OPENAI_API_KEY']
|
109 |
os.environ['OPENAI_API_BASE'] = os.environ['DAVINCI_OPENAI_API_BASE']
|
110 |
+
|
111 |
+
from kron.llm_predictor.KronHFHubLLM import KronHuggingFaceHub
|
112 |
def get_hf_predictor(query_model):
|
113 |
# no embeddings for now
|
114 |
set_openai_local()
|
115 |
+
#llm=HuggingFaceHub(repo_id=query_model, task="text-generation",
|
116 |
+
llm=KronHuggingFaceHub(repo_id=query_model, task="text-generation",
|
117 |
+
# model_kwargs={"temperature": 0.01, "max_new_tokens": MAX_NEW_TOKENS, 'frequency_penalty':1.17},
|
118 |
+
model_kwargs={"temperature": 0.01, "max_new_tokens": MAX_NEW_TOKENS },
|
119 |
huggingfacehub_api_token=hf_api_key)
|
120 |
llm_predictor = LLMPredictor(llm)
|
121 |
return llm_predictor
|
|
|
263 |
with query:
|
264 |
answer_model = st.radio(
|
265 |
"Choose the model used for inference:",
|
266 |
+
('Writer/camel-5b-hf', 'mosaicml/mpt-7b-instruct', 'hf/tiiuae/falcon-7b-instruct', 'cohere/command', 'baseten/Camel-5b', 'openai/text-davinci-003')
|
267 |
)
|
268 |
|
269 |
if answer_model == 'openai/text-davinci-003':
|
|
|
303 |
most_connected = random.sample(graph_nodes[:100], 5)
|
304 |
low_connected = get_networkx_low_connected_components( "", persist_path)
|
305 |
least_connected = random.sample(low_connected, 5)
|
306 |
+
elif answer_model == 'Writer/camel-5b-hf':
|
307 |
query_model = 'Writer/camel-5b-hf'
|
308 |
print(answer_model)
|
309 |
clear_question(query_model)
|
|
|
313 |
most_connected = random.sample(graph_nodes[:100], 5)
|
314 |
low_connected = get_networkx_low_connected_components( "", persist_path)
|
315 |
least_connected = random.sample(low_connected, 5)
|
316 |
+
elif answer_model == 'mosaicml/mpt-7b-instruct':
|
317 |
+
query_model = 'mosaicml/mpt-7b-instruct'
|
318 |
clear_question(query_model)
|
319 |
query_engine = build_hf_query_engine(query_model, persist_path)
|
320 |
graph_nodes = get_networkx_graph_nodes( "", persist_path)
|
|
|
324 |
else:
|
325 |
print('This is a bug.')
|
326 |
|
327 |
+
# to clear the input box
|
328 |
def submit():
|
329 |
st.session_state.question = st.session_state.question_input
|
330 |
st.session_state.question_input = ''
|
331 |
st.session_state.question_answered = False
|
332 |
|
333 |
with st.sidebar:
|
334 |
+
import gensim
|
335 |
+
m_connected = []
|
336 |
+
for item in most_connected:
|
337 |
+
if not item[0].lower() in gensim.parsing.preprocessing.STOPWORDS:
|
338 |
+
m_connected.extend([item[0].lower()])
|
339 |
+
option_1 = st.selectbox("What most are studying:", m_connected, disabled=True)
|
340 |
option_2 = st.selectbox("What few are studying:", least_connected, disabled=True)
|
341 |
|
342 |
with query:
|
343 |
+
st.caption(f'''###### Intended for educational and research purpose. Please do not enter any private or confidential information. Model, question, answer and rating are logged to improve KG Questions.''')
|
|
|
344 |
question = st.text_input("Enter a question, e.g. What benchmarks can we use for QA?", key='question_input', on_change=submit )
|
345 |
+
|
346 |
if(st.session_state.question):
|
347 |
try :
|
348 |
with query:
|
|
|
366 |
from streamlit_star_rating import st_star_rating
|
367 |
stars = st_star_rating("", maxValue=5, defaultValue=3, key="answer_rating")
|
368 |
st.write(answer_str)
|
|
|
369 |
with measurable:
|
370 |
+
from measurable import display_wordcloud
|
371 |
+
display_wordcloud(answer, answer_str)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
372 |
with explainable:
|
373 |
+
from explainable import explain
|
374 |
+
explain(answer)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
375 |
except Exception as e:
|
|
|
376 |
answer_str = f'{type(e)}, {e}'
|
377 |
st.session_state.answer_rating = -1
|
378 |
st.write(f'An error occured, please try again. \n{answer_str}')
|
|
|
381 |
req = st.session_state.question
|
382 |
if(__spaces__):
|
383 |
st.session_state.request_log.add_request_log_entry(query_model, req, answer_str, st.session_state.answer_rating)
|
384 |
+
else:
|
385 |
+
with measurable:
|
386 |
+
st.write(f'###### Ask a question to see a comparison between the corpus, answer and reference documents.')
|
387 |
+
with explainable:
|
388 |
+
st.write(f'###### Ask a question to see the knowledge graph and a list of reference documents.')
|
389 |
+
with ethical:
|
390 |
+
from ethics import display_ethics
|
391 |
+
display_ethics()
|
ethics.py
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
import pandas as pd
|
3 |
+
|
4 |
+
def display_ethics():
|
5 |
+
# ethics statement
|
6 |
+
display_ethics_statement()
|
7 |
+
# license
|
8 |
+
display_license_statement()
|
9 |
+
|
10 |
+
def display_license_statement():
|
11 |
+
st.write('')
|
12 |
+
st.write('##### How papers were included in the index based on license.')
|
13 |
+
st.caption(f'The paper id and title has been included in the index for a full attribution to the authors')
|
14 |
+
ccby = ['<a target="_blank" href="https://creativecommons.org/licenses/by/4.0/">CC BY</a>',
|
15 |
+
'<a target="_blank" href="https://github.com/arylwen/mlk8s/tree/main/apps/papers-kg">full content KG questions pipeline</a>']
|
16 |
+
ccbysa = ['<a target="_blank" href="https://creativecommons.org/licenses/by-sa/4.0/">CC BY-SA</a>',
|
17 |
+
'<a target="_blank" href="https://github.com/arylwen/mlk8s/tree/main/apps/papers-kg">full content KG questions pipeline</a>']
|
18 |
+
ccbyncsa = ['<a target="_blank" href="https://creativecommons.org/licenses/by-nc-sa/4.0/">CC NC-BY-NC-SA</a>',
|
19 |
+
'<a target="_blank" href="https://github.com/arylwen/mlk8s/tree/main/apps/papers-kg">full content KG questions pipeline</a>']
|
20 |
+
ccbyncnd = ['<a target="_blank" href="https://creativecommons.org/licenses/by-nc-nd/4.0/">CC NC-BY-NC-ND</a>',
|
21 |
+
'<a target="_blank" href="https://github.com/arylwen/mlk8s/tree/main/apps/papers-kg">arxiv metadata abstract KG questions pipeline</a>']
|
22 |
+
license_statement = [ccby, ccbysa, ccbyncsa, ccbyncnd]
|
23 |
+
df = pd.DataFrame(license_statement)
|
24 |
+
df.columns = ['license', 'how papers are used']
|
25 |
+
st.markdown(""" <style>
|
26 |
+
table[class*="dataframe"] {
|
27 |
+
font-size: 14px;
|
28 |
+
}
|
29 |
+
</style> """, unsafe_allow_html=True)
|
30 |
+
st.write(df.to_html(escape=False), unsafe_allow_html=True)
|
31 |
+
|
32 |
+
def display_ethics_statement():
|
33 |
+
st.write('##### Bias, risks, limitations and terms of use for the models.')
|
34 |
+
ethics_statement = []
|
35 |
+
falcon = ['hf/tiiuae/falcon-7b-instruct', '<a target="_blank" href="https://huggingface.co/tiiuae/falcon-7b">Bias, Risks, and Limitations</a>']
|
36 |
+
cohere = ['cohere/command', '<a target="_blank" href="https://cohere.com/terms-of-use">Terms of use</a>']
|
37 |
+
camel = ['baseten/Camel-5b', '<a target="_blank" href="https://huggingface.co/Writer/camel-5b-hf">Bias, Risks, and Limitations</a>']
|
38 |
+
davinci = ['openai/text-davinci-003', '<a target="_blank" href="https://openai.com/policies/terms-of-use">Terms of Use</a>']
|
39 |
+
|
40 |
+
ethics_statement.extend([falcon, cohere, camel, davinci])
|
41 |
+
df = pd.DataFrame(ethics_statement)
|
42 |
+
df.columns = ['model', 'model link']
|
43 |
+
st.markdown(""" <style>
|
44 |
+
table[class*="dataframe"] {
|
45 |
+
font-size: 14px;
|
46 |
+
}
|
47 |
+
</style> """, unsafe_allow_html=True)
|
48 |
+
st.write(df.to_html(escape=False), unsafe_allow_html=True)
|
explainable.py
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pyvis.network import Network
|
2 |
+
import streamlit as st
|
3 |
+
|
4 |
+
def explain(answer):
|
5 |
+
graph = Network(height="450px", width="100%")
|
6 |
+
sources_table = []
|
7 |
+
#all_reference_texts = ''
|
8 |
+
for nodewithscore in answer.source_nodes:
|
9 |
+
node = nodewithscore.node
|
10 |
+
from llama_index.schema import NodeRelationship
|
11 |
+
if NodeRelationship.SOURCE in node.relationships:
|
12 |
+
node_id = node.relationships[NodeRelationship.SOURCE].node_id
|
13 |
+
node_id = node_id.split('/')[-1]
|
14 |
+
title = node_id.split('.')[2].replace('_', ' ')
|
15 |
+
link = '.'.join(node_id.split('.')[:2])[:10]
|
16 |
+
link = f'https://arxiv.org/abs/{link}'
|
17 |
+
href = f'<a target="_blank" href="{link}">{title}</a>'
|
18 |
+
sources_table.extend([[href, node.text]])
|
19 |
+
#all_reference_texts = all_reference_texts + '\n' + node.text
|
20 |
+
else:
|
21 |
+
#st.write(node.text) TODO second level relationships
|
22 |
+
rel_map = node.metadata['kg_rel_map']
|
23 |
+
for concept in rel_map.keys():
|
24 |
+
#st.write(concept)
|
25 |
+
graph.add_node(concept, concept, title=concept)
|
26 |
+
rels = rel_map[concept]
|
27 |
+
for rel in rels:
|
28 |
+
graph.add_node(rel[1], rel[1], title=rel[1])
|
29 |
+
graph.add_edge(concept, rel[1], title=rel[0])
|
30 |
+
# --- display the query terms graph
|
31 |
+
st.session_state.graph_name = 'graph.html'
|
32 |
+
graph.save_graph(st.session_state.graph_name)
|
33 |
+
import streamlit.components.v1 as components
|
34 |
+
graphHtml = open(st.session_state.graph_name, 'r', encoding='utf-8')
|
35 |
+
source_code = graphHtml.read()
|
36 |
+
components.html(source_code, height = 500)
|
37 |
+
# --- display the reference texts table
|
38 |
+
import pandas as pd
|
39 |
+
df = pd.DataFrame(sources_table)
|
40 |
+
df.columns = ['paper', 'relevant text']
|
41 |
+
st.markdown(""" <style>
|
42 |
+
table[class*="dataframe"] {
|
43 |
+
font-size: 10px;
|
44 |
+
}
|
45 |
+
</style> """, unsafe_allow_html=True)
|
46 |
+
st.write(df.to_html(escape=False), unsafe_allow_html=True)
|
graph.html
CHANGED
@@ -88,8 +88,8 @@
|
|
88 |
|
89 |
|
90 |
// parsing and collecting nodes and edges from the python
|
91 |
-
nodes = new vis.DataSet([
|
92 |
-
edges = new vis.DataSet([
|
93 |
|
94 |
nodeColors = {};
|
95 |
allNodes = nodes.get({ returnType: "Object" });
|
|
|
88 |
|
89 |
|
90 |
// parsing and collecting nodes and edges from the python
|
91 |
+
nodes = new vis.DataSet([]);
|
92 |
+
edges = new vis.DataSet([]);
|
93 |
|
94 |
nodeColors = {};
|
95 |
allNodes = nodes.get({ returnType: "Object" });
|
kron/llm_predictor/KronHFHubLLM.py
ADDED
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
|
3 |
+
import time
|
4 |
+
|
5 |
+
from typing import Any, Callable, List, Optional
|
6 |
+
|
7 |
+
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
8 |
+
from langchain.llms.huggingface_hub import HuggingFaceHub
|
9 |
+
|
10 |
+
import logging
|
11 |
+
logger = logging.getLogger(__name__)
|
12 |
+
|
13 |
+
from tenacity import (
|
14 |
+
before_sleep_log,
|
15 |
+
retry,
|
16 |
+
retry_if_exception_type,
|
17 |
+
stop_after_attempt,
|
18 |
+
wait_exponential,
|
19 |
+
)
|
20 |
+
|
21 |
+
|
22 |
+
def _create_retry_decorator(llm: KronHuggingFaceHub) -> Callable[[Any], Any]:
|
23 |
+
#import cohere
|
24 |
+
|
25 |
+
min_seconds = 4
|
26 |
+
max_seconds = 10
|
27 |
+
# Wait 2^x * 1 second between each retry starting with
|
28 |
+
# 4 seconds, then up to 10 seconds, then 10 seconds afterwards
|
29 |
+
return retry(
|
30 |
+
reraise=True,
|
31 |
+
stop=stop_after_attempt(llm.max_retries),
|
32 |
+
wait=wait_exponential(multiplier=1, min=min_seconds, max=max_seconds),
|
33 |
+
retry=(retry_if_exception_type(KronHFHubRateExceededException)),
|
34 |
+
before_sleep=before_sleep_log(logger, logging.WARNING),
|
35 |
+
)
|
36 |
+
|
37 |
+
|
38 |
+
def completion_with_retry(llm: KronHuggingFaceHub, **kwargs: Any) -> Any:
|
39 |
+
"""Use tenacity to retry the completion call."""
|
40 |
+
retry_decorator = _create_retry_decorator(llm)
|
41 |
+
|
42 |
+
@retry_decorator
|
43 |
+
def _completion_with_retry(**kwargs: Any) -> Any:
|
44 |
+
return llm.internal_call(**kwargs)
|
45 |
+
|
46 |
+
return _completion_with_retry(**kwargs)
|
47 |
+
|
48 |
+
class KronHFHubRateExceededException(Exception):
|
49 |
+
def __init__(self, message="HF Hub Service Unavailable: Rate exceeded."):
|
50 |
+
self.message = message
|
51 |
+
super().__init__(self.message)
|
52 |
+
|
53 |
+
|
54 |
+
class KronHuggingFaceHub(HuggingFaceHub):
|
55 |
+
|
56 |
+
max_retries: int = 10
|
57 |
+
"""Maximum number of retries to make when generating."""
|
58 |
+
|
59 |
+
def internal_call(
|
60 |
+
self,
|
61 |
+
prompt: str,
|
62 |
+
stop: Optional[List[str]] = None,
|
63 |
+
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
64 |
+
**kwargs: Any,
|
65 |
+
) -> str:
|
66 |
+
try:
|
67 |
+
print(f'**************************************\n{prompt}')
|
68 |
+
response = super()._call(prompt, stop, run_manager, **kwargs)
|
69 |
+
print(f'**************************************\n{response}')
|
70 |
+
return response
|
71 |
+
except ValueError as ve:
|
72 |
+
if "Service Unavailable" in str(ve):
|
73 |
+
raise KronHFHubRateExceededException()
|
74 |
+
else:
|
75 |
+
raise ve
|
76 |
+
|
77 |
+
def _call(
|
78 |
+
self,
|
79 |
+
prompt: str,
|
80 |
+
stop: Optional[List[str]] = None,
|
81 |
+
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
82 |
+
**kwargs: Any,
|
83 |
+
) -> str:
|
84 |
+
response = completion_with_retry(self, prompt=prompt, stop=stop, run_manager=run_manager, **kwargs)
|
85 |
+
return response
|
measurable.py
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from wordcloud import WordCloud, STOPWORDS, ImageColorGenerator
|
2 |
+
import matplotlib.pyplot as plt
|
3 |
+
from PIL import Image
|
4 |
+
import streamlit as st
|
5 |
+
|
6 |
+
def display_wordcloud(answer, answer_str):
|
7 |
+
wc_all, wc_question, wc_reference = st.columns([3, 3, 3])
|
8 |
+
wordcloud = WordCloud(max_font_size=50, max_words=1000, background_color="white")
|
9 |
+
with wc_all:
|
10 |
+
image = Image.open('docs/images/all_papers_wordcloud.png')
|
11 |
+
st.image(image)
|
12 |
+
st.caption('''###### Corpus term frequecy.''')
|
13 |
+
with wc_question:
|
14 |
+
wordcloud_q = wordcloud.generate(answer_str)
|
15 |
+
st.image(wordcloud_q.to_array())
|
16 |
+
st.caption('''###### Answer term frequecy.''')
|
17 |
+
with wc_reference:
|
18 |
+
all_reference_texts = ''
|
19 |
+
for nodewithscore in answer.source_nodes:
|
20 |
+
node = nodewithscore.node
|
21 |
+
from llama_index.schema import NodeRelationship
|
22 |
+
#if NodeRelationship.SOURCE in node.relationships:
|
23 |
+
all_reference_texts = all_reference_texts + '\n' + node.text
|
24 |
+
wordcloud_r = wordcloud.generate(all_reference_texts)
|
25 |
+
st.image(wordcloud_r.to_array())
|
26 |
+
st.caption('''###### Reference plus graph term frequecy.''')
|
requirements.txt
CHANGED
@@ -16,6 +16,7 @@ cohere
|
|
16 |
baseten
|
17 |
st-star-rating
|
18 |
wordcloud
|
|
|
19 |
amazon-dax-client>=1.1.7
|
20 |
boto3>=1.26.79
|
21 |
pytest>=7.2.1
|
|
|
16 |
baseten
|
17 |
st-star-rating
|
18 |
wordcloud
|
19 |
+
gensim
|
20 |
amazon-dax-client>=1.1.7
|
21 |
boto3>=1.26.79
|
22 |
pytest>=7.2.1
|