|
import streamlit as st |
|
|
|
import os |
|
import re |
|
import sys |
|
import time |
|
import base64 |
|
import random |
|
import logging |
|
logging.basicConfig(stream=sys.stdout, level=logging.INFO) |
|
logger = logging.getLogger(__name__) |
|
|
|
from dotenv import load_dotenv |
|
load_dotenv() |
|
|
|
for key in st.session_state.keys(): |
|
|
|
print(f'session state entry: {key} {st.session_state[key]}') |
|
|
|
__spaces__ = os.environ.get('__SPACES__') |
|
|
|
if __spaces__: |
|
from kron.persistence.dynamodb_request_log import get_request_log; |
|
st.session_state.request_log = get_request_log() |
|
|
|
|
|
|
|
hf_api_key = os.environ['HF_TOKEN'] |
|
ch_api_key = os.environ['COHERE_TOKEN'] |
|
bs_api_key = os.environ['BASETEN_TOKEN'] |
|
|
|
|
|
index_model = "Arylwen/instruct-palmyra-20b-gptq-8" |
|
INDEX_NAME = f"{index_model.replace('/', '-')}-default-no-coref" |
|
persist_path = f"storage/{INDEX_NAME}" |
|
MAX_LENGTH = 1024 |
|
MAX_NEW_TOKENS = 250 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def autoplay_video(video_path): |
|
with open(video_path, "rb") as f: |
|
video_content = f.read() |
|
|
|
video_str = f"data:video/mp4;base64,{base64.b64encode(video_content).decode()}" |
|
st.markdown(f""" |
|
<video style="display: block; margin: auto; width: 140px;" controls loop autoplay width="140" height="180"> |
|
<source src="{video_str}" type="video/mp4"> |
|
</video> |
|
""", unsafe_allow_html=True) |
|
|
|
|
|
with st.sidebar: |
|
st.header('KG Questions') |
|
video, text = st.columns([2, 2]) |
|
with video: |
|
autoplay_video('docs/images/kg_construction.mp4') |
|
with text: |
|
st.write( |
|
f''' |
|
###### The construction of a Knowledge Graph is mesmerizing. |
|
###### Concepts in the middle are what most are doing. Are we considering anything different? Why? Why not? |
|
###### Concepts on the edge are what few are doing. Are we considering that? Why? Why not? |
|
''' |
|
) |
|
st.caption('''###### corpus by [@[email protected]](https://sigmoid.social/@ArxivHealthcareNLP)''') |
|
st.caption('''###### KG Questions by [arylwen](https://github.com/arylwen/mlk8s)''') |
|
|
|
from llama_index.core import StorageContext, ServiceContext, load_index_from_storage |
|
|
|
|
|
from llama_index.core.node_parser import SentenceSplitter |
|
|
|
from llama_index.core.service_context_elements.llm_predictor import LLMPredictor |
|
|
|
from langchain import HuggingFaceHub |
|
from langchain.llms.cohere import Cohere |
|
|
|
|
|
import tiktoken |
|
|
|
import openai |
|
|
|
from kron.llm_predictor.KronOpenAILLM import KronOpenAI |
|
|
|
|
|
from kron.llm_predictor.KronLLMPredictor import KronLLMPredictor |
|
|
|
|
|
from llama_index.core.utils import globals_helper |
|
enc = tiktoken.get_encoding("gpt2") |
|
tokenizer = lambda text: enc.encode(text, allowed_special={"<|endoftext|>"}) |
|
globals_helper._tokenizer = tokenizer |
|
|
|
def set_openai_local(): |
|
openai.api_key = os.environ['LOCAL_OPENAI_API_KEY'] |
|
openai.api_base = os.environ['LOCAL_OPENAI_API_BASE'] |
|
os.environ['OPENAI_API_KEY'] = os.environ['LOCAL_OPENAI_API_KEY'] |
|
os.environ['OPENAI_API_BASE'] = os.environ['LOCAL_OPENAI_API_BASE'] |
|
|
|
def set_openai(): |
|
openai.api_key = os.environ['DAVINCI_OPENAI_API_KEY'] |
|
openai.api_base = os.environ['DAVINCI_OPENAI_API_BASE'] |
|
os.environ['OPENAI_API_KEY'] = os.environ['DAVINCI_OPENAI_API_KEY'] |
|
os.environ['OPENAI_API_BASE'] = os.environ['DAVINCI_OPENAI_API_BASE'] |
|
|
|
from kron.llm_predictor.KronHFHubLLM import KronHuggingFaceHub |
|
def get_hf_predictor(query_model): |
|
|
|
set_openai_local() |
|
|
|
llm=KronHuggingFaceHub(repo_id=query_model, task="text-generation", |
|
|
|
model_kwargs={"temperature": 0.01, "max_new_tokens": MAX_NEW_TOKENS }, |
|
huggingfacehub_api_token=hf_api_key) |
|
llm_predictor = LLMPredictor(llm) |
|
return llm_predictor |
|
|
|
def get_cohere_predictor(query_model): |
|
|
|
set_openai_local() |
|
llm=Cohere(model='command', temperature = 0.01, |
|
|
|
cohere_api_key=ch_api_key) |
|
llm_predictor = LLMPredictor(llm) |
|
return llm_predictor |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_kron_openai_predictor(query_model): |
|
|
|
llm=KronOpenAI(temperature=0.01, model=query_model) |
|
llm.max_tokens = MAX_LENGTH |
|
llm_predictor = KronLLMPredictor(llm) |
|
return llm_predictor |
|
|
|
def get_servce_context(llm_predictor): |
|
|
|
text_splitter = SentenceSplitter(chunk_size=192, chunk_overlap=48, paragraph_separator='\n') |
|
|
|
|
|
node_parser = text_splitter |
|
|
|
service_context = ServiceContext.from_defaults(llm_predictor=llm_predictor, node_parser=node_parser) |
|
return service_context |
|
|
|
|
|
@st.cache_data |
|
def get_networkx_graph_nodes(_index, persist_path): |
|
g = _index.get_networkx_graph(100000) |
|
sorted_nodes = sorted(g.degree, key = lambda x: x[1], reverse=True) |
|
return sorted_nodes |
|
|
|
@st.cache_data |
|
def get_networkx_low_connected_components(_index, persist_path): |
|
g = _index.get_networkx_graph(100000) |
|
import networkx as nx |
|
sorted_c = [c for c in sorted(nx.connected_components(g), key=len, reverse=False)] |
|
|
|
low_terms = [] |
|
for c in sorted_c: |
|
for cc in c: |
|
low_terms.extend([cc]) |
|
|
|
return low_terms |
|
|
|
def get_index(service_context, persist_path): |
|
print(f'Loading index from {persist_path}') |
|
|
|
storage_context = StorageContext.from_defaults(persist_dir=persist_path) |
|
|
|
index = load_index_from_storage(storage_context=storage_context, |
|
service_context=service_context, |
|
max_triplets_per_chunk=2, |
|
show_progress = False) |
|
get_networkx_graph_nodes(index, persist_path) |
|
get_networkx_low_connected_components(index, persist_path) |
|
return index |
|
|
|
def get_query_engine(index): |
|
|
|
RESPONSE_MODE = 'accumulate' |
|
query_engine = index.as_query_engine(response_mode = RESPONSE_MODE) |
|
return query_engine |
|
|
|
def load_query_engine(llm_predictor, persist_path): |
|
service_context = get_servce_context(llm_predictor) |
|
index = get_index(service_context, persist_path) |
|
print(f'No query engine for {persist_path}; creating') |
|
query_engine = get_query_engine(index) |
|
return query_engine |
|
|
|
@st.cache_resource |
|
def build_kron_query_engine(query_model, persist_path): |
|
llm_predictor = get_kron_openai_predictor(query_model) |
|
query_engine = load_query_engine(llm_predictor, persist_path) |
|
return query_engine |
|
|
|
@st.cache_resource |
|
def build_hf_query_engine(query_model, persist_path): |
|
llm_predictor = get_hf_predictor(query_model) |
|
query_engine = load_query_engine(llm_predictor, persist_path) |
|
return query_engine |
|
|
|
@st.cache_resource |
|
def build_cohere_query_engine(query_model, persist_path): |
|
llm_predictor = get_cohere_predictor(query_model) |
|
query_engine = load_query_engine(llm_predictor, persist_path) |
|
return query_engine |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def format_response(answer): |
|
|
|
dashes = r'(\-{2,50})' |
|
answer.response = re.sub(dashes, '', answer.response) |
|
return answer.response or "None" |
|
|
|
def clear_question(query_model): |
|
if not ('prev_model' in st.session_state) or (('prev_model' in st.session_state) and (st.session_state.prev_model != query_model)) : |
|
if 'prev_model' in st.session_state: |
|
print(f'clearing question {st.session_state.prev_model} {query_model}') |
|
else: |
|
print(f'clearing question None {query_model}') |
|
if('question_input' in st.session_state): |
|
st.session_state.question = st.session_state.question_input |
|
st.session_state.question_input = '' |
|
st.session_state.question_answered = False |
|
st.session_state.answer = '' |
|
st.session_state.answer_rating = 3 |
|
st.session_state.elapsed = 0 |
|
st.session_state.prev_model = query_model |
|
|
|
query, measurable, explainable, ethical = st.tabs(["Query", "Measurable", "Explainable", "Ethical"]) |
|
|
|
initial_query = '' |
|
|
|
if 'question' not in st.session_state: |
|
st.session_state.question = '' |
|
|
|
if __spaces__ : |
|
with query: |
|
answer_model = st.radio( |
|
"Choose the model used for inference:", |
|
('hf/tiiuae/falcon-7b-instruct', 'cohere/command', 'openai/gpt-3.5-turbo-instruct') |
|
) |
|
else : |
|
with query: |
|
answer_model = st.radio( |
|
"Choose the model used for inference:", |
|
('Writer/camel-5b-hf', 'mosaicml/mpt-7b-instruct', 'hf/tiiuae/falcon-7b-instruct', 'cohere/command', 'baseten/Camel-5b', 'openai/gpt-3.5-turbo-instruct') |
|
) |
|
|
|
if answer_model == 'openai/gpt-3.5-turbo-instruct': |
|
print(answer_model) |
|
query_model = 'gpt-3.5-turbo-instruct' |
|
clear_question(query_model) |
|
set_openai() |
|
query_engine = build_kron_query_engine(query_model, persist_path) |
|
graph_nodes = get_networkx_graph_nodes( "", persist_path) |
|
most_connected = random.sample(graph_nodes[:100], 5) |
|
low_connected = get_networkx_low_connected_components( "", persist_path) |
|
least_connected = random.sample(low_connected, 5) |
|
elif answer_model == 'hf/tiiuae/falcon-7b-instruct': |
|
print(answer_model) |
|
query_model = 'tiiuae/falcon-7b-instruct' |
|
clear_question(query_model) |
|
query_engine = build_hf_query_engine(query_model, persist_path) |
|
graph_nodes = get_networkx_graph_nodes( "", persist_path) |
|
most_connected = random.sample(graph_nodes[:100], 5) |
|
low_connected = get_networkx_low_connected_components( "", persist_path) |
|
least_connected = random.sample(low_connected, 5) |
|
elif answer_model == 'cohere/command': |
|
print(answer_model) |
|
query_model = 'cohere/command' |
|
clear_question(query_model) |
|
query_engine = build_cohere_query_engine(query_model, persist_path) |
|
graph_nodes = get_networkx_graph_nodes( "", persist_path) |
|
most_connected = random.sample(graph_nodes[:100], 5) |
|
low_connected = get_networkx_low_connected_components( "", persist_path) |
|
least_connected = random.sample(low_connected, 5) |
|
elif answer_model == 'baseten/Camel-5b': |
|
print(answer_model) |
|
query_model = 'baseten/Camel-5b' |
|
clear_question(query_model) |
|
query_engine = build_baseten_query_engine(query_model, persist_path) |
|
graph_nodes = get_networkx_graph_nodes( "", persist_path) |
|
most_connected = random.sample(graph_nodes[:100], 5) |
|
low_connected = get_networkx_low_connected_components( "", persist_path) |
|
least_connected = random.sample(low_connected, 5) |
|
elif answer_model == 'Writer/camel-5b-hf': |
|
query_model = 'Writer/camel-5b-hf' |
|
print(answer_model) |
|
clear_question(query_model) |
|
set_openai_local() |
|
query_engine = build_kron_query_engine(query_model, persist_path) |
|
graph_nodes = get_networkx_graph_nodes( "", persist_path) |
|
most_connected = random.sample(graph_nodes[:100], 5) |
|
low_connected = get_networkx_low_connected_components( "", persist_path) |
|
least_connected = random.sample(low_connected, 5) |
|
elif answer_model == 'mosaicml/mpt-7b-instruct': |
|
query_model = 'mosaicml/mpt-7b-instruct' |
|
clear_question(query_model) |
|
query_engine = build_hf_query_engine(query_model, persist_path) |
|
graph_nodes = get_networkx_graph_nodes( "", persist_path) |
|
most_connected = random.sample(graph_nodes[:100], 5) |
|
low_connected = get_networkx_low_connected_components( "", persist_path) |
|
least_connected = random.sample(low_connected, 5) |
|
else: |
|
print('This is a bug.') |
|
|
|
|
|
def submit(): |
|
st.session_state.question = st.session_state.question_input |
|
st.session_state.question_input = '' |
|
st.session_state.question_answered = False |
|
|
|
with st.sidebar: |
|
import gensim |
|
m_connected = [] |
|
for item in most_connected: |
|
if not item[0].lower() in gensim.parsing.preprocessing.STOPWORDS: |
|
m_connected.extend([item[0].lower()]) |
|
option_1 = st.selectbox("What most are studying:", m_connected, disabled=True) |
|
option_2 = st.selectbox("What few are studying:", least_connected, disabled=True) |
|
|
|
with query: |
|
st.caption(f'''###### Intended for educational and research purpose. Please do not enter any private or confidential information. Model, question, answer and rating are logged to improve KG Questions.''') |
|
question = st.text_input("Enter a question, e.g. What benchmarks can we use for QA?", key='question_input', on_change=submit ) |
|
|
|
if(st.session_state.question): |
|
try : |
|
with query: |
|
col1, col2 = st.columns([2, 2]) |
|
if not st.session_state.question_answered: |
|
with st.spinner(f'Answering: {st.session_state.question} with {query_model}.'): |
|
start = time.time() |
|
answer = query_engine.query(st.session_state.question) |
|
st.session_state.answer = answer |
|
st.session_state.question_answered = True |
|
end = time.time() |
|
st.session_state.elapsed = (end-start) |
|
else: |
|
answer = st.session_state.answer |
|
answer_str = format_response(answer) |
|
with col1: |
|
if answer_str: |
|
elapsed = '{:.2f}'.format(st.session_state.elapsed) |
|
st.write(f'Answered: {st.session_state.question} with {query_model} in {elapsed}s. Please rate this answer.') |
|
with col2: |
|
from streamlit_star_rating import st_star_rating |
|
stars = st_star_rating("", maxValue=5, defaultValue=3, key="answer_rating") |
|
st.write(answer_str) |
|
with measurable: |
|
from measurable import display_wordcloud |
|
display_wordcloud(answer, answer_str) |
|
with explainable: |
|
from explainable import explain |
|
explain(answer) |
|
except Exception as e: |
|
answer_str = f'{type(e)}, {e}' |
|
st.session_state.answer_rating = -1 |
|
st.write(f'An error occured, please try again. \n{answer_str}') |
|
finally: |
|
if 'question' in st.session_state: |
|
req = st.session_state.question |
|
if(__spaces__): |
|
st.session_state.request_log.add_request_log_entry(query_model, req, answer_str, st.session_state.answer_rating) |
|
else: |
|
with measurable: |
|
st.write(f'###### Ask a question to see a comparison between the corpus, answer and reference documents.') |
|
with explainable: |
|
st.write(f'###### Ask a question to see the knowledge graph and a list of reference documents.') |
|
with ethical: |
|
from ethics import display_ethics |
|
display_ethics() |
|
|