File size: 2,489 Bytes
a2c0a8e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d353114
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import streamlit as st
from langchain.document_loaders import PyPDFLoader
from langchain.indexes import VectorstoreIndexCreator
from langchain.chains import RetrievalQA, RetrievalQAWithSourcesChain
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.text_splitter import RecursiveCharacterTextSplitter
from transformers import AutoModelForCausalLM, AutoTokenizer


    
def respond_to_question(question, model, tokenizer):

    prompt = [{'role': 'user', 'content': question}]
    inputs = tokenizer.apply_chat_template(
        prompt,
        add_generation_prompt=True,
        return_tensors='pt'
    )

    tokens = model.generate(
        inputs.to(model.device),
        max_new_tokens=1024,
        temperature=0.8,
        do_sample=True
    )

    print(tokenizer.decode(tokens[0], skip_special_tokens=False))
    return tokenizer.decode(tokens[0], skip_special_tokens=False)
    # prompt = "write me a python function that prints the fibonacci sequence"
    # messages = [
    #     {
    #         "role": "system",
    #         "content": "You are a friendly chatbot who can code",
    #     },
    #     {"role": "user", "content": prompt},
    # ]
    # prompt = pipe.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    # outputs = pipe(prompt, max_new_tokens=256, do_sample=True, temperature=0.7, top_k=50, top_p=0.95)
    # print(outputs[0]["generated_text"].split("<|assistant|>")[1])
    # return outputs[0]["generated_text"].split("<|assistant|>")[1]

def main():
    st.title("LangChain Demo")

    tokenizer = AutoTokenizer.from_pretrained('stabilityai/stablelm-zephyr-3b')
    model = AutoModelForCausalLM.from_pretrained(
            'stabilityai/stablelm-zephyr-3b',
            trust_remote_code=True,
            device_map="auto"
        )

    if 'messages' not in st.session_state:
        st.session_state.messages = []
        
    prompt = st.text_input("Enter your question here:")

    for message in st.session_state.messages:
        st.chat_message(message['role']).markdown(message['text'])

    if prompt:
        st.session_state.messages.append({'role': 'user', 'text': prompt})
        st.chat_message("user").markdown(prompt)
        model_response = respond_to_question(prompt, model, tokenizer)
        st.session_state.messages.append({'role': 'Assistant', 'text': model_response})
        st.chat_message("system").markdown(model_response)
        
if __name__ == "__main__":
    main()