RegBotBeta / models /llamaCustom.py
zhtet's picture
Update models/llamaCustom.py
5bf0c23
raw
history blame
5.15 kB
import os
import pickle
from json import dumps, loads
from typing import Any, List, Mapping, Optional
import numpy as np
import openai
import pandas as pd
import streamlit as st
from dotenv import load_dotenv
from huggingface_hub import HfFileSystem, Repository
from llama_index import (
Document,
GPTVectorStoreIndex,
LLMPredictor,
PromptHelper,
ServiceContext,
SimpleDirectoryReader,
StorageContext,
load_index_from_storage,
)
from llama_index.llms import CompletionResponse, CustomLLM, LLMMetadata
# from langchain.llms.base import LLM
# from llama_index.prompts import Prompt
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, pipeline
# from utils.customLLM import CustomLLM
load_dotenv()
# openai.api_key = os.getenv("OPENAI_API_KEY")
fs = HfFileSystem()
# define prompt helper
# set maximum input size
CONTEXT_WINDOW = 2048
# set number of output tokens
NUM_OUTPUT = 525
# set maximum chunk overlap
CHUNK_OVERLAP_RATIO = 0.2
prompt_helper = PromptHelper(
context_window=CONTEXT_WINDOW,
num_output=NUM_OUTPUT,
chunk_overlap_ratio=CHUNK_OVERLAP_RATIO,
)
@st.cache_resource
def load_model(mode_name: str):
# llm_model_name = "bigscience/bloom-560m"
tokenizer = AutoTokenizer.from_pretrained(mode_name)
model = AutoModelForCausalLM.from_pretrained(mode_name, config="T5Config")
pipe = pipeline(
task="text-generation",
model=model,
tokenizer=tokenizer,
# device=0, # GPU device number
# max_length=512,
do_sample=True,
top_p=0.95,
top_k=50,
temperature=0.7,
)
return pipe
class OurLLM(CustomLLM):
def __init__(self, model_name: str, model_pipeline):
self.model_name = model_name
self.pipeline = model_pipeline
@property
def metadata(self) -> LLMMetadata:
"""Get LLM metadata."""
return LLMMetadata(
context_window=CONTEXT_WINDOW,
num_output=NUM_OUTPUT,
model_name=self.model_name,
)
def complete(self, prompt: str, **kwargs: Any) -> CompletionResponse:
prompt_length = len(prompt)
response = self.pipeline(prompt, max_new_tokens=NUM_OUTPUT)[0]["generated_text"]
# only return newly generated tokens
text = response[prompt_length:]
return CompletionResponse(text=text)
def stream_complete(self, prompt: str, **kwargs: Any) -> CompletionResponse:
raise NotImplementedError()
# def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
# prompt_length = len(prompt)
# response = self.pipeline(prompt, max_new_tokens=525)[0]["generated_text"]
# # only return newly generated tokens
# return response[prompt_length:]
# @property
# def _identifying_params(self) -> Mapping[str, Any]:
# return {"name_of_model": self.model_name}
# @property
# def _llm_type(self) -> str:
# return "custom"
@st.cache_resource
class LlamaCustom:
# define llm
# llm_predictor = LLMPredictor(llm=OurLLM())
# service_context = ServiceContext.from_defaults(
# llm_predictor=llm_predictor, prompt_helper=prompt_helper
# )
def __init__(self, model_name: str) -> None:
pipe = load_model(mode_name=model_name)
llm = OurLLM(model_name=model_name, model_pipeline=pipe)
self.service_context = ServiceContext.from_defaults(
llm=llm, prompt_helper=prompt_helper
)
self.vector_index = self.initialize_index(model_name=model_name)
def initialize_index(self, model_name: str):
index_name = model_name.split("/")[-1]
file_path = f"./vectorStores/{index_name}"
if os.path.exists(path=file_path):
# rebuild storage context
storage_context = StorageContext.from_defaults(persist_dir=file_path)
# local load index access
index = load_index_from_storage(storage_context)
# huggingface repo load access
# with fs.open(file_path, "r") as file:
# index = pickle.loads(file.readlines())
return index
else:
# documents = prepare_data(r"./assets/regItems.json")
documents = SimpleDirectoryReader(input_dir="./assets/pdf").load_data()
index = GPTVectorStoreIndex.from_documents(
documents, service_context=self.service_context
)
# local write access
index.storage_context.persist(file_path)
# huggingface repo write access
# with fs.open(file_path, "w") as file:
# file.write(pickle.dumps(index))
return index
def get_response(self, query_str):
print("query_str: ", query_str)
query_engine = self.vector_index.as_query_engine()
# query_engine = self.vector_index.as_query_engine(
# text_qa_template=text_qa_template, refine_template=refine_template
# )
response = query_engine.query(query_str)
print("metadata: ", response.metadata)
return str(response)