hassano94 commited on
Commit
462a119
·
verified ·
1 Parent(s): 8a74019

Update RAG_class.py

Browse files
Files changed (1) hide show
  1. RAG_class.py +0 -24
RAG_class.py CHANGED
@@ -8,44 +8,20 @@ class RAG_1177:
8
  def __init__(self):
9
  self.db_name = "RAG_1177"
10
 
11
- self.text_splitter = RecursiveCharacterTextSplitter(chunk_size=2500,chunk_overlap=500,length_function=len)
12
  self.model = SentenceTransformer('KBLab/sentence-bert-swedish-cased')
13
  self.client = chromadb.PersistentClient(path="RAG_1177_db")
14
  self.db = self.client.get_or_create_collection(self.db_name)
15
 
16
- self.url_list_path = "all_urls_list.txt"
17
- self.text_folder = "scraped_texts/"
18
-
19
- def chunk_text_file(self, file_name):
20
- file_name = self.text_folder + file_name
21
- with open(file_name, 'r', encoding='utf-8') as f:
22
- text = f.read()
23
- chunks = self.text_splitter.create_documents([text])
24
- #append chunks as elements in a list
25
- chunks = [chunk.page_content for chunk in chunks]
26
- return chunks
27
-
28
- def get_file_names(self, folder_path):
29
- doc_list = os.listdir(folder_path)
30
- doc_list = sorted(doc_list, key=lambda x: int(x.split('-')[-1].split('.')[0]))
31
- return doc_list
32
 
33
  def get_embeddings(self, text):
34
  embeddings = self.model.encode(text)
35
  return (embeddings.tolist())
36
 
37
- def get_url(self, url_index):
38
- with open(self.url_list_path, 'r') as f:
39
- urls = f.readlines()
40
- return urls[url_index].strip()
41
 
42
  def get_ids(self, num_ids):
43
  ids = [str(uuid.uuid4()) for _ in range(num_ids)]
44
  return ids
45
 
46
- def get_url_dict(self, url, integer):
47
- url_list = [{"url": url} for _ in range(integer)]
48
- return url_list
49
 
50
  def delete_collection(self):
51
  self.client.delete_collection(self.db_name)
 
8
  def __init__(self):
9
  self.db_name = "RAG_1177"
10
 
 
11
  self.model = SentenceTransformer('KBLab/sentence-bert-swedish-cased')
12
  self.client = chromadb.PersistentClient(path="RAG_1177_db")
13
  self.db = self.client.get_or_create_collection(self.db_name)
14
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
  def get_embeddings(self, text):
17
  embeddings = self.model.encode(text)
18
  return (embeddings.tolist())
19
 
 
 
 
 
20
 
21
  def get_ids(self, num_ids):
22
  ids = [str(uuid.uuid4()) for _ in range(num_ids)]
23
  return ids
24
 
 
 
 
25
 
26
  def delete_collection(self):
27
  self.client.delete_collection(self.db_name)