Yoxas commited on
Commit
921072e
·
verified ·
1 Parent(s): 079a471

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +29 -153
app.py CHANGED
@@ -1,156 +1,32 @@
1
  import gradio as gr
2
- from datasets import load_dataset
3
-
4
- import os
5
- import spaces
6
- from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer, BitsAndBytesConfig
7
- import torch
8
- from threading import Thread
9
- from sentence_transformers import SentenceTransformer
10
- import numpy as np
11
-
12
- token = os.environ["HF_TOKEN"]
13
- ST = SentenceTransformer("mixedbread-ai/mxbai-embed-large-v1")
14
-
15
- dataset = load_dataset("Yoxas/statistical_literacyv2")
16
-
17
- data = dataset["train"]
18
-
19
- # Convert string embeddings to numpy arrays and ensure they are 2D
20
- def convert_and_ensure_2d_embeddings(example):
21
- embedding_str = example['embedding']
22
- embedding_str = embedding_str.replace('\n', ' ').replace('...', '')
23
- embedding_list = list(map(float, embedding_str.strip("[]").split()))
24
- embeddings = np.array(embedding_list, dtype=np.float32)
25
- # Ensure the embeddings are 2-dimensional
26
- if embeddings.ndim == 1:
27
- embeddings = embeddings.reshape(1, -1)
28
- return {'embedding': embeddings}
29
-
30
- # Apply the function to ensure embeddings are 2-dimensional and of type float32
31
- data = data.map(convert_and_ensure_2d_embeddings)
32
-
33
- # Flatten embeddings if they are nested 2D arrays
34
- def flatten_embeddings(example):
35
- embedding = np.array(example['embedding'], dtype=np.float32)
36
- if embedding.ndim == 2 and embedding.shape[0] == 1:
37
- embedding = embedding.flatten()
38
- return {'embedding': embedding}
39
-
40
- data = data.map(flatten_embeddings)
41
-
42
- # Extract embeddings and convert to numpy array
43
- embeddings = np.vstack([example['embedding'] for example in data])
44
-
45
- # Add FAISS index
46
- data = data.add_faiss_index_from_external_arrays("embedding", embeddings)
47
-
48
- model_id = "meta-llama/Meta-Llama-3-8B-Instruct"
49
-
50
- # Use quantization to lower GPU usage
51
- bnb_config = BitsAndBytesConfig(
52
- load_in_4bit=True, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16
53
- )
54
-
55
- tokenizer = AutoTokenizer.from_pretrained(model_id, token=token)
56
- model = AutoModelForCausalLM.from_pretrained(
57
- model_id,
58
- torch_dtype=torch.bfloat16,
59
- device_map="auto",
60
- quantization_config=bnb_config,
61
- token=token
62
  )
63
- terminators = [
64
- tokenizer.eos_token_id,
65
- tokenizer.convert_tokens_to_ids("")
66
- ]
67
-
68
- SYS_PROMPT = """You are an assistant for answering questions.
69
- You are given the extracted parts of a long document and a question. Provide a conversational answer.
70
- If you don't know the answer, just say "I do not know." Don't make up an answer."""
71
-
72
- def search(query: str, k: int = 3):
73
- """A function that embeds a new query and returns the most probable results."""
74
- embedded_query = ST.encode(query) # Embed new query
75
- scores, retrieved_examples = data.get_nearest_examples( # Retrieve results
76
- "embedding", embedded_query, # Compare our new embedded query with the dataset embeddings
77
- k=k # Get only top k results
78
- )
79
- return scores, retrieved_examples
80
-
81
- def format_prompt(prompt, retrieved_documents, k):
82
- """Using the retrieved documents we will prompt the model to generate our responses."""
83
- PROMPT = f"Question:{prompt}\nContext:"
84
- for idx in range(k):
85
- PROMPT += f"{retrieved_documents['text'][idx]}\n"
86
- return PROMPT
87
 
