File size: 3,448 Bytes
c4b8230
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e807162
 
 
 
 
 
 
 
 
 
492258d
e807162
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1a1ce31
e807162
1a1ce31
 
 
 
 
 
 
 
 
 
 
 
 
e807162
c4b8230
02e046e
c4b8230
 
02e046e
 
 
 
 
 
c4b8230
 
 
e807162
 
 
 
 
 
 
c4b8230
 
 
 
 
e807162
 
 
 
 
3051fc5
e807162
1a1ce31
3051fc5
e807162
 
1a1ce31
 
 
 
02e046e
 
 
 
1a1ce31
c4b8230
1a1ce31
c4b8230
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
# -*- coding: utf-8 -*-
"""
Created on Sat Oct  5 16:41:22 2024

@author: Admin
"""

import gradio as gr
from transformers import pipeline
import os
from huggingface_hub import login
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch


from llama_index.embeddings.huggingface import HuggingFaceEmbedding
from llama_index.core import Settings, SimpleDirectoryReader, VectorStoreIndex
from llama_index.core.retrievers import VectorIndexRetriever
from llama_index.core.query_engine import RetrieverQueryEngine
from llama_index.core.postprocessor import SimilarityPostprocessor

Settings.embed_model = HuggingFaceEmbedding(model_name="BAAI/bge-small-en-v1.5")
Settings.llm = None
Settings.chunk_size = 256
Settings.chunk_overlap = 25
documents = SimpleDirectoryReader("./test").load_data()
index = VectorStoreIndex.from_documents(documents)

top_k = 6

# configure retriever
retriever = VectorIndexRetriever(
    index=index,
    similarity_top_k=top_k,
)

query_engine = RetrieverQueryEngine(
    retriever=retriever,
    node_postprocessors=[SimilarityPostprocessor(similarity_cutoff=0.5)],
)

#chatbot = pipeline(model="microsoft/Phi-3.5-mini-instruct")

from peft import PeftModel, PeftConfig
from transformers import AutoModelForCausalLM, AutoTokenizer

model_name = "microsoft/Phi-3.5-mini-instruct"
model = AutoModelForCausalLM.from_pretrained(model_name,
                                             device_map="auto",
                                             trust_remote_code=False,
                                             revision="main")


# load tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
model.eval()

#token = os.getenv("HF_TOKEN")
#login(token = os.getenv('HF_TOKEN'))
#chatbot = pipeline(model="meta-llama/Llama-3.2-1B")

#tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-1B-Instruct")
#model = AutoModelForCausalLM.from_pretrained(
#    "meta-llama/Llama-3.2-1B-Instruct",
#    device_map="auto",
#    torch_dtype="auto",
#)

#chatbot = pipeline(model="facebook/blenderbot-400M-distill")

prompt_template_w_context = lambda context, comment: f"""{context}
Please respond to the following comment. Use the context above if it is helpful.
{comment}
[/INST]
"""


message_list = []
response_list = []


def vanilla_chatbot(message, history):
    response = query_engine.query(message)
    # reformat response
    context = "Context:\n"
    for i in range(len(response.source_nodes)):
        context = context + response.source_nodes[i].text + "\n\n"
    print(context)
    prompt = prompt_template_w_context(context, message)
    inputs = tokenizer(prompt, return_tensors="pt")
    outputs = model.generate(input_ids=inputs["input_ids"], max_new_tokens=100)
    #print(tokenizer.batch_decode(outputs)[0])
    #conversation = pipe(message, temperature=0.1)
    ot=tokenizer.batch_decode(outputs)[0]
    context_length=len(prompt)
    new_sentence = ot[context_length+3:]
    return new_sentence
    #inputs = tokenizer(message, return_tensors="pt").to("cpu")
    #with torch.no_grad():
    #    outputs = model.generate(inputs.input_ids, max_length=100)
    #return tokenizer.decode(outputs[0], skip_special_tokens=True)
    #conversation = chatbot(prompt)
    
    #return conversation[0]['generated_text']

demo_chatbot = gr.ChatInterface(vanilla_chatbot, title="Vanilla Chatbot", description="Enter text to start chatting.")

demo_chatbot.launch(True)