madoss commited on
Commit
1581c72
·
1 Parent(s): 2014880

Upload query_index.py

Browse files
Files changed (1) hide show
  1. query_index.py +43 -0
query_index.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import logging
3
+
4
+ import datasets
5
+ import sentence_transformers
6
+
7
+ import utils
8
+
9
+ logging.disable(logging.CRITICAL)
10
+
11
+ parser = argparse.ArgumentParser()
12
+ parser.add_argument("--query", type=str, required=True)
13
+ parser.add_argument("--k", type=int, default=5)
14
+ args = parser.parse_args()
15
+
16
+ model = sentence_transformers.SentenceTransformer(
17
+ "dangvantuan/sentence-camembert-large", device="cuda"
18
+ )
19
+
20
+ dataset = datasets.load_dataset("json", data_files=["./data/dataset.json"], split="train")
21
+ dataset.load_faiss_index("embeddings", "index.faiss")
22
+
23
+ query_embedding = model.encode(args.query)
24
+ _, retrieved_examples = dataset.get_nearest_examples(
25
+ "embeddings",
26
+ query_embedding,
27
+ k=args.k,
28
+ )
29
+
30
+
31
+ for text, start, end, title, url in zip(
32
+ retrieved_examples["text"],
33
+ retrieved_examples["start"],
34
+ retrieved_examples["end"],
35
+ retrieved_examples["title"],
36
+ retrieved_examples["url"],
37
+ ):
38
+ start = start
39
+ end = end
40
+ print(f"title: {title}")
41
+ print(f"transcript: [{str(start)+' ====> '+str(end)}] {text}")
42
+ print(f"link: {url}")
43
+ print("*" * 10)