RegBotBeta / models /bloom.py
Zwea Htet
added codes
0809507
raw
history blame
2.24 kB
import os
from json import dumps, loads
import numpy as np
import pandas as pd
from dotenv import load_dotenv
from llama_index import (Document, GPTVectorStoreIndex, LLMPredictor,
PromptHelper, ServiceContext, StorageContext,
load_index_from_storage)
from transformers import AutoModelForCausalLM, AutoTokenizer
from utils.customLLM import CustomLLM
load_dotenv()
# get model
model_name = "bigscience/bloom-560m"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name, config='T5Config')
# define prompt helper
# set maximum input size
max_input_size = 2048
# set number of output tokens
num_output = 525
# set maximum chunk overlap
max_chunk_overlap = 20
prompt_helper = PromptHelper(max_input_size, num_output, max_chunk_overlap)
# define llm
llm_predictor = LLMPredictor(llm=CustomLLM(model, tokenizer))
service_context = ServiceContext.from_defaults(llm_predictor=llm_predictor, prompt_helper=prompt_helper)
def prepare_data(file_path:str):
df = pd.read_json(file_path)
df = df.replace(to_replace="", value=np.nan).dropna(axis=0) # remove null values
parsed = loads(df.to_json(orient="records"))
documents = []
for item in parsed:
document = Document(item['paragraphText'],
item['_id']['$oid'],
extra_info={"chapter": item['chapter'],
"article": item['article'],
"title": item['title']})
documents.append(document)
return documents
def initialize_index(index_name):
file_path = f"./vectorStores/{index_name}"
if os.path.exists(file_path):
# rebuild storage context
storage_context = StorageContext.from_defaults(persist_dir=file_path)
# load index
index = load_index_from_storage(storage_context)
return GPTVectorStoreIndex.load_from_disk(file_path)
else:
documents = prepare_data(r"./assets/regItems.json")
index = GPTVectorStoreIndex.from_documents(documents, service_context=service_context)
index.storage_context.persist(file_path)
return index