File size: 2,728 Bytes
93f9537
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
from llama_cpp import Llama
import streamlit as st
from langchain.llms.base import LLM
from llama_index import LLMPredictor, LangchainEmbedding, ServiceContext, PromptHelper
from typing import Optional, List, Mapping, Any
from langchain.embeddings.huggingface import HuggingFaceEmbeddings

MODEL_NAME = 'TheBloke/MelloGPT-AWQ'

# Number of threads to use
NUM_THREADS = 8
# define prompt helper
# set maximum input size
max_input_size = 2048
# set number of output tokens
num_output = 256
# set maximum chunk overlap
chunk_overlap_ratio = 0.8

try:
    prompt_helper = PromptHelper(max_input_size, num_output, chunk_overlap_ratio)
except Exception as e:
    chunk_overlap_ratio = 0.2  # Set a different max_chunk_overlap value for the next attempt
    prompt_helper = PromptHelper(max_input_size, num_output, chunk_overlap_ratio)
    
embed_model = LangchainEmbedding(HuggingFaceEmbeddings())


class CustomLLM(LLM):
    model_name = MODEL_NAME

    def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
        p = f"Human: {prompt} Assistant: "
        prompt_length = len(p)
        llm = Llama(model_path=MODEL_PATH, n_threads=NUM_THREADS)
        output = llm(p, max_tokens=512, stop=["Human:"], echo=True)['choices'][0]['text']
        # only return newly generated tokens by slicing list to include words after the original prompt
        response = output[prompt_length:]
        st.session_state.messages.append({"role": "user", "content": prompt})
        st.session_state.messages.append({"role": "assistant", "content": response})

    @property
    def _identifying_params(self) -> Mapping[str, Any]:
        return {"name_of_model": self.model_name}

    @property
    def _llm_type(self) -> str:
        return "custom"


# define our LLM
llm_predictor = LLMPredictor(llm=CustomLLM())
service_context = ServiceContext.from_defaults(llm_predictor=llm_predictor, prompt_helper=prompt_helper, embed_model=embed_model)


def clear_convo():
    st.session_state['messages'] = []


def init():
    st.set_page_config(page_title='Local LLama', page_icon=':robot_face: ')
    st.sidebar.title('Local LLama')
    if 'messages' not in st.session_state:
        st.session_state['messages'] = []


if __name__ == '__main__':
    init()


    @st.cache_resource
    def get_llm():
        llm = CustomLLM()
        return llm

    clear_button = st.sidebar.button("Clear Conversation", key="clear")
    if clear_button:
        clear_convo()

    user_input = st.chat_input("Say something")

    if user_input:
        llm = get_llm()
        llm._call(prompt=user_input)

    for message in st.session_state.messages:
        with st.chat_message(message["role"]):
            st.markdown(message["content"])