import streamlit as st import os import re import sys import time import base64 import logging logging.basicConfig(stream=sys.stdout, level=logging.INFO) logger = logging.getLogger(__name__) from dotenv import load_dotenv load_dotenv() #os.environ['AWS_DEFAULT_REGION'] = 'us-west-2' for key in st.session_state.keys(): #del st.session_state[key] print(f'session state entry: {key} {st.session_state[key]}') __spaces__ = os.environ.get('__SPACES__') if __spaces__: from kron.persistence.dynamodb_request_log import get_request_log; st.session_state.request_log = get_request_log() #third party service access #hf inference api hf_api_key = os.environ['HF_TOKEN'] ch_api_key = os.environ['COHERE_TOKEN'] bs_api_key = os.environ['BASETEN_TOKEN'] index_model = "Writer/camel-5b-hf" INDEX_NAME = f"{index_model.replace('/', '-')}-default-no-coref" persist_path = f"storage/{INDEX_NAME}" MAX_LENGTH = 1024 import baseten @st.cache_resource def set_baseten_key(bs_api_key): baseten.login(bs_api_key) set_baseten_key(bs_api_key) def autoplay_video(video_path): with open(video_path, "rb") as f: video_content = f.read() video_str = f"data:video/mp4;base64,{base64.b64encode(video_content).decode()}" st.markdown(f""" """, unsafe_allow_html=True) # sidebar with st.sidebar: st.header('KG Questions') video, text = st.columns([2, 2]) with video: autoplay_video('docs/images/kg_construction.mp4') with text: st.write( f''' ###### The construction of a Knowledge Graph is mesmerizing. ###### Concepts in the middle are what most are doing. Are we considering anything different? Why? Why not? ###### Concepts on the edge are what few are doing. Are we considering that? Why? Why not? ''' ) st.caption('''###### corpus by [@ArxivHealthcareNLP@sigmoid.social](https://sigmoid.social/@ArxivHealthcareNLP)''') st.caption('''###### KG Questions by [arylwen](https://github.com/arylwen/mlk8s)''') st.write( f''' #### How can help with ? ''') from llama_index import StorageContext from llama_index import ServiceContext from llama_index import load_index_from_storage from llama_index.langchain_helpers.text_splitter import SentenceSplitter from llama_index.node_parser import SimpleNodeParser from llama_index import LLMPredictor from langchain import HuggingFaceHub from langchain.llms.cohere import Cohere from langchain.llms import Baseten import tiktoken import openai #extensions to llama_index to support openai compatible endpoints, e.g. llama-api from kron.llm_predictor.KronOpenAILLM import KronOpenAI #baseten deployment expects a specific request format from kron.llm_predictor.KronBasetenCamelLLM import KronBasetenCamelLLM from kron.llm_predictor.KronLLMPredictor import KronLLMPredictor #writer/camel uses endoftext from llama_index.utils import globals_helper enc = tiktoken.get_encoding("gpt2") tokenizer = lambda text: enc.encode(text, allowed_special={"<|endoftext|>"}) globals_helper._tokenizer = tokenizer def set_openai_local(): openai.api_key = os.environ['LOCAL_OPENAI_API_KEY'] openai.api_base = os.environ['LOCAL_OPENAI_API_BASE'] os.environ['OPENAI_API_KEY'] = os.environ['LOCAL_OPENAI_API_KEY'] os.environ['OPENAI_API_BASE'] = os.environ['LOCAL_OPENAI_API_BASE'] def set_openai(): openai.api_key = os.environ['DAVINCI_OPENAI_API_KEY'] openai.api_base = os.environ['DAVINCI_OPENAI_API_BASE'] os.environ['OPENAI_API_KEY'] = os.environ['DAVINCI_OPENAI_API_KEY'] os.environ['OPENAI_API_BASE'] = os.environ['DAVINCI_OPENAI_API_BASE'] def get_hf_predictor(query_model): # no embeddings for now set_openai_local() llm=HuggingFaceHub(repo_id=query_model, task="text-generation", model_kwargs={"temperature": 0.01, "max_length": MAX_LENGTH}, huggingfacehub_api_token=hf_api_key) llm_predictor = LLMPredictor(llm) return llm_predictor def get_cohere_predictor(query_model): # no embeddings for now set_openai_local() llm=Cohere(model='command', temperature = 0.01, # model_kwargs={"temperature": 0.01, "max_length": MAX_LENGTH}, cohere_api_key=ch_api_key) llm_predictor = LLMPredictor(llm) return llm_predictor def get_baseten_predictor(query_model): # no embeddings for now set_openai_local() llm=KronBasetenCamelLLM(model='3yd1ke3', temperature = 0.01, # model_kwargs={"temperature": 0.01, "max_length": MAX_LENGTH, 'repetition_penalty':1.07}, model_kwargs={"temperature": 0.01, "max_length": MAX_LENGTH, 'frequency_penalty':1}, cohere_api_key=ch_api_key) llm_predictor = LLMPredictor(llm) return llm_predictor def get_kron_openai_predictor(query_model): # define LLM llm=KronOpenAI(temperature=0.01, model=query_model) llm.max_tokens = MAX_LENGTH llm_predictor = KronLLMPredictor(llm) return llm_predictor def get_servce_context(llm_predictor): # define TextSplitter text_splitter = SentenceSplitter(chunk_size=192, chunk_overlap=48, paragraph_separator='\n') #define NodeParser node_parser = SimpleNodeParser(text_splitter=text_splitter) #define ServiceContext service_context = ServiceContext.from_defaults(llm_predictor=llm_predictor, node_parser=node_parser) return service_context def get_index(service_context, persist_path): print(f'Loading index from {persist_path}') # rebuild storage context storage_context = StorageContext.from_defaults(persist_dir=persist_path) # load index index = load_index_from_storage(storage_context=storage_context, service_context=service_context, max_triplets_per_chunk=2, show_progress = False) return index def get_query_engine(index): #writer/camel does not understand the refine prompt RESPONSE_MODE = 'accumulate' query_engine = index.as_query_engine(response_mode = RESPONSE_MODE) return query_engine def load_query_engine(llm_predictor, persist_path): service_context = get_servce_context(llm_predictor) index = get_index(service_context, persist_path) print(f'No query engine for {persist_path}; creating') query_engine = get_query_engine(index) return query_engine @st.cache_resource def build_kron_query_engine(query_model, persist_path): llm_predictor = get_kron_openai_predictor(query_model) query_engine = load_query_engine(llm_predictor, persist_path) return query_engine @st.cache_resource def build_hf_query_engine(query_model, persist_path): llm_predictor = get_hf_predictor(query_model) query_engine = load_query_engine(llm_predictor, persist_path) return query_engine @st.cache_resource def build_cohere_query_engine(query_model, persist_path): llm_predictor = get_cohere_predictor(query_model) query_engine = load_query_engine(llm_predictor, persist_path) return query_engine @st.cache_resource def build_baseten_query_engine(query_model, persist_path): llm_predictor = get_baseten_predictor(query_model) query_engine = load_query_engine(llm_predictor, persist_path) return query_engine def format_response(answer): # Replace any eventual -- dashes = r'(\-{2,50})' answer.response = re.sub(dashes, '', answer.response) return answer.response or "None" def clear_question(query_model): if not ('prev_model' in st.session_state) or (('prev_model' in st.session_state) and (st.session_state.prev_model != query_model)) : if 'prev_model' in st.session_state: print(f'clearing question {st.session_state.prev_model} {query_model}') else: print(f'clearing question None {query_model}') if('question_input' in st.session_state): st.session_state.question = st.session_state.question_input st.session_state.question_input = '' st.session_state.question_answered = False st.session_state.answer = '' st.session_state.answer_rating = 3 st.session_state.elapsed = 0 st.session_state.prev_model = query_model query, measurable, explainable, ethical = st.tabs(["Query", "Measurable", "Explainable", "Ethical"]) initial_query = '' if 'question' not in st.session_state: st.session_state.question = '' if __spaces__ : with query: answer_model = st.radio( "Choose the model used for inference:", ('baseten/Camel-5b', 'cohere/command','hf/tiiuae/falcon-7b-instruct', 'openai/text-davinci-003') #TODO start hf inference container on demand ) else : with query: answer_model = st.radio( "Choose the model used for inference:", ('Local-Camel', 'HF-TKI', 'hf/tiiuae/falcon-7b-instruct', 'openai/text-davinci-003') ) if answer_model == 'openai/text-davinci-003': print(answer_model) query_model = 'text-davinci-003' clear_question(query_model) set_openai() query_engine = build_kron_query_engine(query_model, persist_path) elif answer_model == 'hf/tiiuae/falcon-7b-instruct': print(answer_model) query_model = 'tiiuae/falcon-7b-instruct' clear_question(query_model) query_engine = build_hf_query_engine(query_model, persist_path) elif answer_model == 'cohere/command': print(answer_model) query_model = 'cohere/command' clear_question(query_model) query_engine = build_cohere_query_engine(query_model, persist_path) elif answer_model == 'baseten/Camel-5b': print(answer_model) query_model = 'baseten/Camel-5b' clear_question(query_model) query_engine = build_baseten_query_engine(query_model, persist_path) elif answer_model == 'Local-Camel': query_model = 'Writer/camel-5b-hf' print(answer_model) clear_question(query_model) set_openai_local() query_engine = build_kron_query_engine(query_model, persist_path) elif answer_model == 'HF-TKI': query_model = 'allenai/tk-instruct-3b-def-pos-neg-expl' clear_question(query_model) query_engine = build_hf_query_engine(query_model, persist_path) else: print('This is a bug.') # to clear input box def submit(): st.session_state.question = st.session_state.question_input st.session_state.question_input = '' st.session_state.question_answered = False with query: st.caption(f'''###### Model, question, answer and rating are logged to improve KG Questions.''') question = st.text_input("Enter a question, e.g. What benchmarks can we use for QA?", key='question_input', on_change=submit ) if(st.session_state.question): #with col1: # st.write(f'Answering: {st.session_state.question} with {query_model}.') try : with query: col1, col2 = st.columns([2, 2]) if not st.session_state.question_answered: with st.spinner(f'Answering: {st.session_state.question} with {query_model}.'): start = time.time() answer = query_engine.query(st.session_state.question) st.session_state.answer = answer st.session_state.question_answered = True end = time.time() st.session_state.elapsed = (end-start) else: answer = st.session_state.answer answer_str = format_response(answer) with col1: if answer_str: elapsed = '{:.2f}'.format(st.session_state.elapsed) st.write(f'Answered: {st.session_state.question} with {query_model} in {elapsed}s. Please rate this answer.') with col2: from streamlit_star_rating import st_star_rating stars = st_star_rating("", maxValue=5, defaultValue=3, key="answer_rating") st.write(answer_str) with measurable: from wordcloud import WordCloud, STOPWORDS, ImageColorGenerator import matplotlib.pyplot as plt from PIL import Image wc_all, wq_question = st.columns([2, 2]) wordcloud = WordCloud(max_font_size=50, max_words=1000, background_color="white").generate(answer_str) with wc_all: #st.write('''### Corpus''') image = Image.open('docs/images/all_papers_wordcloud.png') st.image(image) st.caption('''###### Corpus word frequecy.''') with wq_question: #st.write('''### Question''') st.image(wordcloud.to_array()) st.caption('''###### Query word frequecy.''') with explainable: #st.write(answer.source_nodes) from pyvis.network import Network graph = Network(height="450px", width="100%") sources_table = [] for nodewithscore in answer.source_nodes: node = nodewithscore.node from llama_index.schema import NodeRelationship if NodeRelationship.SOURCE in node.relationships: #st.write(node.relationships[NodeRelationship.SOURCE].node_id) #st.write(node.text) node_id = node.relationships[NodeRelationship.SOURCE].node_id node_id = node_id.split('/')[-1] title = node_id.split('.')[2].replace('_', ' ') sources_table.extend([[title, node.text]]) else: #st.write(node.metadata['kg_rel_map']) #st.write(node.text) TODO rel_map = node.metadata['kg_rel_map'] for concept in rel_map.keys(): #st.write(concept) graph.add_node(concept, concept, title=concept) rels = rel_map[concept] for rel in rels: graph.add_node(rel[1], rel[1], title=rel[1]) graph.add_edge(concept, rel[1], title=rel[0]) st.session_state.graph_name = 'graph.html' graph.save_graph(st.session_state.graph_name) import streamlit.components.v1 as components graphHtml = open(st.session_state.graph_name, 'r', encoding='utf-8') source_code = graphHtml.read() #print(source_code) components.html(source_code, height = 500) import pandas as pd df = pd.DataFrame(sources_table) df.columns = ['paper', 'relevant text'] st.markdown(""" """, unsafe_allow_html=True) st.table(df) #st.write(answer.source_nodes[0].node) except Exception as e: #print(f'{type(e)}, {e}') answer_str = f'{type(e)}, {e}' st.session_state.answer_rating = -1 st.write(f'An error occured, please try again. \n{answer_str}') finally: if 'question' in st.session_state: req = st.session_state.question if(__spaces__): st.session_state.request_log.add_request_log_entry(query_model, req, answer_str, st.session_state.answer_rating)