|
import streamlit as st |
|
from langchain.document_loaders import PyPDFLoader |
|
from langchain.indexes import VectorstoreIndexCreator |
|
from langchain.chains import RetrievalQA, RetrievalQAWithSourcesChain |
|
from langchain.embeddings import HuggingFaceEmbeddings |
|
from langchain.text_splitter import RecursiveCharacterTextSplitter |
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
|
|
|
|
|
|
def respond_to_question(question, model, tokenizer): |
|
|
|
prompt = [{'role': 'user', 'content': question}] |
|
inputs = tokenizer.apply_chat_template( |
|
prompt, |
|
add_generation_prompt=True, |
|
return_tensors='pt' |
|
) |
|
|
|
tokens = model.generate( |
|
inputs.to(model.device), |
|
max_new_tokens=1024, |
|
temperature=0.8, |
|
do_sample=True |
|
) |
|
|
|
print(tokenizer.decode(tokens[0], skip_special_tokens=False)) |
|
return tokenizer.decode(tokens[0], skip_special_tokens=False) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def main(): |
|
st.title("LangChain Demo") |
|
|
|
tokenizer = AutoTokenizer.from_pretrained('stabilityai/stablelm-zephyr-3b') |
|
model = AutoModelForCausalLM.from_pretrained( |
|
'stabilityai/stablelm-zephyr-3b', |
|
trust_remote_code=True, |
|
device_map="auto" |
|
) |
|
|
|
if 'messages' not in st.session_state: |
|
st.session_state.messages = [] |
|
|
|
prompt = st.text_input("Enter your question here:") |
|
|
|
for message in st.session_state.messages: |
|
st.chat_message(message['role']).markdown(message['text']) |
|
|
|
if prompt: |
|
st.session_state.messages.append({'role': 'user', 'text': prompt}) |
|
st.chat_message("user").markdown(prompt) |
|
model_response = respond_to_question(prompt, model, tokenizer) |
|
st.session_state.messages.append({'role': 'Assistant', 'text': model_response}) |
|
st.chat_message("system").markdown(model_response) |
|
|
|
if __name__ == "__main__": |
|
main() |