mlk8s / app.py
Arylwen's picture
updates llama_index to 0.10.15
a553e02
import streamlit as st
import os
import re
import sys
import time
import base64
import random
import logging
logging.basicConfig(stream=sys.stdout, level=logging.INFO)
logger = logging.getLogger(__name__)
from dotenv import load_dotenv
load_dotenv()
for key in st.session_state.keys():
#del st.session_state[key]
print(f'session state entry: {key} {st.session_state[key]}')
__spaces__ = os.environ.get('__SPACES__')
if __spaces__:
from kron.persistence.dynamodb_request_log import get_request_log;
st.session_state.request_log = get_request_log()
#third party service access
#hf inference api
hf_api_key = os.environ['HF_TOKEN']
ch_api_key = os.environ['COHERE_TOKEN']
bs_api_key = os.environ['BASETEN_TOKEN']
#index_model = "Writer/camel-5b-hf"
index_model = "Arylwen/instruct-palmyra-20b-gptq-8"
INDEX_NAME = f"{index_model.replace('/', '-')}-default-no-coref"
persist_path = f"storage/{INDEX_NAME}"
MAX_LENGTH = 1024
MAX_NEW_TOKENS = 250
#import baseten
#@st.cache_resource
#def set_baseten_key(bs_api_key):
# baseten.login(bs_api_key)
#set_baseten_key(bs_api_key)
def autoplay_video(video_path):
with open(video_path, "rb") as f:
video_content = f.read()
video_str = f"data:video/mp4;base64,{base64.b64encode(video_content).decode()}"
st.markdown(f"""
<video style="display: block; margin: auto; width: 140px;" controls loop autoplay width="140" height="180">
<source src="{video_str}" type="video/mp4">
</video>
""", unsafe_allow_html=True)
# sidebar
with st.sidebar:
st.header('KG Questions')
video, text = st.columns([2, 2])
with video:
autoplay_video('docs/images/kg_construction.mp4')
with text:
st.write(
f'''
###### The construction of a Knowledge Graph is mesmerizing.
###### Concepts in the middle are what most are doing. Are we considering anything different? Why? Why not?
###### Concepts on the edge are what few are doing. Are we considering that? Why? Why not?
'''
)
st.caption('''###### corpus by [@[email protected]](https://sigmoid.social/@ArxivHealthcareNLP)''')
st.caption('''###### KG Questions by [arylwen](https://github.com/arylwen/mlk8s)''')
from llama_index.core import StorageContext, ServiceContext, load_index_from_storage
#from llama_index import ServiceContext
# from llama_index import load_index_from_storage
from llama_index.core.node_parser import SentenceSplitter
#from llama_index.node_parser import SimpleNodeParser
from llama_index.core.service_context_elements.llm_predictor import LLMPredictor
from langchain import HuggingFaceHub
from langchain.llms.cohere import Cohere
#from langchain.llms import Baseten
import tiktoken
import openai
#extensions to llama_index to support openai compatible endpoints, e.g. llama-api
from kron.llm_predictor.KronOpenAILLM import KronOpenAI
#baseten deployment expects a specific request format
#from kron.llm_predictor.KronBasetenCamelLLM import KronBasetenCamelLLM
from kron.llm_predictor.KronLLMPredictor import KronLLMPredictor
#writer/camel uses endoftext
from llama_index.core.utils import globals_helper
enc = tiktoken.get_encoding("gpt2")
tokenizer = lambda text: enc.encode(text, allowed_special={"<|endoftext|>"})
globals_helper._tokenizer = tokenizer
def set_openai_local():
openai.api_key = os.environ['LOCAL_OPENAI_API_KEY']
openai.api_base = os.environ['LOCAL_OPENAI_API_BASE']
os.environ['OPENAI_API_KEY'] = os.environ['LOCAL_OPENAI_API_KEY']
os.environ['OPENAI_API_BASE'] = os.environ['LOCAL_OPENAI_API_BASE']
def set_openai():
openai.api_key = os.environ['DAVINCI_OPENAI_API_KEY']
openai.api_base = os.environ['DAVINCI_OPENAI_API_BASE']
os.environ['OPENAI_API_KEY'] = os.environ['DAVINCI_OPENAI_API_KEY']
os.environ['OPENAI_API_BASE'] = os.environ['DAVINCI_OPENAI_API_BASE']
from kron.llm_predictor.KronHFHubLLM import KronHuggingFaceHub
def get_hf_predictor(query_model):
# no embeddings for now
set_openai_local()
#llm=HuggingFaceHub(repo_id=query_model, task="text-generation",
llm=KronHuggingFaceHub(repo_id=query_model, task="text-generation",
# model_kwargs={"temperature": 0.01, "max_new_tokens": MAX_NEW_TOKENS, 'frequency_penalty':1.17},
model_kwargs={"temperature": 0.01, "max_new_tokens": MAX_NEW_TOKENS },
huggingfacehub_api_token=hf_api_key)
llm_predictor = LLMPredictor(llm)
return llm_predictor
def get_cohere_predictor(query_model):
# no embeddings for now
set_openai_local()
llm=Cohere(model='command', temperature = 0.01,
# model_kwargs={"temperature": 0.01, "max_length": MAX_LENGTH},
cohere_api_key=ch_api_key)
llm_predictor = LLMPredictor(llm)
return llm_predictor
#def get_baseten_predictor(query_model):
# # no embeddings for now
# set_openai_local()
# llm=KronBasetenCamelLLM(model='3yd1ke3', temperature = 0.01,
# model_kwargs={"temperature": 0.01, "max_length": MAX_LENGTH, 'repetition_penalty':1.07},
# model_kwargs={"temperature": 0.01, "max_length": MAX_LENGTH, 'frequency_penalty':1},
# cohere_api_key=ch_api_key)
# llm_predictor = LLMPredictor(llm)
# return llm_predictor
def get_kron_openai_predictor(query_model):
# define LLM
llm=KronOpenAI(temperature=0.01, model=query_model)
llm.max_tokens = MAX_LENGTH
llm_predictor = KronLLMPredictor(llm)
return llm_predictor
def get_servce_context(llm_predictor):
# define TextSplitter
text_splitter = SentenceSplitter(chunk_size=192, chunk_overlap=48, paragraph_separator='\n')
#define NodeParser
#node_parser = SimpleNodeParser(text_splitter=text_splitter)
node_parser = text_splitter
#define ServiceContext
service_context = ServiceContext.from_defaults(llm_predictor=llm_predictor, node_parser=node_parser)
return service_context
# hack - on subsequent calls we can pass anything as index
@st.cache_data
def get_networkx_graph_nodes(_index, persist_path):
g = _index.get_networkx_graph(100000)
sorted_nodes = sorted(g.degree, key = lambda x: x[1], reverse=True)
return sorted_nodes
@st.cache_data
def get_networkx_low_connected_components(_index, persist_path):
g = _index.get_networkx_graph(100000)
import networkx as nx
sorted_c = [c for c in sorted(nx.connected_components(g), key=len, reverse=False)]
#print(sorted_c[:100])
low_terms = []
for c in sorted_c:
for cc in c:
low_terms.extend([cc])
#print(low_terms)
return low_terms
def get_index(service_context, persist_path):
print(f'Loading index from {persist_path}')
# rebuild storage context
storage_context = StorageContext.from_defaults(persist_dir=persist_path)
# load index
index = load_index_from_storage(storage_context=storage_context,
service_context=service_context,
max_triplets_per_chunk=2,
show_progress = False)
get_networkx_graph_nodes(index, persist_path)
get_networkx_low_connected_components(index, persist_path)
return index
def get_query_engine(index):
#writer/camel does not understand the refine prompt
RESPONSE_MODE = 'accumulate'
query_engine = index.as_query_engine(response_mode = RESPONSE_MODE)
return query_engine
def load_query_engine(llm_predictor, persist_path):
service_context = get_servce_context(llm_predictor)
index = get_index(service_context, persist_path)
print(f'No query engine for {persist_path}; creating')
query_engine = get_query_engine(index)
return query_engine
@st.cache_resource
def build_kron_query_engine(query_model, persist_path):
llm_predictor = get_kron_openai_predictor(query_model)
query_engine = load_query_engine(llm_predictor, persist_path)
return query_engine
@st.cache_resource
def build_hf_query_engine(query_model, persist_path):
llm_predictor = get_hf_predictor(query_model)
query_engine = load_query_engine(llm_predictor, persist_path)
return query_engine
@st.cache_resource
def build_cohere_query_engine(query_model, persist_path):
llm_predictor = get_cohere_predictor(query_model)
query_engine = load_query_engine(llm_predictor, persist_path)
return query_engine
#@st.cache_resource
#def build_baseten_query_engine(query_model, persist_path):
# llm_predictor = get_baseten_predictor(query_model)
# query_engine = load_query_engine(llm_predictor, persist_path)
# return query_engine
def format_response(answer):
# Replace any eventual --
dashes = r'(\-{2,50})'
answer.response = re.sub(dashes, '', answer.response)
return answer.response or "None"
def clear_question(query_model):
if not ('prev_model' in st.session_state) or (('prev_model' in st.session_state) and (st.session_state.prev_model != query_model)) :
if 'prev_model' in st.session_state:
print(f'clearing question {st.session_state.prev_model} {query_model}')
else:
print(f'clearing question None {query_model}')
if('question_input' in st.session_state):
st.session_state.question = st.session_state.question_input
st.session_state.question_input = ''
st.session_state.question_answered = False
st.session_state.answer = ''
st.session_state.answer_rating = 3
st.session_state.elapsed = 0
st.session_state.prev_model = query_model
query, measurable, explainable, ethical = st.tabs(["Query", "Measurable", "Explainable", "Ethical"])
initial_query = ''
if 'question' not in st.session_state:
st.session_state.question = ''
if __spaces__ :
with query:
answer_model = st.radio(
"Choose the model used for inference:",
('hf/tiiuae/falcon-7b-instruct', 'cohere/command', 'openai/gpt-3.5-turbo-instruct') #TODO start hf inference container on demand
)
else :
with query:
answer_model = st.radio(
"Choose the model used for inference:",
('Writer/camel-5b-hf', 'mosaicml/mpt-7b-instruct', 'hf/tiiuae/falcon-7b-instruct', 'cohere/command', 'baseten/Camel-5b', 'openai/gpt-3.5-turbo-instruct')
)
if answer_model == 'openai/gpt-3.5-turbo-instruct':
print(answer_model)
query_model = 'gpt-3.5-turbo-instruct'
clear_question(query_model)
set_openai()
query_engine = build_kron_query_engine(query_model, persist_path)
graph_nodes = get_networkx_graph_nodes( "", persist_path)
most_connected = random.sample(graph_nodes[:100], 5)
low_connected = get_networkx_low_connected_components( "", persist_path)
least_connected = random.sample(low_connected, 5)
elif answer_model == 'hf/tiiuae/falcon-7b-instruct':
print(answer_model)
query_model = 'tiiuae/falcon-7b-instruct'
clear_question(query_model)
query_engine = build_hf_query_engine(query_model, persist_path)
graph_nodes = get_networkx_graph_nodes( "", persist_path)
most_connected = random.sample(graph_nodes[:100], 5)
low_connected = get_networkx_low_connected_components( "", persist_path)
least_connected = random.sample(low_connected, 5)
elif answer_model == 'cohere/command':
print(answer_model)
query_model = 'cohere/command'
clear_question(query_model)
query_engine = build_cohere_query_engine(query_model, persist_path)
graph_nodes = get_networkx_graph_nodes( "", persist_path)
most_connected = random.sample(graph_nodes[:100], 5)
low_connected = get_networkx_low_connected_components( "", persist_path)
least_connected = random.sample(low_connected, 5)
elif answer_model == 'baseten/Camel-5b':
print(answer_model)
query_model = 'baseten/Camel-5b'
clear_question(query_model)
query_engine = build_baseten_query_engine(query_model, persist_path)
graph_nodes = get_networkx_graph_nodes( "", persist_path)
most_connected = random.sample(graph_nodes[:100], 5)
low_connected = get_networkx_low_connected_components( "", persist_path)
least_connected = random.sample(low_connected, 5)
elif answer_model == 'Writer/camel-5b-hf':
query_model = 'Writer/camel-5b-hf'
print(answer_model)
clear_question(query_model)
set_openai_local()
query_engine = build_kron_query_engine(query_model, persist_path)
graph_nodes = get_networkx_graph_nodes( "", persist_path)
most_connected = random.sample(graph_nodes[:100], 5)
low_connected = get_networkx_low_connected_components( "", persist_path)
least_connected = random.sample(low_connected, 5)
elif answer_model == 'mosaicml/mpt-7b-instruct':
query_model = 'mosaicml/mpt-7b-instruct'
clear_question(query_model)
query_engine = build_hf_query_engine(query_model, persist_path)
graph_nodes = get_networkx_graph_nodes( "", persist_path)
most_connected = random.sample(graph_nodes[:100], 5)
low_connected = get_networkx_low_connected_components( "", persist_path)
least_connected = random.sample(low_connected, 5)
else:
print('This is a bug.')
# to clear the input box
def submit():
st.session_state.question = st.session_state.question_input
st.session_state.question_input = ''
st.session_state.question_answered = False
with st.sidebar:
import gensim
m_connected = []
for item in most_connected:
if not item[0].lower() in gensim.parsing.preprocessing.STOPWORDS:
m_connected.extend([item[0].lower()])
option_1 = st.selectbox("What most are studying:", m_connected, disabled=True)
option_2 = st.selectbox("What few are studying:", least_connected, disabled=True)
with query:
st.caption(f'''###### Intended for educational and research purpose. Please do not enter any private or confidential information. Model, question, answer and rating are logged to improve KG Questions.''')
question = st.text_input("Enter a question, e.g. What benchmarks can we use for QA?", key='question_input', on_change=submit )
if(st.session_state.question):
try :
with query:
col1, col2 = st.columns([2, 2])
if not st.session_state.question_answered:
with st.spinner(f'Answering: {st.session_state.question} with {query_model}.'):
start = time.time()
answer = query_engine.query(st.session_state.question)
st.session_state.answer = answer
st.session_state.question_answered = True
end = time.time()
st.session_state.elapsed = (end-start)
else:
answer = st.session_state.answer
answer_str = format_response(answer)
with col1:
if answer_str:
elapsed = '{:.2f}'.format(st.session_state.elapsed)
st.write(f'Answered: {st.session_state.question} with {query_model} in {elapsed}s. Please rate this answer.')
with col2:
from streamlit_star_rating import st_star_rating
stars = st_star_rating("", maxValue=5, defaultValue=3, key="answer_rating")
st.write(answer_str)
with measurable:
from measurable import display_wordcloud
display_wordcloud(answer, answer_str)
with explainable:
from explainable import explain
explain(answer)
except Exception as e:
answer_str = f'{type(e)}, {e}'
st.session_state.answer_rating = -1
st.write(f'An error occured, please try again. \n{answer_str}')
finally:
if 'question' in st.session_state:
req = st.session_state.question
if(__spaces__):
st.session_state.request_log.add_request_log_entry(query_model, req, answer_str, st.session_state.answer_rating)
else:
with measurable:
st.write(f'###### Ask a question to see a comparison between the corpus, answer and reference documents.')
with explainable:
st.write(f'###### Ask a question to see the knowledge graph and a list of reference documents.')
with ethical:
from ethics import display_ethics
display_ethics()