import gradio as gr from datasets import load_dataset from sentence_transformers import SentenceTransformer from sentence_transformers.quantization import quantize_embeddings import faiss from usearch.index import Index import os import spaces from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer import torch from threading import Thread token = os.environ["HF_TOKEN"] model = AutoModelForCausalLM.from_pretrained("google/gemma-7b-it", # torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, torch_dtype=torch.float16, token=token) tok = AutoTokenizer.from_pretrained("google/gemma-7b-it",token=token) device = torch.device('cuda') model = model.to(device) # Load titles and texts title_text_dataset = load_dataset( "mixedbread-ai/wikipedia-data-en-2023-11", split="train", num_proc=4 ).select_columns(["title", "text"]) # Load the int8 and binary indices. Int8 is loaded as a view to save memory, as we never actually perform search with it. int8_view = Index.restore("wikipedia_int8_usearch_50m.index", view=True) binary_index: faiss.IndexBinaryFlat = faiss.read_index_binary( "wikipedia_ubinary_faiss_50m.index" ) binary_ivf: faiss.IndexBinaryIVF = faiss.read_index_binary( "wikipedia_ubinary_ivf_faiss_50m.index" ) # Load the SentenceTransformer model for embedding the queries model = SentenceTransformer( "mixedbread-ai/mxbai-embed-large-v1", prompts={ "retrieval": "Represent this sentence for searching relevant passages: ", }, default_prompt_name="retrieval", ) def search( query, top_k: int = 10, rescore_multiplier: int = 1, use_approx: bool = False ): # 1. Embed the query as float32 query_embedding = model.encode(query) # 2. Quantize the query to ubinary query_embedding_ubinary = quantize_embeddings( query_embedding.reshape(1, -1), "ubinary" ) # 3. Search the binary index (either exact or approximate) index = binary_ivf if use_approx else binary_index _scores, binary_ids = index.search( query_embedding_ubinary, top_k * rescore_multiplier ) binary_ids = binary_ids[0] # 4. Load the corresponding int8 embeddings int8_embeddings = int8_view[binary_ids].astype(int) # 5. Rescore the top_k * rescore_multiplier using the float32 query embedding and the int8 document embeddings scores = query_embedding @ int8_embeddings.T # 6. Sort the scores and return the top_k indices = scores.argsort()[::-1][:top_k] top_k_indices = binary_ids[indices] top_k_scores = scores[indices] top_k_titles, top_k_texts = zip( *[ (title_text_dataset[idx]["title"], title_text_dataset[idx]["text"]) for idx in top_k_indices.tolist() ] ) df = { "Score": [round(value, 2) for value in top_k_scores], "Title": top_k_titles, "Text": top_k_texts, } return df def prepare_prompt(query, df): prompt = f"Query: {query}\nContinue to answer the query by using the Search Results:\n" for data in df : title = data["Title"] text = data["Text"] prompt+=f"Title: {title}, Text: {text}\n" return prompt @spaces.GPU def talk(message, history): df = search(message) message = prepare_prompt(message,df) resources = "\nRESOURCES:\n" for title in df["Title"][:3] : resources+=f"[{title}](https://huggingface.co/spaces/not-lain/RAG), " chat = [] for item in history: chat.append({"role": "user", "content": item[0]}) if item[1] is not None: cleaned_past = item[1].split("\nRESOURCES:\n")[0] chat.append({"role": "assistant", "content": cleaned_past}) chat.append({"role": "user", "content": message}) messages = tok.apply_chat_template(chat, tokenize=False, add_generation_prompt=True) # Tokenize the messages string model_inputs = tok([messages], return_tensors="pt").to(device) streamer = TextIteratorStreamer( tok, timeout=10., skip_prompt=True, skip_special_tokens=True) generate_kwargs = dict( model_inputs, streamer=streamer, max_new_tokens=1024, do_sample=True, top_p=0.95, top_k=1000, temperature=0.75, num_beams=1, ) t = Thread(target=model.generate, kwargs=generate_kwargs) t.start() # Initialize an empty string to store the generated text partial_text = "" for new_text in streamer: partial_text += new_text yield partial_text partial_text+= resources yield partial_text TITLE = "RAG" DESCRIPTION = """ ## Resources used to build this project * https://huggingface.co/learn/cookbook/rag_with_hugging_face_gemma_mongodb * https://huggingface.co/spaces/sentence-transformers/quantized-retrieval ## Retrival paramaters ```python top_k: int = 10, rescore_multiplier: int = 1, use_approx: bool = False ``` ## Models the models used in this space are : * google/gemma-7b-it * mixedbread-ai/wikipedia-data-en-2023-11 """ demo = gr.ChatInterface(fn=talk, chatbot=gr.Chatbot(show_label=True, show_share_button=True, show_copy_button=True, likeable=True, layout="bubble", bubble_full_width=False), theme="Soft", examples=[["Write me a poem about Machine Learning."]], title="Text Streaming") demo.launch()