|
import streamlit as st |
|
|
|
|
|
from transformers import AutoTokenizer, AutoModelForCausalLM, TextStreamer |
|
|
|
import torch |
|
|
|
from llama_index.prompts.prompts import SimpleInputPrompt |
|
|
|
from llama_index.llms import HuggingFaceLLM |
|
|
|
from llama_index.embeddings import LangchainEmbedding |
|
|
|
from langchain.embeddings.huggingface import HuggingFaceEmbeddings |
|
|
|
from llama_index import set_global_service_context |
|
from llama_index import ServiceContext |
|
|
|
from llama_index import VectorStoreIndex, download_loader |
|
from pathlib import Path |
|
|
|
|
|
name = "NousResearch/Llama-2-7b-chat-hf" |
|
|
|
|
|
|
|
|
|
@st.cache_resource |
|
def get_tokenizer_model(): |
|
|
|
tokenizer = AutoTokenizer.from_pretrained(name, cache_dir='./model/') |
|
|
|
|
|
model = AutoModelForCausalLM.from_pretrained(name, cache_dir='./model/' |
|
, torch_dtype=torch.float16, |
|
rope_scaling={"type": "dynamic", "factor": 2}, load_in_8bit=True) |
|
|
|
return tokenizer, model |
|
tokenizer, model = get_tokenizer_model() |
|
|
|
|
|
system_prompt = """<s>[INST] <<SYS>> |
|
You are a helpful, respectful and honest assistant. Always answer as |
|
helpfully as possible, while being safe. Your answers should not include |
|
any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. |
|
Please ensure that your responses are socially unbiased and positive in nature. |
|
|
|
If a question does not make any sense, or is not factually coherent, explain |
|
why instead of answering something not correct. If you don't know the answer |
|
to a question, please don't share false information. |
|
|
|
Your goal is to provide answers relating to the workout science and informatins in the documentSYS>> |
|
""" |
|
|
|
|
|
|
|
query_wrapper_prompt = SimpleInputPrompt("{query_str} [/INST]") |
|
|
|
|
|
llm = HuggingFaceLLM(context_window=1024, |
|
max_new_tokens=128, |
|
system_prompt=system_prompt, |
|
query_wrapper_prompt=query_wrapper_prompt, |
|
model=model, |
|
tokenizer=tokenizer) |
|
|
|
|
|
embeddings=LangchainEmbedding( |
|
HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2") |
|
) |
|
|
|
|
|
|
|
service_context = ServiceContext.from_defaults( |
|
chunk_size=1024, |
|
llm=llm, |
|
embed_model=embeddings |
|
) |
|
|
|
set_global_service_context(service_context) |
|
|
|
|
|
PyMuPDFReader = download_loader("PyMuPDFReader") |
|
|
|
loader = PyMuPDFReader() |
|
|
|
documents = loader.load(file_path=Path('./data/annualreport.pdf'), metadata=True) |
|
|
|
|
|
PyMuPDFReader = download_loader("PyMuPDFReader") |
|
|
|
loader = PyMuPDFReader() |
|
|
|
documents = loader.load(file_path=Path('jeff_wo.pdf'), metadata=True) |
|
|
|
|
|
index = VectorStoreIndex.from_documents(documents) |
|
|
|
query_engine = index.as_query_engine() |
|
|
|
|
|
|
|
st.title('🦙 Llama Banker') |
|
|
|
prompt = st.text_input('Input your prompt here') |
|
|
|
|
|
if prompt: |
|
response = query_engine.query(prompt) |
|
|
|
st.write(response) |
|
|
|
|
|
with st.expander('Response Object'): |
|
st.write(response) |
|
|
|
with st.expander('Source Text'): |
|
st.write(response.get_formatted_sources()) |
|
|