|
|
|
""" |
|
Created on Sat Oct 5 16:41:22 2024 |
|
|
|
@author: Admin |
|
""" |
|
|
|
import gradio as gr |
|
from transformers import pipeline |
|
import os |
|
from huggingface_hub import login |
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
import torch |
|
|
|
|
|
from llama_index.embeddings.huggingface import HuggingFaceEmbedding |
|
from llama_index.core import Settings, SimpleDirectoryReader, VectorStoreIndex |
|
from llama_index.core.retrievers import VectorIndexRetriever |
|
from llama_index.core.query_engine import RetrieverQueryEngine |
|
from llama_index.core.postprocessor import SimilarityPostprocessor |
|
|
|
Settings.embed_model = HuggingFaceEmbedding(model_name="BAAI/bge-small-en-v1.5") |
|
Settings.llm = None |
|
Settings.chunk_size = 256 |
|
Settings.chunk_overlap = 25 |
|
documents = SimpleDirectoryReader("./test").load_data() |
|
index = VectorStoreIndex.from_documents(documents) |
|
|
|
top_k = 6 |
|
|
|
|
|
retriever = VectorIndexRetriever( |
|
index=index, |
|
similarity_top_k=top_k, |
|
) |
|
|
|
query_engine = RetrieverQueryEngine( |
|
retriever=retriever, |
|
node_postprocessors=[SimilarityPostprocessor(similarity_cutoff=0.5)], |
|
) |
|
|
|
|
|
|
|
from peft import PeftModel, PeftConfig |
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
|
|
model_name = "microsoft/Phi-3.5-mini-instruct" |
|
model = AutoModelForCausalLM.from_pretrained(model_name, |
|
device_map="auto", |
|
trust_remote_code=False, |
|
revision="main") |
|
|
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True) |
|
model.eval() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
prompt_template_w_context = lambda context, comment: f"""{context} |
|
Please respond to the following comment. Use the context above if it is helpful. |
|
{comment} |
|
[/INST] |
|
""" |
|
|
|
|
|
message_list = [] |
|
response_list = [] |
|
|
|
|
|
def vanilla_chatbot(message, history): |
|
response = query_engine.query(message) |
|
|
|
context = "Context:\n" |
|
for i in range(len(response.source_nodes)): |
|
context = context + response.source_nodes[i].text + "\n\n" |
|
print(context) |
|
prompt = prompt_template_w_context(context, message) |
|
inputs = tokenizer(prompt, return_tensors="pt") |
|
outputs = model.generate(input_ids=inputs["input_ids"], max_new_tokens=100) |
|
|
|
|
|
ot=tokenizer.batch_decode(outputs)[0] |
|
context_length=len(prompt) |
|
new_sentence = ot[context_length+3:] |
|
return new_sentence |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
demo_chatbot = gr.ChatInterface(vanilla_chatbot, title="Vanilla Chatbot", description="Enter text to start chatting.") |
|
|
|
demo_chatbot.launch(True) |