|
from llama_cpp import Llama |
|
import streamlit as st |
|
from langchain.llms.base import LLM |
|
from llama_index import LLMPredictor, LangchainEmbedding, ServiceContext, PromptHelper |
|
from typing import Optional, List, Mapping, Any |
|
from langchain.embeddings.huggingface import HuggingFaceEmbeddings |
|
|
|
MODEL_NAME = 'TheBloke/MelloGPT-AWQ' |
|
|
|
|
|
NUM_THREADS = 8 |
|
|
|
|
|
max_input_size = 2048 |
|
|
|
num_output = 256 |
|
|
|
chunk_overlap_ratio = 0.8 |
|
|
|
try: |
|
prompt_helper = PromptHelper(max_input_size, num_output, chunk_overlap_ratio) |
|
except Exception as e: |
|
chunk_overlap_ratio = 0.2 |
|
prompt_helper = PromptHelper(max_input_size, num_output, chunk_overlap_ratio) |
|
|
|
embed_model = LangchainEmbedding(HuggingFaceEmbeddings()) |
|
|
|
|
|
class CustomLLM(LLM): |
|
model_name = MODEL_NAME |
|
|
|
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str: |
|
p = f"Human: {prompt} Assistant: " |
|
prompt_length = len(p) |
|
llm = Llama(model_path=MODEL_PATH, n_threads=NUM_THREADS) |
|
output = llm(p, max_tokens=512, stop=["Human:"], echo=True)['choices'][0]['text'] |
|
|
|
response = output[prompt_length:] |
|
st.session_state.messages.append({"role": "user", "content": prompt}) |
|
st.session_state.messages.append({"role": "assistant", "content": response}) |
|
|
|
@property |
|
def _identifying_params(self) -> Mapping[str, Any]: |
|
return {"name_of_model": self.model_name} |
|
|
|
@property |
|
def _llm_type(self) -> str: |
|
return "custom" |
|
|
|
|
|
|
|
llm_predictor = LLMPredictor(llm=CustomLLM()) |
|
service_context = ServiceContext.from_defaults(llm_predictor=llm_predictor, prompt_helper=prompt_helper, embed_model=embed_model) |
|
|
|
|
|
def clear_convo(): |
|
st.session_state['messages'] = [] |
|
|
|
|
|
def init(): |
|
st.set_page_config(page_title='Local LLama', page_icon=':robot_face: ') |
|
st.sidebar.title('Local LLama') |
|
if 'messages' not in st.session_state: |
|
st.session_state['messages'] = [] |
|
|
|
|
|
if __name__ == '__main__': |
|
init() |
|
|
|
|
|
@st.cache_resource |
|
def get_llm(): |
|
llm = CustomLLM() |
|
return llm |
|
|
|
clear_button = st.sidebar.button("Clear Conversation", key="clear") |
|
if clear_button: |
|
clear_convo() |
|
|
|
user_input = st.chat_input("Say something") |
|
|
|
if user_input: |
|
llm = get_llm() |
|
llm._call(prompt=user_input) |
|
|
|
for message in st.session_state.messages: |
|
with st.chat_message(message["role"]): |
|
st.markdown(message["content"]) |