File size: 4,618 Bytes
07ffad3
42df98c
1b7e4b0
07ffad3
 
 
 
 
eaca477
1b7e4b0
43ae797
07ffad3
 
1b7e4b0
 
 
 
 
 
eaca477
1b7e4b0
07ffad3
eaca477
7d9a21e
ec493d8
1b7e4b0
 
07ffad3
5ea07e3
1b7e4b0
eaca477
b0b771c
1b7e4b0
42df98c
eaca477
e4b2161
eaca477
 
 
 
1b7e4b0
 
eaca477
1b7e4b0
8577be5
07ffad3
1b7e4b0
eaca477
 
 
 
 
23e06d0
fdd8ddb
07ffad3
 
cc1edc1
07ffad3
e82c570
 
 
eaca477
e82c570
eaca477
ec493d8
e82c570
eaca477
1b7e4b0
e82c570
07ffad3
 
 
ec493d8
 
07ffad3
42df98c
 
 
3afa221
 
07ffad3
eaca477
07ffad3
eaca477
1b7e4b0
07ffad3
 
 
 
 
 
 
 
 
 
3afa221
07ffad3
 
43ae797
3afa221
 
89b12a6
 
 
 
 
 
 
 
8577be5
89b12a6
 
e4f812c
fdd8ddb
07ffad3
 
18b530b
07ffad3
 
e4b2161
 
31630cf
 
18b530b
 
 
 
42df98c
afb805b
07ffad3
 
18b530b
ef4a283
 
 
 
 
 
 
 
 
 
 
e1ca399
ef4a283
 
8b048b4
ef4a283
18b530b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
import gradio as gr
from datasets import load_dataset

import os
import spaces
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
import torch
from threading import Thread
from sentence_transformers import SentenceTransformer
from datasets import load_dataset
import time

token = os.environ["HF_TOKEN"]
model = AutoModelForCausalLM.from_pretrained(
    "google/gemma-7b-it",
    # torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
    torch_dtype=torch.float16,
    token=token,
)
tokenizer = AutoTokenizer.from_pretrained("google/gemma-7b-it", token=token)
device = torch.device("cuda")
model = model.to(device)
RAG = SentenceTransformer("mixedbread-ai/mxbai-embed-large-v1")
TOP_K = 1
HEADER = "\n# RESOURCES:\n"
# prepare data
# since data is too big we will only select the first 3K lines

data = load_dataset("not-lain/wikipedia-small-3000-embedded", split="train")

# index dataset
data.add_faiss_index("embedding")


def search(query: str, k: int = TOP_K):
    embedded_query = RAG.encode(query)
    scores, retrieved_examples = data.get_nearest_examples(
        "embedding", embedded_query, k=k
    )
    return retrieved_examples


def prepare_prompt(query, retrieved_examples):
    prompt = (
        f"Query: {query}\nContinue to answer the query in short sentences by using the Search Results:\n"
    )
    urls = []
    titles = retrieved_examples["title"][::-1]
    texts = retrieved_examples["text"][::-1]
    urls = retrieved_examples["url"][::-1]
    titles = titles[::-1]
    for i in range(TOP_K):
        prompt += f"* {texts[i]}\n"
    return prompt, zip(titles, urls)


@spaces.GPU(duration=150)
def talk(message, history):
    print("history, ", history)
    print("message ", message)
    print("searching dataset ...")
    retrieved_examples = search(message)
    print("preparing prompt ...")
    message, metadata = prepare_prompt(message, retrieved_examples)
    resources = HEADER
    print("preparing metadata ...")
    for title, url in metadata:
        resources += f"[{title}]({url}),  "
    print("preparing chat template ...")
    chat = []
    for item in history:
        chat.append({"role": "user", "content": item[0]})
        cleaned_past = item[1].split(HEADER)[0]
        chat.append({"role": "assistant", "content": cleaned_past})
    chat.append({"role": "user", "content": message})
    messages = tokenizer.apply_chat_template(
        chat, tokenize=False, add_generation_prompt=True
    )
    print("chat template prepared, ", messages)
    print("tokenizing input ...")
    # Tokenize the messages string
    model_inputs = tokenizer([messages], return_tensors="pt").to(device)
    streamer = TextIteratorStreamer(
        tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True
    )
    generate_kwargs = dict(
        model_inputs,
        streamer=streamer,
        max_new_tokens=1024,
        do_sample=True,
        top_p=0.95,
        top_k=1000,
        temperature=0.75,
        num_beams=1,
    )
    print("initializing thread ...")
    t = Thread(target=model.generate, kwargs=generate_kwargs)
    t.start()
    time.sleep(1)
    # Initialize an empty string to store the generated text
    partial_text = ""
    i = 0
    while t.is_alive():
        try:
            for new_text in streamer:
                if new_text is not None:
                    partial_text += new_text
                    yield partial_text
        except Exception as e:
            print(f"retry number {i}\n LOGS:\n")
            i+=1
            print(e, e.args)
    partial_text += resources
    yield partial_text


TITLE = "# RAG"

DESCRIPTION = """
A rag pipeline with a chatbot feature

Resources used to build this project :

* embedding model : https://huggingface.co/mixedbread-ai/mxbai-embed-large-v1
* dataset : https://huggingface.co/datasets/not-lain/wikipedia-small-3000-embedded (used mxbai-colbert-large-v1 to create the embedding column )
* faiss docs : https://huggingface.co/docs/datasets/v2.18.0/en/package_reference/main_classes#datasets.Dataset.add_faiss_index 
* chatbot : https://huggingface.co/google/gemma-7b-it

If you want to support my work consider clicking on the heart react button ❤️🤗
"""


demo = gr.ChatInterface(
    fn=talk,
    chatbot=gr.Chatbot(
        show_label=True,
        show_share_button=True,
        show_copy_button=True,
        likeable=True,
        layout="bubble",
        bubble_full_width=False,
    ),
    theme="Soft",
    examples=[["what's anarchy ? "]],
    title=TITLE,
    description=DESCRIPTION,
    
)
demo.launch(debug=True)