TI_demo_E2E / app.py
arjunanand13's picture
Update app.py
7a542bf verified
raw
history blame
4.23 kB
import gradio as gr
import transformers
from transformers import AutoTokenizer, AutoModelForCausalLM
import accelerate
import einops
import langchain
import xformers
import os
import bitsandbytes
import sentence_transformers
import huggingface_hub
import torch
from torch import cuda, bfloat16
from transformers import StoppingCriteria, StoppingCriteriaList
from langchain.llms import HuggingFacePipeline
from langchain.document_loaders import TextLoader, DirectoryLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.vectorstores import FAISS
from langchain.chains import ConversationalRetrievalChain
# Login to Hugging Face using a token
# huggingface_hub.login(HF_TOKEN)
"""
Loading of the LLama3 model
"""
HF_TOKEN = os.environ.get("HF_TOKEN", None)
model_id = 'meta-llama/Meta-Llama-3-8B-Instruct'
device = f'cuda:{cuda.current_device()}' if cuda.is_available() else 'cpu'
# set quantization configuration to load large model with less GPU memory
# this requires the `bitsandbytes` library
# bnb_config = transformers.BitsAndBytesConfig(
# load_in_4bit=True,
# bnb_4bit_quant_type='nf4',
# bnb_4bit_use_double_quant=True,
# bnb_4bit_compute_dtype=bfloat16
# )
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct")
model = AutoModelForCausalLM.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct", device_map="auto") # to("cuda:0")
terminators = [
tokenizer.eos_token_id,
tokenizer.convert_tokens_to_ids("<|eot_id|>")
]
"""
Setting up the stop list to define stopping criteria.
"""
stop_list = ['\nHuman:', '\n```\n']
stop_token_ids = [tokenizer(x)['input_ids'] for x in stop_list]
stop_token_ids = [torch.LongTensor(x).to(device) for x in stop_token_ids]
# define custom stopping criteria object
class StopOnTokens(StoppingCriteria):
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
for stop_ids in stop_token_ids:
if torch.eq(input_ids[0][-len(stop_ids):], stop_ids).all():
return True
return False
stopping_criteria = StoppingCriteriaList([StopOnTokens()])
generate_text = transformers.pipeline(
model=model,
tokenizer=tokenizer,
return_full_text=True, # langchain expects the full text
task='text-generation',
# we pass model parameters here too
stopping_criteria=stopping_criteria, # without this model rambles during chat
temperature=0.1, # 'randomness' of outputs, 0.0 is the min and 1.0 the max
max_new_tokens=512, # max number of tokens to generate in the output
repetition_penalty=1.1 # without this output begins repeating
)
llm = HuggingFacePipeline(pipeline=generate_text)
loader = DirectoryLoader('data/text/', loader_cls=TextLoader)
documents = loader.load()
print('len of documents are',len(documents))
text_splitter = RecursiveCharacterTextSplitter(chunk_size=5000, chunk_overlap=250)
all_splits = text_splitter.split_documents(documents)
model_name = "sentence-transformers/all-mpnet-base-v2"
model_kwargs = {"device": "cuda"}
embeddings = HuggingFaceEmbeddings(model_name=model_name, model_kwargs=model_kwargs)
# storing embeddings in the vector store
vectorstore = FAISS.from_documents(all_splits, embeddings)
chain = ConversationalRetrievalChain.from_llm(llm, vectorstore.as_retriever(), return_source_documents=True)
chat_history = []
def qa_infer(query):
result = chain({"question": query, "chat_history": chat_history})
print(result['answer'])
return result['answer']
# query = "What` is the best TS pin configuration for BQ24040 in normal battery charge mode"
# qa_infer(query)
EXAMPLES = ["What is the best TS pin configuration for BQ24040 in normal battery charge mode",
"Can BQ25896 support I2C interface?",
"Can you please provide me with Gerber/CAD file for UCC2897A"]
demo = gr.Interface(fn=qa_infer, inputs="text",allow_flagging='never', examples=EXAMPLES,
cache_examples=False,outputs="text")
# launch the app!
#demo.launch(enable_queue = True,share=True)
#demo.queue(default_enabled=True).launch(debug=True,share=True)
demo.launch()