import streamlit as st from langchain.document_loaders import PyPDFLoader from langchain.indexes import VectorstoreIndexCreator from langchain.chains import RetrievalQA, RetrievalQAWithSourcesChain from langchain.embeddings import HuggingFaceEmbeddings from langchain.text_splitter import RecursiveCharacterTextSplitter from transformers import AutoModelForCausalLM, AutoTokenizer def respond_to_question(question, model, tokenizer): prompt = [{'role': 'user', 'content': question}] inputs = tokenizer.apply_chat_template( prompt, add_generation_prompt=True, return_tensors='pt' ) tokens = model.generate( inputs.to(model.device), max_new_tokens=1024, temperature=0.8, do_sample=True ) print(tokenizer.decode(tokens[0], skip_special_tokens=False)) return tokenizer.decode(tokens[0], skip_special_tokens=False) # prompt = "write me a python function that prints the fibonacci sequence" # messages = [ # { # "role": "system", # "content": "You are a friendly chatbot who can code", # }, # {"role": "user", "content": prompt}, # ] # prompt = pipe.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) # outputs = pipe(prompt, max_new_tokens=256, do_sample=True, temperature=0.7, top_k=50, top_p=0.95) # print(outputs[0]["generated_text"].split("<|assistant|>")[1]) # return outputs[0]["generated_text"].split("<|assistant|>")[1] def main(): st.title("LangChain Demo") tokenizer = AutoTokenizer.from_pretrained('stabilityai/stablelm-zephyr-3b') model = AutoModelForCausalLM.from_pretrained( 'stabilityai/stablelm-zephyr-3b', trust_remote_code=True, device_map="auto" ) if 'messages' not in st.session_state: st.session_state.messages = [] prompt = st.text_input("Enter your question here:") for message in st.session_state.messages: st.chat_message(message['role']).markdown(message['text']) if prompt: st.session_state.messages.append({'role': 'user', 'text': prompt}) st.chat_message("user").markdown(prompt) model_response = respond_to_question(prompt, model, tokenizer) st.session_state.messages.append({'role': 'Assistant', 'text': model_response}) st.chat_message("system").markdown(model_response) if __name__ == "__main__": main()