|
import streamlit as st |
|
|
|
import os |
|
import re |
|
import sys |
|
import logging |
|
logging.basicConfig(stream=sys.stdout, level=logging.INFO) |
|
logger = logging.getLogger(__name__) |
|
|
|
from dotenv import load_dotenv |
|
load_dotenv() |
|
|
|
os.environ['AWS_DEFAULT_REGION'] = 'us-west-2' |
|
|
|
for key in st.session_state.keys(): |
|
|
|
print(f'session state entry: {key} {st.session_state[key]}') |
|
|
|
__spaces__ = os.environ.get('__SPACES__') |
|
|
|
if __spaces__: |
|
from kron.persistence.dynamodb_request_log import get_request_log; |
|
st.session_state.request_log = get_request_log() |
|
|
|
|
|
|
|
hf_api_key = os.environ['HF_TOKEN'] |
|
ch_api_key = os.environ['COHERE_TOKEN'] |
|
bs_api_key = os.environ['BASETEN_TOKEN'] |
|
|
|
index_model = "Writer/camel-5b-hf" |
|
INDEX_NAME = f"{index_model.replace('/', '-')}-default-no-coref" |
|
persist_path = f"storage/{INDEX_NAME}" |
|
MAX_LENGTH = 1024 |
|
|
|
import baseten |
|
@st.cache_resource |
|
def set_baseten_key(bs_api_key): |
|
baseten.login(bs_api_key) |
|
|
|
set_baseten_key(bs_api_key) |
|
|
|
from llama_index import StorageContext |
|
from llama_index import ServiceContext |
|
from llama_index import load_index_from_storage |
|
from llama_index.langchain_helpers.text_splitter import SentenceSplitter |
|
from llama_index.node_parser import SimpleNodeParser |
|
from llama_index import LLMPredictor |
|
|
|
from langchain import HuggingFaceHub |
|
from langchain.llms.cohere import Cohere |
|
from langchain.llms import Baseten |
|
|
|
import tiktoken |
|
|
|
import openai |
|
|
|
from kron.llm_predictor.KronOpenAILLM import KronOpenAI |
|
|
|
from kron.llm_predictor.KronBasetenCamelLLM import KronBasetenCamelLLM |
|
from kron.llm_predictor.KronLLMPredictor import KronLLMPredictor |
|
|
|
|
|
from llama_index.utils import globals_helper |
|
enc = tiktoken.get_encoding("gpt2") |
|
tokenizer = lambda text: enc.encode(text, allowed_special={"<|endoftext|>"}) |
|
globals_helper._tokenizer = tokenizer |
|
|
|
|
|
def set_openai_local(): |
|
openai.api_key = os.environ['LOCAL_OPENAI_API_KEY'] |
|
openai.api_base = os.environ['LOCAL_OPENAI_API_BASE'] |
|
os.environ['OPENAI_API_KEY'] = os.environ['LOCAL_OPENAI_API_KEY'] |
|
os.environ['OPENAI_API_BASE'] = os.environ['LOCAL_OPENAI_API_BASE'] |
|
|
|
def set_openai(): |
|
openai.api_key = os.environ['DAVINCI_OPENAI_API_KEY'] |
|
openai.api_base = os.environ['DAVINCI_OPENAI_API_BASE'] |
|
os.environ['OPENAI_API_KEY'] = os.environ['DAVINCI_OPENAI_API_KEY'] |
|
os.environ['OPENAI_API_BASE'] = os.environ['DAVINCI_OPENAI_API_BASE'] |
|
|
|
def get_hf_predictor(query_model): |
|
|
|
set_openai_local() |
|
llm=HuggingFaceHub(repo_id=query_model, task="text-generation", |
|
model_kwargs={"temperature": 0.01, "max_length": MAX_LENGTH}, |
|
huggingfacehub_api_token=hf_api_key) |
|
llm_predictor = LLMPredictor(llm) |
|
return llm_predictor |
|
|
|
def get_cohere_predictor(query_model): |
|
|
|
set_openai_local() |
|
llm=Cohere(model='command', temperature = 0.01, |
|
|
|
cohere_api_key=ch_api_key) |
|
llm_predictor = LLMPredictor(llm) |
|
return llm_predictor |
|
|
|
def get_baseten_predictor(query_model): |
|
|
|
set_openai_local() |
|
llm=KronBasetenCamelLLM(model='3yd1ke3', temperature = 0.01, |
|
|
|
model_kwargs={"temperature": 0.01, "max_length": MAX_LENGTH, 'frequency_penalty':1}, |
|
cohere_api_key=ch_api_key) |
|
llm_predictor = LLMPredictor(llm) |
|
return llm_predictor |
|
|
|
def get_kron_openai_predictor(query_model): |
|
|
|
llm=KronOpenAI(temperature=0.01, model=query_model) |
|
llm.max_tokens = MAX_LENGTH |
|
llm_predictor = KronLLMPredictor(llm) |
|
return llm_predictor |
|
|
|
def get_servce_context(llm_predictor): |
|
|
|
text_splitter = SentenceSplitter(chunk_size=192, chunk_overlap=48, paragraph_separator='\n') |
|
|
|
node_parser = SimpleNodeParser(text_splitter=text_splitter) |
|
|
|
service_context = ServiceContext.from_defaults(llm_predictor=llm_predictor, node_parser=node_parser) |
|
return service_context |
|
|
|
def get_index(service_context, persist_path): |
|
print(f'Loading index from {persist_path}') |
|
|
|
storage_context = StorageContext.from_defaults(persist_dir=persist_path) |
|
|
|
index = load_index_from_storage(storage_context=storage_context, |
|
service_context=service_context, |
|
max_triplets_per_chunk=2, |
|
show_progress = False) |
|
return index |
|
|
|
def get_query_engine(index): |
|
|
|
RESPONSE_MODE = 'accumulate' |
|
query_engine = index.as_query_engine(response_mode = RESPONSE_MODE) |
|
return query_engine |
|
|
|
def load_query_engine(llm_predictor, persist_path): |
|
service_context = get_servce_context(llm_predictor) |
|
index = get_index(service_context, persist_path) |
|
print(f'No query engine for {persist_path}; creating') |
|
query_engine = get_query_engine(index) |
|
return query_engine |
|
|
|
@st.cache_resource |
|
def build_kron_query_engine(query_model, persist_path): |
|
llm_predictor = get_kron_openai_predictor(query_model) |
|
query_engine = load_query_engine(llm_predictor, persist_path) |
|
return query_engine |
|
|
|
@st.cache_resource |
|
def build_hf_query_engine(query_model, persist_path): |
|
llm_predictor = get_hf_predictor(query_model) |
|
query_engine = load_query_engine(llm_predictor, persist_path) |
|
return query_engine |
|
|
|
@st.cache_resource |
|
def build_cohere_query_engine(query_model, persist_path): |
|
llm_predictor = get_cohere_predictor(query_model) |
|
query_engine = load_query_engine(llm_predictor, persist_path) |
|
return query_engine |
|
|
|
@st.cache_resource |
|
def build_baseten_query_engine(query_model, persist_path): |
|
llm_predictor = get_baseten_predictor(query_model) |
|
query_engine = load_query_engine(llm_predictor, persist_path) |
|
return query_engine |
|
|
|
def format_response(answer): |
|
|
|
dashes = r'(\-{2,50})' |
|
answer.response = re.sub(dashes, '', answer.response) |
|
return answer.response or "None" |
|
|
|
def clear_question(query_model): |
|
if not ('prev_model' in st.session_state) or (('prev_model' in st.session_state) and (st.session_state.prev_model != query_model)) : |
|
if 'prev_model' in st.session_state: |
|
print(f'clearing question {st.session_state.prev_model} {query_model}') |
|
else: |
|
print(f'clearing question None {query_model}') |
|
if('question_input' in st.session_state): |
|
st.session_state.question = st.session_state.question_input |
|
st.session_state.question_input = '' |
|
st.session_state.question_answered = False |
|
st.session_state.answer = '' |
|
st.session_state.prev_model = query_model |
|
|
|
|
|
initial_query = '' |
|
|
|
|
|
if 'question' not in st.session_state: |
|
st.session_state.question = '' |
|
|
|
if __spaces__ : |
|
answer_model = st.radio( |
|
"Choose the model used for inference:", |
|
('baseten/Camel-5b', 'cohere/command','hf/tiiuae/falcon-7b-instruct', 'openai/text-davinci-003') |
|
|
|
) |
|
else : |
|
answer_model = st.radio( |
|
"Choose the model used for inference:", |
|
('Local-Camel', 'HF-TKI', 'hf/tiiuae/falcon-7b-instruct', 'openai/text-davinci-003') |
|
) |
|
|
|
if answer_model == 'openai/text-davinci-003': |
|
print(answer_model) |
|
query_model = 'text-davinci-003' |
|
clear_question(query_model) |
|
set_openai() |
|
query_engine = build_kron_query_engine(query_model, persist_path) |
|
elif answer_model == 'hf/tiiuae/falcon-7b-instruct': |
|
print(answer_model) |
|
query_model = 'tiiuae/falcon-7b-instruct' |
|
clear_question(query_model) |
|
query_engine = build_hf_query_engine(query_model, persist_path) |
|
elif answer_model == 'cohere/command': |
|
print(answer_model) |
|
query_model = 'cohere/command' |
|
clear_question(query_model) |
|
query_engine = build_cohere_query_engine(query_model, persist_path) |
|
elif answer_model == 'baseten/Camel-5b': |
|
print(answer_model) |
|
query_model = 'baseten/Camel-5b' |
|
clear_question(query_model) |
|
query_engine = build_baseten_query_engine(query_model, persist_path) |
|
elif answer_model == 'Local-Camel': |
|
query_model = 'Writer/camel-5b-hf' |
|
print(answer_model) |
|
clear_question(query_model) |
|
set_openai_local() |
|
query_engine = build_kron_query_engine(query_model, persist_path) |
|
elif answer_model == 'HF-TKI': |
|
query_model = 'allenai/tk-instruct-3b-def-pos-neg-expl' |
|
clear_question(query_model) |
|
query_engine = build_hf_query_engine(query_model, persist_path) |
|
else: |
|
print('This is a bug.') |
|
|
|
|
|
def submit(): |
|
st.session_state.question = st.session_state.question_input |
|
st.session_state.question_input = '' |
|
st.session_state.question_answered = False |
|
|
|
|
|
|
|
|
|
|
|
|
|
st.write(f'Model, question, answer and rating are logged to help with the improvement of this application.') |
|
question = st.text_input("Enter a question, e.g. What benchmarks can we use for QA?", key='question_input', on_change=submit ) |
|
|
|
|
|
if(st.session_state.question): |
|
col1, col2 = st.columns([2, 2]) |
|
with col1: |
|
st.write(f'Answering: {st.session_state.question} with {query_model}.') |
|
|
|
try : |
|
if not st.session_state.question_answered: |
|
answer = query_engine.query(st.session_state.question) |
|
st.session_state.answer = answer |
|
st.session_state.question_answered = True |
|
else: |
|
answer = st.session_state.answer |
|
answer_str = format_response(answer) |
|
st.write(answer_str) |
|
with col1: |
|
if answer_str: |
|
st.write(f' Please rate this answer.') |
|
with col2: |
|
from streamlit_star_rating import st_star_rating |
|
stars = st_star_rating("", maxValue=5, defaultValue=3, key="answer_rating", |
|
|
|
|
|
) |
|
print(f"------stars {stars}") |
|
except Exception as e: |
|
print(e) |
|
answer_str = str(e) |
|
st.session_state.answer_rating = -1 |
|
finally: |
|
if 'question' in st.session_state: |
|
req = st.session_state.question |
|
|
|
if(__spaces__): |
|
|
|
st.session_state.request_log.add_request_log_entry(query_model, req, answer_str, st.session_state.answer_rating) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|