88
- @spaces.GPU(duration=150)
89
- def talk(prompt, history):
90
- k = 1 # Number of retrieved documents
91
- scores, retrieved_documents = search(prompt, k)
92
- formatted_prompt = format_prompt(prompt, retrieved_documents, k)
93
- formatted_prompt = formatted_prompt[:2000] # To avoid GPU OOM
94
- messages = [{"role": "system", "content": SYS_PROMPT}, {"role": "user", "content": formatted_prompt}]
95
- # Tell the model to generate
96
- input_ids = tokenizer.apply_chat_template(
97
- messages,
98
- add_generation_prompt=True,
99
- return_tensors="pt"
100
- ).to(model.device)
101
- outputs = model.generate(
102
- input_ids,
103
- max_new_tokens=1024,
104
- eos_token_id=terminators,
105
- do_sample=True,
106
- temperature=0.6,
107
- top_p=0.9,
108
- )
109
- streamer = TextIteratorStreamer(
110
- tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True
111
- )
112
- generate_kwargs = dict(
113
- input_ids=input_ids,
114
- streamer=streamer,
115
- max_new_tokens=1024,
116
- do_sample=True,
117
- top_p=0.95,
118
- temperature=0.75,
119
- eos_token_id=terminators,
120
- )
121
- t = Thread(target=model.generate, kwargs=generate_kwargs)
122
- t.start()
123
-
124
- outputs = []
125
- for text in streamer:
126
- outputs.append(text)
127
- print(outputs)
128
- yield "".join(outputs)
129
-
130
- TITLE = "# RAG"
131
-
132
- DESCRIPTION = """
133
- A RAG pipeline with a chatbot feature
134
- Resources used to build this project:
135
- * Embedding model: https://huggingface.co/mixedbread-ai/mxbai-embed-large-v1
136
- * Dataset: https://huggingface.co/datasets/not-lain/wikipedia
137
- * FAISS docs: https://huggingface.co/docs/datasets/v2.18.0/en/package_reference/main_classes#datasets.Dataset.add_faiss_index
138
- * Chatbot: https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct
139
- """
140
-
141
- demo = gr.ChatInterface(
142
- fn=talk,
143
- chatbot=gr.Chatbot(
144
- show_label=True,
145
- show_share_button=True,
146
- show_copy_button=True,
147
- likeable=True,
148
- layout="bubble",
149
- bubble_full_width=False,
150
- ),
151
- theme="Soft",
152
- examples=[["what's anarchy?"]],
153
- title=TITLE,
154
- description=DESCRIPTION,
155
- )
156
- demo.launch(debug=True)
 
1
  import gradio as gr
2
+ import pandas as pd
3
+ from transformers import pipeline
4
+
5
+ # Load CSV data
6
+ data = pd.read_csv('documents.csv')
7
+
8
+ # Load a transformer model (you can choose a suitable model from Hugging Face)
9
+ # For this example, we'll use a simple QA model
10
+ qa_model = pipeline("question-answering", model="distilbert-base-uncased-distilled-squad")
11
+
12
+ # Function to retrieve the relevant document and generate a response
13
+ def retrieve_and_generate(question):
14
+ # Combine all abstracts into a single string (you can improve this by better retrieval methods)
15
+ abstracts = " ".join(data['Abstract'].fillna("").tolist())
16
+
17
+ # Retrieve the most relevant section from the combined abstracts
18
+ response = qa_model(question=question, context=abstracts)
19
+
20
+ return response['answer']
21
+
22
+ # Create a Gradio interface
23
+ interface = gr.Interface(
24
+ fn=retrieve_and_generate,
25
+ inputs=gr.inputs.Textbox(lines=2, placeholder="Ask a question about the documents..."),
26
+ outputs="text",
27
+ title="RAG Chatbot",
28
+ description="Ask questions about the documents in the CSV file."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
 
31
+ # Launch the Gradio app
32
+ interface.launch()