douglasfaisal commited on
Commit
16959be
·
1 Parent(s): 156d9bb

Upload 7 files

Browse files
.gitignore ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ flagged/
2
+ .idea/
3
+ */__pycache__/
main.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This is a sample Python script.
2
+
3
+ # Press Shift+F10 to execute it or replace it with your code.
4
+ # Press Double Shift to search everywhere for classes, files, tool windows, actions, and settings.
5
+
6
+ import gradio as gr
7
+ import os
8
+
9
+ from reranker.reranker import CrossEncReranker
10
+ from retriever.es_retriever import ESRetriever
11
+ from utils.preprocessing import question_to_statement
12
+
13
+
14
+ ES_HOST = os.environ["ES_HOST"]
15
+ ES_INDEX_NAME = os.environ["ES_INDEX_NAME"]
16
+ ES_USERNAME = os.environ["ES_USERNAME"]
17
+ ES_PASSWORD = os.environ["ES_PASSWORD"]
18
+
19
+ RERANKER_MODEL_NAME = "douglasfaisal/granularity-legal-reranker-cross-encoder-indobert-base-p2"
20
+
21
+ es_retriever_client = ESRetriever(ES_HOST, ES_INDEX_NAME, ES_USERNAME, ES_PASSWORD)
22
+ cross_enc_reranker = CrossEncReranker(RERANKER_MODEL_NAME, 512)
23
+
24
+ def retrieve_and_rerank(question: str):
25
+
26
+ query = question_to_statement(question)
27
+ retrieval_results = es_retriever_client.retrieve(query)
28
+ reranker_results = cross_enc_reranker.rerank(query, retrieval_results)
29
+
30
+ return reranker_results[0].text
31
+
32
+
33
+ demo = gr.Interface(fn=retrieve_and_rerank, inputs="text", outputs="text")
34
+
35
+ # Press the green button in the gutter to run the script.
36
+ demo.launch()
37
+
38
+ # See PyCharm help at https://www.jetbrains.com/help/pycharm/
models/law_component.py ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+
3
+ LAW_PATTERN = '(?P<law_type>[A-Za-z]+) (?P<law_number>\d{1,})/(?P<law_year>\d{4})'
4
+
5
+ class LawComponent:
6
+ law_type: str = None
7
+ law_year: int = 0
8
+ law_number: int = 0
9
+ component_type: str = None
10
+ chapter: int = None
11
+ article: str = None
12
+ subsection: int = None
13
+ letter: str = None
14
+ text: str = None
15
+
16
+ def __init__(self):
17
+ pass
18
+
19
+ def from_uri(uri: str):
20
+ lc = LawComponent()
21
+ uri_split = uri.split('/')
22
+ lc.law_type = uri_split[4]
23
+ lc.law_year = int(uri_split[5]) if int(uri_split[5]) != 0 else None
24
+ lc.law_number = int(uri_split[6]) if int(uri_split[6]) != 0 else None
25
+ if len(uri_split) < 8:
26
+ return lc
27
+ if (uri_split[7] == 'bab'):
28
+ lc.component_type = 'chapter'
29
+ lc.chapter = int(uri_split[8]) if int(uri_split[8]) != 0 else None
30
+ else:
31
+ lc.article = str(int(uri_split[8])) if int(uri_split[8]) != 0 else None
32
+ lc.component_type = 'article'
33
+ if (len(uri_split) > 9 and uri_split[9] == "versi"):
34
+ if (len(uri_split) > 11 and uri_split[11] == "ayat"):
35
+ lc.subsection = int(uri_split[12]) if int(uri_split[12]) != 0 else None
36
+ lc.component_type = 'subsection'
37
+ if (len(uri_split) > 13 and uri_split[13] == "huruf"):
38
+ lc.component_type = 'letter'
39
+ try:
40
+ lc.letter = str(int(uri_split[14])) if int(uri_split[14]) != 0 else None
41
+ except:
42
+ lc.letter = uri_split[14]
43
+ elif (len(uri_split) > 11 and uri_split[11] == "huruf"):
44
+ lc.component_type = 'letter'
45
+ try:
46
+ lc.letter = str(int(uri_split[12])) if int(uri_split[12]) != 0 else None
47
+ except:
48
+ lc.letter = uri_split[12]
49
+
50
+ return lc
51
+
52
+ # def from_answer_granularity_row(row_dict: dict):
53
+ # lc = LawComponent()
54
+ #
55
+ # law = row_dict['Law']
56
+ # law_search = re.search(LAW_PATTERN, law)
57
+ #
58
+ # if (law_search != None):
59
+ # lc.law_type = law_search.group('law_type').lower()
60
+ # lc.law_number = int(law_search.group('law_number')) if int(law_search.group('law_number')) != 0 else None
61
+ # lc.law_year = int(law_search.group('law_year')) if int(law_search.group('law_year')) != 0 else None
62
+ #
63
+ # lc.component_type = row_dict['Answer Granularity'].lower()
64
+ # lc.chapter = int(row_dict['Chapter']) if int(row_dict['Chapter']) != 0 else None
65
+ # try:
66
+ # lc.article = str(int(row_dict['Article'])) if int(row_dict['Article']) != 0 else None
67
+ # except:
68
+ # pass
69
+ # try:
70
+ # lc.subsection = int(row_dict['Subsection']) if int(row_dict['Subsection']) != 0 else None
71
+ # except:
72
+ # pass
73
+ # try:
74
+ # lc.letter = str(int(row_dict['Letter (1st level)'])) if int(
75
+ # row_dict['Letter (1st level)']) != 0 else None
76
+ # except:
77
+ # if (pd.isnull(row_dict['Letter (1st level)'])):
78
+ # lc.letter = None
79
+ # else:
80
+ # lc.letter = row_dict['Letter (1st level)']
81
+ #
82
+ # return lc
83
+
84
+ def set_text(self, text):
85
+ self.text = text
86
+
87
+ def __eq__(self, other):
88
+
89
+ if not (self.law_type == other.law_type and self.law_year == other.law_year
90
+ and self.law_number == other.law_number):
91
+ return False
92
+
93
+ if self.component_type != other.component_type:
94
+ return False
95
+
96
+ # if self.component_type == 'chapter':
97
+ # if self.chapter != other.chapter:
98
+ # return False
99
+ if self.article is None and other.article is None:
100
+ if self.chapter != other.chapter:
101
+ return False
102
+ else:
103
+ if self.article != other.article:
104
+ return False
105
+
106
+ if self.component_type == 'article':
107
+ return True
108
+
109
+ if self.subsection != other.subsection:
110
+ return False
111
+
112
+ if self.component_type == 'subsection':
113
+ return True
114
+
115
+ if self.letter != other.letter:
116
+ return False
117
+
118
+ if self.component_type == 'letter':
119
+ return True
120
+
121
+ return True
122
+
123
+ def is_article_equal(self, other):
124
+
125
+ if not (self.law_type == other.law_type and self.law_year == other.law_year
126
+ and self.law_number == other.law_number):
127
+ return False
128
+
129
+ if self.component_type == 'chapter':
130
+ return False
131
+ else:
132
+ if self.article != other.article:
133
+ return False
134
+
135
+ return True
136
+
137
+ def __repr__(self):
138
+ return "LawComponent({}, {}, {}, {}, {}, {}, {}, {})".format(
139
+ self.law_type,
140
+ self.law_number,
141
+ self.law_year,
142
+ self.component_type,
143
+ self.chapter,
144
+ self.article,
145
+ self.subsection,
146
+ self.letter
147
+ )
148
+
149
+ def __str__(self):
150
+ return self.__repr__()
151
+
152
+ def copy(self):
153
+ return copy.deepcopy(self)
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ gradio~=3.28.1
2
+ numpy==1.21.4
3
+ requests==2.26.0
reranker/reranker.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+ from models.law_component import LawComponent
4
+ from sentence_transformers.cross_encoder import CrossEncoder
5
+
6
+ class CrossEncReranker:
7
+
8
+ def __init__(self, model_name, max_length=512):
9
+ self.model_name = model_name
10
+ self.reranker = CrossEncoder(self.model_name)
11
+ self.reranker.max_length = max_length
12
+
13
+ def rerank(self, query_text: str, candidates: list[LawComponent]):
14
+ sentence_combinations = [[query_text, c.text] for c in candidates]
15
+ similarity_scores = self.reranker.predict(sentence_combinations)
16
+ index = np.argsort(similarity_scores)[::-1]
17
+
18
+ reranked_candidates = np.array(candidates)[index]
19
+ return reranked_candidates
20
+
retriever/es_retriever.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+
3
+ import json
4
+ import requests
5
+ from requests.auth import HTTPBasicAuth
6
+ from models.law_component import LawComponent
7
+
8
+ base_query = {
9
+ "query": {
10
+ "bool": {
11
+ "should": [
12
+ {
13
+ "match": {
14
+ "text": {
15
+ "query": None,
16
+ "boost": 1.0
17
+ }
18
+ }
19
+ },
20
+ {
21
+ "match": {
22
+ "chapterTitle": {
23
+ "query": None,
24
+ "boost": 1.0
25
+ }
26
+ }
27
+ },
28
+ {
29
+ "match_phrase": {
30
+ "text": {
31
+ "query": None,
32
+ "boost": 1.0
33
+ }
34
+ }
35
+ },
36
+ {
37
+ "match_phrase": {
38
+ "chapterTitle": {
39
+ "query": None,
40
+ "boost": 1.0
41
+ }
42
+ }
43
+ }
44
+ ],
45
+ "minimum_should_match": 1
46
+ }
47
+ }
48
+ }
49
+
50
+ class ESRetriever:
51
+
52
+ def __init__(self, es_host, es_index_name, es_username="", es_password=""):
53
+ self.es_host = es_host
54
+ self.es_index_name = es_index_name
55
+ self.es_username = es_username
56
+ self.es_password = es_password
57
+
58
+ if (es_username != "" and es_password != ""):
59
+ self.auth = HTTPBasicAuth(es_username, es_password)
60
+ else:
61
+ self.auth = None
62
+
63
+ # Returns LawComponent
64
+ def retrieve(self, query_text: str):
65
+ query = copy.deepcopy(base_query)
66
+ query['query']['bool']['should'][0]['match']['text']['query'] = query_text
67
+ query['query']['bool']['should'][1]['match']['chapterTitle']['query'] = query_text
68
+ query['query']['bool']['should'][2]['match_phrase']['text']['query'] = query_text
69
+ query['query']['bool']['should'][3]['match_phrase']['chapterTitle']['query'] = query_text
70
+
71
+ # try:
72
+ response = requests.get(
73
+ self.es_host + self.es_index_name + '/_search',
74
+ headers={'Content-Type': 'application/json'},
75
+ data=json.dumps(query),
76
+ auth=self.auth
77
+ )
78
+
79
+ if response.ok:
80
+ results = response.json()["hits"]["hits"]
81
+ retrieval_results = []
82
+ for result in results:
83
+ lc = LawComponent.from_uri(result["_source"]["uri"])
84
+ lc.set_text(result["_source"]["text"])
85
+ retrieval_results.append(lc)
86
+ return retrieval_results
87
+
88
+ #
89
+ # response.content
utils/preprocessing.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+
3
+
4
+ def question_to_statement(question: str):
5
+ pattern = "(?P<qWord>kapan|(ber|meng|si)?apa|(di|ke) ?(mana)|bagaimana) ?(saja)? ?(kah)?"
6
+
7
+ result = re.sub(pattern, "", question.lower())
8
+ result = result.replace("?", "")
9
+ result = re.sub("\s+", " ", result)
10
+ result = result.strip()
11
+ return result