mlk8s / app.py
Arylwen's picture
0.0.7 relevant documents and graph
96d80d4
raw
history blame
15.6 kB
import streamlit as st
import os
import re
import sys
import time
import base64
import logging
logging.basicConfig(stream=sys.stdout, level=logging.INFO)
logger = logging.getLogger(__name__)
from dotenv import load_dotenv
load_dotenv()
#os.environ['AWS_DEFAULT_REGION'] = 'us-west-2'
for key in st.session_state.keys():
#del st.session_state[key]
print(f'session state entry: {key} {st.session_state[key]}')
__spaces__ = os.environ.get('__SPACES__')
if __spaces__:
from kron.persistence.dynamodb_request_log import get_request_log;
st.session_state.request_log = get_request_log()
#third party service access
#hf inference api
hf_api_key = os.environ['HF_TOKEN']
ch_api_key = os.environ['COHERE_TOKEN']
bs_api_key = os.environ['BASETEN_TOKEN']
index_model = "Writer/camel-5b-hf"
INDEX_NAME = f"{index_model.replace('/', '-')}-default-no-coref"
persist_path = f"storage/{INDEX_NAME}"
MAX_LENGTH = 1024
import baseten
@st.cache_resource
def set_baseten_key(bs_api_key):
baseten.login(bs_api_key)
set_baseten_key(bs_api_key)
def autoplay_video(video_path):
with open(video_path, "rb") as f:
video_content = f.read()
video_str = f"data:video/mp4;base64,{base64.b64encode(video_content).decode()}"
st.markdown(f"""
<video style="display: block; margin: auto; width: 140px;" controls loop autoplay width="140" height="180">
<source src="{video_str}" type="video/mp4">
</video>
""", unsafe_allow_html=True)
# sidebar
with st.sidebar:
st.header('KG Questions')
video, text = st.columns([2, 2])
with video:
autoplay_video('docs/images/kg_construction.mp4')
with text:
st.write(
f'''
###### The construction of a Knowledge Graph is mesmerizing.
###### Concepts in the middle are what most are doing. Are we considering anything different? Why? Why not?
###### Concepts on the edge are what few are doing. Are we considering that? Why? Why not?
'''
)
st.caption('''###### corpus by [@[email protected]](https://sigmoid.social/@ArxivHealthcareNLP)''')
st.caption('''###### KG Questions by [arylwen](https://github.com/arylwen/mlk8s)''')
st.write(
f'''
#### How can <what most are doing> help with <what few are doing>?
''')
from llama_index import StorageContext
from llama_index import ServiceContext
from llama_index import load_index_from_storage
from llama_index.langchain_helpers.text_splitter import SentenceSplitter
from llama_index.node_parser import SimpleNodeParser
from llama_index import LLMPredictor
from langchain import HuggingFaceHub
from langchain.llms.cohere import Cohere
from langchain.llms import Baseten
import tiktoken
import openai
#extensions to llama_index to support openai compatible endpoints, e.g. llama-api
from kron.llm_predictor.KronOpenAILLM import KronOpenAI
#baseten deployment expects a specific request format
from kron.llm_predictor.KronBasetenCamelLLM import KronBasetenCamelLLM
from kron.llm_predictor.KronLLMPredictor import KronLLMPredictor
#writer/camel uses endoftext
from llama_index.utils import globals_helper
enc = tiktoken.get_encoding("gpt2")
tokenizer = lambda text: enc.encode(text, allowed_special={"<|endoftext|>"})
globals_helper._tokenizer = tokenizer
def set_openai_local():
openai.api_key = os.environ['LOCAL_OPENAI_API_KEY']
openai.api_base = os.environ['LOCAL_OPENAI_API_BASE']
os.environ['OPENAI_API_KEY'] = os.environ['LOCAL_OPENAI_API_KEY']
os.environ['OPENAI_API_BASE'] = os.environ['LOCAL_OPENAI_API_BASE']
def set_openai():
openai.api_key = os.environ['DAVINCI_OPENAI_API_KEY']
openai.api_base = os.environ['DAVINCI_OPENAI_API_BASE']
os.environ['OPENAI_API_KEY'] = os.environ['DAVINCI_OPENAI_API_KEY']
os.environ['OPENAI_API_BASE'] = os.environ['DAVINCI_OPENAI_API_BASE']
def get_hf_predictor(query_model):
# no embeddings for now
set_openai_local()
llm=HuggingFaceHub(repo_id=query_model, task="text-generation",
model_kwargs={"temperature": 0.01, "max_length": MAX_LENGTH},
huggingfacehub_api_token=hf_api_key)
llm_predictor = LLMPredictor(llm)
return llm_predictor
def get_cohere_predictor(query_model):
# no embeddings for now
set_openai_local()
llm=Cohere(model='command', temperature = 0.01,
# model_kwargs={"temperature": 0.01, "max_length": MAX_LENGTH},
cohere_api_key=ch_api_key)
llm_predictor = LLMPredictor(llm)
return llm_predictor
def get_baseten_predictor(query_model):
# no embeddings for now
set_openai_local()
llm=KronBasetenCamelLLM(model='3yd1ke3', temperature = 0.01,
# model_kwargs={"temperature": 0.01, "max_length": MAX_LENGTH, 'repetition_penalty':1.07},
model_kwargs={"temperature": 0.01, "max_length": MAX_LENGTH, 'frequency_penalty':1},
cohere_api_key=ch_api_key)
llm_predictor = LLMPredictor(llm)
return llm_predictor
def get_kron_openai_predictor(query_model):
# define LLM
llm=KronOpenAI(temperature=0.01, model=query_model)
llm.max_tokens = MAX_LENGTH
llm_predictor = KronLLMPredictor(llm)
return llm_predictor
def get_servce_context(llm_predictor):
# define TextSplitter
text_splitter = SentenceSplitter(chunk_size=192, chunk_overlap=48, paragraph_separator='\n')
#define NodeParser
node_parser = SimpleNodeParser(text_splitter=text_splitter)
#define ServiceContext
service_context = ServiceContext.from_defaults(llm_predictor=llm_predictor, node_parser=node_parser)
return service_context
def get_index(service_context, persist_path):
print(f'Loading index from {persist_path}')
# rebuild storage context
storage_context = StorageContext.from_defaults(persist_dir=persist_path)
# load index
index = load_index_from_storage(storage_context=storage_context,
service_context=service_context,
max_triplets_per_chunk=2,
show_progress = False)
return index
def get_query_engine(index):
#writer/camel does not understand the refine prompt
RESPONSE_MODE = 'accumulate'
query_engine = index.as_query_engine(response_mode = RESPONSE_MODE)
return query_engine
def load_query_engine(llm_predictor, persist_path):
service_context = get_servce_context(llm_predictor)
index = get_index(service_context, persist_path)
print(f'No query engine for {persist_path}; creating')
query_engine = get_query_engine(index)
return query_engine
@st.cache_resource
def build_kron_query_engine(query_model, persist_path):
llm_predictor = get_kron_openai_predictor(query_model)
query_engine = load_query_engine(llm_predictor, persist_path)
return query_engine
@st.cache_resource
def build_hf_query_engine(query_model, persist_path):
llm_predictor = get_hf_predictor(query_model)
query_engine = load_query_engine(llm_predictor, persist_path)
return query_engine
@st.cache_resource
def build_cohere_query_engine(query_model, persist_path):
llm_predictor = get_cohere_predictor(query_model)
query_engine = load_query_engine(llm_predictor, persist_path)
return query_engine
@st.cache_resource
def build_baseten_query_engine(query_model, persist_path):
llm_predictor = get_baseten_predictor(query_model)
query_engine = load_query_engine(llm_predictor, persist_path)
return query_engine
def format_response(answer):
# Replace any eventual --
dashes = r'(\-{2,50})'
answer.response = re.sub(dashes, '', answer.response)
return answer.response or "None"
def clear_question(query_model):
if not ('prev_model' in st.session_state) or (('prev_model' in st.session_state) and (st.session_state.prev_model != query_model)) :
if 'prev_model' in st.session_state:
print(f'clearing question {st.session_state.prev_model} {query_model}')
else:
print(f'clearing question None {query_model}')
if('question_input' in st.session_state):
st.session_state.question = st.session_state.question_input
st.session_state.question_input = ''
st.session_state.question_answered = False
st.session_state.answer = ''
st.session_state.answer_rating = 3
st.session_state.elapsed = 0
st.session_state.prev_model = query_model
query, measurable, explainable, ethical = st.tabs(["Query", "Measurable", "Explainable", "Ethical"])
initial_query = ''
if 'question' not in st.session_state:
st.session_state.question = ''
if __spaces__ :
with query:
answer_model = st.radio(
"Choose the model used for inference:",
('baseten/Camel-5b', 'cohere/command','hf/tiiuae/falcon-7b-instruct', 'openai/text-davinci-003') #TODO start hf inference container on demand
)
else :
with query:
answer_model = st.radio(
"Choose the model used for inference:",
('Local-Camel', 'HF-TKI', 'hf/tiiuae/falcon-7b-instruct', 'openai/text-davinci-003')
)
if answer_model == 'openai/text-davinci-003':
print(answer_model)
query_model = 'text-davinci-003'
clear_question(query_model)
set_openai()
query_engine = build_kron_query_engine(query_model, persist_path)
elif answer_model == 'hf/tiiuae/falcon-7b-instruct':
print(answer_model)
query_model = 'tiiuae/falcon-7b-instruct'
clear_question(query_model)
query_engine = build_hf_query_engine(query_model, persist_path)
elif answer_model == 'cohere/command':
print(answer_model)
query_model = 'cohere/command'
clear_question(query_model)
query_engine = build_cohere_query_engine(query_model, persist_path)
elif answer_model == 'baseten/Camel-5b':
print(answer_model)
query_model = 'baseten/Camel-5b'
clear_question(query_model)
query_engine = build_baseten_query_engine(query_model, persist_path)
elif answer_model == 'Local-Camel':
query_model = 'Writer/camel-5b-hf'
print(answer_model)
clear_question(query_model)
set_openai_local()
query_engine = build_kron_query_engine(query_model, persist_path)
elif answer_model == 'HF-TKI':
query_model = 'allenai/tk-instruct-3b-def-pos-neg-expl'
clear_question(query_model)
query_engine = build_hf_query_engine(query_model, persist_path)
else:
print('This is a bug.')
# to clear input box
def submit():
st.session_state.question = st.session_state.question_input
st.session_state.question_input = ''
st.session_state.question_answered = False
with query:
st.caption(f'''###### Model, question, answer and rating are logged to improve KG Questions.''')
question = st.text_input("Enter a question, e.g. What benchmarks can we use for QA?", key='question_input', on_change=submit )
if(st.session_state.question):
#with col1:
# st.write(f'Answering: {st.session_state.question} with {query_model}.')
try :
with query:
col1, col2 = st.columns([2, 2])
if not st.session_state.question_answered:
with st.spinner(f'Answering: {st.session_state.question} with {query_model}.'):
start = time.time()
answer = query_engine.query(st.session_state.question)
st.session_state.answer = answer
st.session_state.question_answered = True
end = time.time()
st.session_state.elapsed = (end-start)
else:
answer = st.session_state.answer
answer_str = format_response(answer)
with col1:
if answer_str:
elapsed = '{:.2f}'.format(st.session_state.elapsed)
st.write(f'Answered: {st.session_state.question} with {query_model} in {elapsed}s. Please rate this answer.')
with col2:
from streamlit_star_rating import st_star_rating
stars = st_star_rating("", maxValue=5, defaultValue=3, key="answer_rating")
st.write(answer_str)
with measurable:
from wordcloud import WordCloud, STOPWORDS, ImageColorGenerator
import matplotlib.pyplot as plt
from PIL import Image
wc_all, wq_question = st.columns([2, 2])
wordcloud = WordCloud(max_font_size=50, max_words=1000, background_color="white").generate(answer_str)
with wc_all:
#st.write('''### Corpus''')
image = Image.open('docs/images/all_papers_wordcloud.png')
st.image(image)
st.caption('''###### Corpus word frequecy.''')
with wq_question:
#st.write('''### Question''')
st.image(wordcloud.to_array())
st.caption('''###### Query word frequecy.''')
with explainable:
#st.write(answer.source_nodes)
from pyvis.network import Network
graph = Network(height="450px", width="100%")
sources_table = []
for nodewithscore in answer.source_nodes:
node = nodewithscore.node
from llama_index.schema import NodeRelationship
if NodeRelationship.SOURCE in node.relationships:
#st.write(node.relationships[NodeRelationship.SOURCE].node_id)
#st.write(node.text)
node_id = node.relationships[NodeRelationship.SOURCE].node_id
node_id = node_id.split('/')[-1]
title = node_id.split('.')[2].replace('_', ' ')
sources_table.extend([[title, node.text]])
else:
#st.write(node.metadata['kg_rel_map'])
#st.write(node.text) TODO
rel_map = node.metadata['kg_rel_map']
for concept in rel_map.keys():
#st.write(concept)
graph.add_node(concept, concept, title=concept)
rels = rel_map[concept]
for rel in rels:
graph.add_node(rel[1], rel[1], title=rel[1])
graph.add_edge(concept, rel[1], title=rel[0])
st.session_state.graph_name = 'graph.html'
graph.save_graph(st.session_state.graph_name)
import streamlit.components.v1 as components
graphHtml = open(st.session_state.graph_name, 'r', encoding='utf-8')
source_code = graphHtml.read()
#print(source_code)
components.html(source_code, height = 500)
import pandas as pd
df = pd.DataFrame(sources_table)
df.columns = ['paper', 'relevant text']
st.markdown(""" <style> .font {
font-size:8px;}
</style> """, unsafe_allow_html=True)
st.table(df)
#st.write(answer.source_nodes[0].node)
except Exception as e:
#print(f'{type(e)}, {e}')
answer_str = f'{type(e)}, {e}'
st.session_state.answer_rating = -1
st.write(f'An error occured, please try again. \n{answer_str}')
finally:
if 'question' in st.session_state:
req = st.session_state.question
if(__spaces__):
st.session_state.request_log.add_request_log_entry(query_model, req, answer_str, st.session_state.answer_rating)