import streamlit as st |
import os |
import re |
import sys |
import logging |
logging.basicConfig(stream=sys.stdout, level=logging.INFO) |
logger = logging.getLogger(__name__) |
from dotenv import load_dotenv |
load_dotenv() |
os.environ['AWS_DEFAULT_REGION'] = 'us-west-2' |
for key in st.session_state.keys(): |
print(f'session state entry: {key} {st.session_state[key]}') |
__spaces__ = os.environ.get('__SPACES__') |
if __spaces__: |
from kron.persistence.dynamodb_request_log import get_request_log; |
st.session_state.request_log = get_request_log() |
hf_api_key = os.environ['HF_TOKEN'] |
ch_api_key = os.environ['COHERE_TOKEN'] |
bs_api_key = os.environ['BASETEN_TOKEN'] |
index_model = "Writer/camel-5b-hf" |
INDEX_NAME = f"{index_model.replace('/', '-')}-default-no-coref" |
persist_path = f"storage/{INDEX_NAME}" |
MAX_LENGTH = 1024 |
import baseten |
@st.cache_resource |
def set_baseten_key(bs_api_key): |
baseten.login(bs_api_key) |
set_baseten_key(bs_api_key) |
from llama_index import StorageContext |
from llama_index import ServiceContext |
from llama_index import load_index_from_storage |
from llama_index.langchain_helpers.text_splitter import SentenceSplitter |
from llama_index.node_parser import SimpleNodeParser |
from llama_index import LLMPredictor |
from langchain import HuggingFaceHub |
from langchain.llms.cohere import Cohere |
from langchain.llms import Baseten |
import tiktoken |
import openai |
from kron.llm_predictor.KronOpenAILLM import KronOpenAI |
from kron.llm_predictor.KronBasetenCamelLLM import KronBasetenCamelLLM |
from kron.llm_predictor.KronLLMPredictor import KronLLMPredictor |
from llama_index.utils import globals_helper |
enc = tiktoken.get_encoding("gpt2") |
tokenizer = lambda text: enc.encode(text, allowed_special={"<|endoftext|>"}) |
globals_helper._tokenizer = tokenizer |
def set_openai_local(): |
openai.api_key = os.environ['LOCAL_OPENAI_API_KEY'] |
openai.api_base = os.environ['LOCAL_OPENAI_API_BASE'] |
os.environ['OPENAI_API_KEY'] = os.environ['LOCAL_OPENAI_API_KEY'] |
os.environ['OPENAI_API_BASE'] = os.environ['LOCAL_OPENAI_API_BASE'] |
def set_openai(): |
openai.api_key = os.environ['DAVINCI_OPENAI_API_KEY'] |
openai.api_base = os.environ['DAVINCI_OPENAI_API_BASE'] |
os.environ['OPENAI_API_KEY'] = os.environ['DAVINCI_OPENAI_API_KEY'] |
os.environ['OPENAI_API_BASE'] = os.environ['DAVINCI_OPENAI_API_BASE'] |
def get_hf_predictor(query_model): |
set_openai_local() |
llm=HuggingFaceHub(repo_id=query_model, task="text-generation", |
model_kwargs={"temperature": 0.01, "max_length": MAX_LENGTH}, |
huggingfacehub_api_token=hf_api_key) |
llm_predictor = LLMPredictor(llm) |
return llm_predictor |
def get_cohere_predictor(query_model): |
set_openai_local() |
llm=Cohere(model='command', temperature = 0.01, |
cohere_api_key=ch_api_key) |
llm_predictor = LLMPredictor(llm) |
return llm_predictor |
def get_baseten_predictor(query_model): |
set_openai_local() |
llm=KronBasetenCamelLLM(model='3yd1ke3', temperature = 0.01, |
model_kwargs={"temperature": 0.01, "max_length": MAX_LENGTH, 'frequency_penalty':1}, |
cohere_api_key=ch_api_key) |
llm_predictor = LLMPredictor(llm) |
return llm_predictor |
def get_kron_openai_predictor(query_model): |
llm=KronOpenAI(temperature=0.01, model=query_model) |
llm.max_tokens = MAX_LENGTH |
llm_predictor = KronLLMPredictor(llm) |
return llm_predictor |
def get_servce_context(llm_predictor): |
text_splitter = SentenceSplitter(chunk_size=192, chunk_overlap=48, paragraph_separator='\n') |
node_parser = SimpleNodeParser(text_splitter=text_splitter) |
service_context = ServiceContext.from_defaults(llm_predictor=llm_predictor, node_parser=node_parser) |
return service_context |
def get_index(service_context, persist_path): |
print(f'Loading index from {persist_path}') |
storage_context = StorageContext.from_defaults(persist_dir=persist_path) |
index = load_index_from_storage(storage_context=storage_context, |
service_context=service_context, |
max_triplets_per_chunk=2, |
show_progress = False) |
return index |
def get_query_engine(index): |
RESPONSE_MODE = 'accumulate' |
query_engine = index.as_query_engine(response_mode = RESPONSE_MODE) |
return query_engine |
def load_query_engine(llm_predictor, persist_path): |
service_context = get_servce_context(llm_predictor) |
index = get_index(service_context, persist_path) |
print(f'No query engine for {persist_path}; creating') |
query_engine = get_query_engine(index) |
return query_engine |
@st.cache_resource |
def build_kron_query_engine(query_model, persist_path): |
llm_predictor = get_kron_openai_predictor(query_model) |
query_engine = load_query_engine(llm_predictor, persist_path) |
return query_engine |
@st.cache_resource |
def build_hf_query_engine(query_model, persist_path): |
llm_predictor = get_hf_predictor(query_model) |
query_engine = load_query_engine(llm_predictor, persist_path) |
return query_engine |
@st.cache_resource |
def build_cohere_query_engine(query_model, persist_path): |
llm_predictor = get_cohere_predictor(query_model) |
query_engine = load_query_engine(llm_predictor, persist_path) |
return query_engine |
@st.cache_resource |
def build_baseten_query_engine(query_model, persist_path): |
llm_predictor = get_baseten_predictor(query_model) |
query_engine = load_query_engine(llm_predictor, persist_path) |
return query_engine |
def format_response(answer): |
dashes = r'(\-{2,50})' |
answer.response = re.sub(dashes, '', answer.response) |
return answer.response or "None" |
def clear_question(query_model): |
if not ('prev_model' in st.session_state) or (('prev_model' in st.session_state) and (st.session_state.prev_model != query_model)) : |
if 'prev_model' in st.session_state: |
print(f'clearing question {st.session_state.prev_model} {query_model}') |
else: |
print(f'clearing question None {query_model}') |
if('question_input' in st.session_state): |
st.session_state.question = st.session_state.question_input |
st.session_state.question_input = '' |
st.session_state.question_answered = False |
st.session_state.answer = '' |
st.session_state.prev_model = query_model |
initial_query = '' |
if 'question' not in st.session_state: |
st.session_state.question = '' |
if __spaces__ : |
answer_model = st.radio( |
"Choose the model used for inference:", |
('baseten/Camel-5b', 'cohere/command','hf/tiiuae/falcon-7b-instruct', 'openai/text-davinci-003') |
) |
else : |
answer_model = st.radio( |
"Choose the model used for inference:", |
('Local-Camel', 'HF-TKI', 'hf/tiiuae/falcon-7b-instruct', 'openai/text-davinci-003') |
) |
if answer_model == 'openai/text-davinci-003': |
print(answer_model) |
query_model = 'text-davinci-003' |
clear_question(query_model) |
set_openai() |
query_engine = build_kron_query_engine(query_model, persist_path) |
elif answer_model == 'hf/tiiuae/falcon-7b-instruct': |
print(answer_model) |
query_model = 'tiiuae/falcon-7b-instruct' |
clear_question(query_model) |
query_engine = build_hf_query_engine(query_model, persist_path) |
elif answer_model == 'cohere/command': |
print(answer_model) |
query_model = 'cohere/command' |
clear_question(query_model) |
query_engine = build_cohere_query_engine(query_model, persist_path) |
elif answer_model == 'baseten/Camel-5b': |
print(answer_model) |
query_model = 'baseten/Camel-5b' |
clear_question(query_model) |
query_engine = build_baseten_query_engine(query_model, persist_path) |
elif answer_model == 'Local-Camel': |
query_model = 'Writer/camel-5b-hf' |
print(answer_model) |
clear_question(query_model) |
set_openai_local() |
query_engine = build_kron_query_engine(query_model, persist_path) |
elif answer_model == 'HF-TKI': |
query_model = 'allenai/tk-instruct-3b-def-pos-neg-expl' |
clear_question(query_model) |
query_engine = build_hf_query_engine(query_model, persist_path) |
else: |
print('This is a bug.') |
def submit(): |
st.session_state.question = st.session_state.question_input |
st.session_state.question_input = '' |
st.session_state.question_answered = False |
st.write(f'Model, question, answer and rating are logged to help with the improvement of this application.') |
question = st.text_input("Enter a question, e.g. What benchmarks can we use for QA?", key='question_input', on_change=submit ) |
if(st.session_state.question): |
col1, col2 = st.columns([2, 2]) |
with col1: |
st.write(f'Answering: {st.session_state.question} with {query_model}.') |
try : |
if not st.session_state.question_answered: |
answer = query_engine.query(st.session_state.question) |
st.session_state.answer = answer |
st.session_state.question_answered = True |
else: |
answer = st.session_state.answer |
answer_str = format_response(answer) |
st.write(answer_str) |
with col1: |
if answer_str: |
st.write(f' Please rate this answer.') |
with col2: |
from streamlit_star_rating import st_star_rating |
stars = st_star_rating("", maxValue=5, defaultValue=3, key="answer_rating", |
) |
print(f"------stars {stars}") |
except Exception as e: |
print(e) |
answer_str = str(e) |
st.session_state.answer_rating = -1 |
finally: |
if 'question' in st.session_state: |
req = st.session_state.question |
if(__spaces__): |
st.session_state.request_log.add_request_log_entry(query_model, req, answer_str, st.session_state.answer_rating) |