playmak3r commited on
Commit
0c09017
·
1 Parent(s): 11b3ae2

add csv batch test

Browse files
Files changed (1) hide show
  1. tests/test_csv.py +37 -0
tests/test_csv.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+ import csv
3
+ import requests
4
+
5
+
6
+ def main(csv_path: str, target_col: int = 0, source_col: int = 1):
7
+ target_texts = []
8
+ source_texts = []
9
+ with open(csv_path, 'r', newline='', encoding='utf-8') as csvfile:
10
+ reader = csv.reader(csvfile)
11
+ for row in reader:
12
+ target_texts.append(row[target_col])
13
+ source_texts.append(row[source_col])
14
+
15
+ similarities = get_similarity(target_texts, source_texts)
16
+ with open('./tests/output.csv', mode="w", newline="", encoding="utf-8") as new_file:
17
+ writer = csv.writer(new_file)
18
+ for i in range(0, len(target_texts)):
19
+ writer.writerow([ target_texts[i], source_texts[i], similarities[i] ])
20
+
21
+
22
+ def get_similarity(texts1: List[str], texts2: List[str]):
23
+ response = requests.post("http://localhost:8000/api/similarity", json={
24
+ "texts1": texts1,
25
+ "texts2": texts2,
26
+ })
27
+
28
+ response_body = response.json()
29
+ similarities = list(map(lambda i: i['similarity'], response_body))
30
+ return similarities
31
+
32
+
33
+
34
+ if __name__ == "__main__":
35
+ main('./tests/input.csv')
36
+
37
+