Commit
·
16959be
1
Parent(s):
156d9bb
Upload 7 files
Browse files- .gitignore +3 -0
- main.py +38 -0
- models/law_component.py +153 -0
- requirements.txt +3 -0
- reranker/reranker.py +20 -0
- retriever/es_retriever.py +89 -0
- utils/preprocessing.py +11 -0
.gitignore
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
flagged/
|
2 |
+
.idea/
|
3 |
+
*/__pycache__/
|
main.py
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# This is a sample Python script.
|
2 |
+
|
3 |
+
# Press Shift+F10 to execute it or replace it with your code.
|
4 |
+
# Press Double Shift to search everywhere for classes, files, tool windows, actions, and settings.
|
5 |
+
|
6 |
+
import gradio as gr
|
7 |
+
import os
|
8 |
+
|
9 |
+
from reranker.reranker import CrossEncReranker
|
10 |
+
from retriever.es_retriever import ESRetriever
|
11 |
+
from utils.preprocessing import question_to_statement
|
12 |
+
|
13 |
+
|
14 |
+
ES_HOST = os.environ["ES_HOST"]
|
15 |
+
ES_INDEX_NAME = os.environ["ES_INDEX_NAME"]
|
16 |
+
ES_USERNAME = os.environ["ES_USERNAME"]
|
17 |
+
ES_PASSWORD = os.environ["ES_PASSWORD"]
|
18 |
+
|
19 |
+
RERANKER_MODEL_NAME = "douglasfaisal/granularity-legal-reranker-cross-encoder-indobert-base-p2"
|
20 |
+
|
21 |
+
es_retriever_client = ESRetriever(ES_HOST, ES_INDEX_NAME, ES_USERNAME, ES_PASSWORD)
|
22 |
+
cross_enc_reranker = CrossEncReranker(RERANKER_MODEL_NAME, 512)
|
23 |
+
|
24 |
+
def retrieve_and_rerank(question: str):
|
25 |
+
|
26 |
+
query = question_to_statement(question)
|
27 |
+
retrieval_results = es_retriever_client.retrieve(query)
|
28 |
+
reranker_results = cross_enc_reranker.rerank(query, retrieval_results)
|
29 |
+
|
30 |
+
return reranker_results[0].text
|
31 |
+
|
32 |
+
|
33 |
+
demo = gr.Interface(fn=retrieve_and_rerank, inputs="text", outputs="text")
|
34 |
+
|
35 |
+
# Press the green button in the gutter to run the script.
|
36 |
+
demo.launch()
|
37 |
+
|
38 |
+
# See PyCharm help at https://www.jetbrains.com/help/pycharm/
|
models/law_component.py
ADDED
@@ -0,0 +1,153 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import copy
|
2 |
+
|
3 |
+
LAW_PATTERN = '(?P<law_type>[A-Za-z]+) (?P<law_number>\d{1,})/(?P<law_year>\d{4})'
|
4 |
+
|
5 |
+
class LawComponent:
|
6 |
+
law_type: str = None
|
7 |
+
law_year: int = 0
|
8 |
+
law_number: int = 0
|
9 |
+
component_type: str = None
|
10 |
+
chapter: int = None
|
11 |
+
article: str = None
|
12 |
+
subsection: int = None
|
13 |
+
letter: str = None
|
14 |
+
text: str = None
|
15 |
+
|
16 |
+
def __init__(self):
|
17 |
+
pass
|
18 |
+
|
19 |
+
def from_uri(uri: str):
|
20 |
+
lc = LawComponent()
|
21 |
+
uri_split = uri.split('/')
|
22 |
+
lc.law_type = uri_split[4]
|
23 |
+
lc.law_year = int(uri_split[5]) if int(uri_split[5]) != 0 else None
|
24 |
+
lc.law_number = int(uri_split[6]) if int(uri_split[6]) != 0 else None
|
25 |
+
if len(uri_split) < 8:
|
26 |
+
return lc
|
27 |
+
if (uri_split[7] == 'bab'):
|
28 |
+
lc.component_type = 'chapter'
|
29 |
+
lc.chapter = int(uri_split[8]) if int(uri_split[8]) != 0 else None
|
30 |
+
else:
|
31 |
+
lc.article = str(int(uri_split[8])) if int(uri_split[8]) != 0 else None
|
32 |
+
lc.component_type = 'article'
|
33 |
+
if (len(uri_split) > 9 and uri_split[9] == "versi"):
|
34 |
+
if (len(uri_split) > 11 and uri_split[11] == "ayat"):
|
35 |
+
lc.subsection = int(uri_split[12]) if int(uri_split[12]) != 0 else None
|
36 |
+
lc.component_type = 'subsection'
|
37 |
+
if (len(uri_split) > 13 and uri_split[13] == "huruf"):
|
38 |
+
lc.component_type = 'letter'
|
39 |
+
try:
|
40 |
+
lc.letter = str(int(uri_split[14])) if int(uri_split[14]) != 0 else None
|
41 |
+
except:
|
42 |
+
lc.letter = uri_split[14]
|
43 |
+
elif (len(uri_split) > 11 and uri_split[11] == "huruf"):
|
44 |
+
lc.component_type = 'letter'
|
45 |
+
try:
|
46 |
+
lc.letter = str(int(uri_split[12])) if int(uri_split[12]) != 0 else None
|
47 |
+
except:
|
48 |
+
lc.letter = uri_split[12]
|
49 |
+
|
50 |
+
return lc
|
51 |
+
|
52 |
+
# def from_answer_granularity_row(row_dict: dict):
|
53 |
+
# lc = LawComponent()
|
54 |
+
#
|
55 |
+
# law = row_dict['Law']
|
56 |
+
# law_search = re.search(LAW_PATTERN, law)
|
57 |
+
#
|
58 |
+
# if (law_search != None):
|
59 |
+
# lc.law_type = law_search.group('law_type').lower()
|
60 |
+
# lc.law_number = int(law_search.group('law_number')) if int(law_search.group('law_number')) != 0 else None
|
61 |
+
# lc.law_year = int(law_search.group('law_year')) if int(law_search.group('law_year')) != 0 else None
|
62 |
+
#
|
63 |
+
# lc.component_type = row_dict['Answer Granularity'].lower()
|
64 |
+
# lc.chapter = int(row_dict['Chapter']) if int(row_dict['Chapter']) != 0 else None
|
65 |
+
# try:
|
66 |
+
# lc.article = str(int(row_dict['Article'])) if int(row_dict['Article']) != 0 else None
|
67 |
+
# except:
|
68 |
+
# pass
|
69 |
+
# try:
|
70 |
+
# lc.subsection = int(row_dict['Subsection']) if int(row_dict['Subsection']) != 0 else None
|
71 |
+
# except:
|
72 |
+
# pass
|
73 |
+
# try:
|
74 |
+
# lc.letter = str(int(row_dict['Letter (1st level)'])) if int(
|
75 |
+
# row_dict['Letter (1st level)']) != 0 else None
|
76 |
+
# except:
|
77 |
+
# if (pd.isnull(row_dict['Letter (1st level)'])):
|
78 |
+
# lc.letter = None
|
79 |
+
# else:
|
80 |
+
# lc.letter = row_dict['Letter (1st level)']
|
81 |
+
#
|
82 |
+
# return lc
|
83 |
+
|
84 |
+
def set_text(self, text):
|
85 |
+
self.text = text
|
86 |
+
|
87 |
+
def __eq__(self, other):
|
88 |
+
|
89 |
+
if not (self.law_type == other.law_type and self.law_year == other.law_year
|
90 |
+
and self.law_number == other.law_number):
|
91 |
+
return False
|
92 |
+
|
93 |
+
if self.component_type != other.component_type:
|
94 |
+
return False
|
95 |
+
|
96 |
+
# if self.component_type == 'chapter':
|
97 |
+
# if self.chapter != other.chapter:
|
98 |
+
# return False
|
99 |
+
if self.article is None and other.article is None:
|
100 |
+
if self.chapter != other.chapter:
|
101 |
+
return False
|
102 |
+
else:
|
103 |
+
if self.article != other.article:
|
104 |
+
return False
|
105 |
+
|
106 |
+
if self.component_type == 'article':
|
107 |
+
return True
|
108 |
+
|
109 |
+
if self.subsection != other.subsection:
|
110 |
+
return False
|
111 |
+
|
112 |
+
if self.component_type == 'subsection':
|
113 |
+
return True
|
114 |
+
|
115 |
+
if self.letter != other.letter:
|
116 |
+
return False
|
117 |
+
|
118 |
+
if self.component_type == 'letter':
|
119 |
+
return True
|
120 |
+
|
121 |
+
return True
|
122 |
+
|
123 |
+
def is_article_equal(self, other):
|
124 |
+
|
125 |
+
if not (self.law_type == other.law_type and self.law_year == other.law_year
|
126 |
+
and self.law_number == other.law_number):
|
127 |
+
return False
|
128 |
+
|
129 |
+
if self.component_type == 'chapter':
|
130 |
+
return False
|
131 |
+
else:
|
132 |
+
if self.article != other.article:
|
133 |
+
return False
|
134 |
+
|
135 |
+
return True
|
136 |
+
|
137 |
+
def __repr__(self):
|
138 |
+
return "LawComponent({}, {}, {}, {}, {}, {}, {}, {})".format(
|
139 |
+
self.law_type,
|
140 |
+
self.law_number,
|
141 |
+
self.law_year,
|
142 |
+
self.component_type,
|
143 |
+
self.chapter,
|
144 |
+
self.article,
|
145 |
+
self.subsection,
|
146 |
+
self.letter
|
147 |
+
)
|
148 |
+
|
149 |
+
def __str__(self):
|
150 |
+
return self.__repr__()
|
151 |
+
|
152 |
+
def copy(self):
|
153 |
+
return copy.deepcopy(self)
|
requirements.txt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
gradio~=3.28.1
|
2 |
+
numpy==1.21.4
|
3 |
+
requests==2.26.0
|
reranker/reranker.py
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
|
3 |
+
from models.law_component import LawComponent
|
4 |
+
from sentence_transformers.cross_encoder import CrossEncoder
|
5 |
+
|
6 |
+
class CrossEncReranker:
|
7 |
+
|
8 |
+
def __init__(self, model_name, max_length=512):
|
9 |
+
self.model_name = model_name
|
10 |
+
self.reranker = CrossEncoder(self.model_name)
|
11 |
+
self.reranker.max_length = max_length
|
12 |
+
|
13 |
+
def rerank(self, query_text: str, candidates: list[LawComponent]):
|
14 |
+
sentence_combinations = [[query_text, c.text] for c in candidates]
|
15 |
+
similarity_scores = self.reranker.predict(sentence_combinations)
|
16 |
+
index = np.argsort(similarity_scores)[::-1]
|
17 |
+
|
18 |
+
reranked_candidates = np.array(candidates)[index]
|
19 |
+
return reranked_candidates
|
20 |
+
|
retriever/es_retriever.py
ADDED
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import copy
|
2 |
+
|
3 |
+
import json
|
4 |
+
import requests
|
5 |
+
from requests.auth import HTTPBasicAuth
|
6 |
+
from models.law_component import LawComponent
|
7 |
+
|
8 |
+
base_query = {
|
9 |
+
"query": {
|
10 |
+
"bool": {
|
11 |
+
"should": [
|
12 |
+
{
|
13 |
+
"match": {
|
14 |
+
"text": {
|
15 |
+
"query": None,
|
16 |
+
"boost": 1.0
|
17 |
+
}
|
18 |
+
}
|
19 |
+
},
|
20 |
+
{
|
21 |
+
"match": {
|
22 |
+
"chapterTitle": {
|
23 |
+
"query": None,
|
24 |
+
"boost": 1.0
|
25 |
+
}
|
26 |
+
}
|
27 |
+
},
|
28 |
+
{
|
29 |
+
"match_phrase": {
|
30 |
+
"text": {
|
31 |
+
"query": None,
|
32 |
+
"boost": 1.0
|
33 |
+
}
|
34 |
+
}
|
35 |
+
},
|
36 |
+
{
|
37 |
+
"match_phrase": {
|
38 |
+
"chapterTitle": {
|
39 |
+
"query": None,
|
40 |
+
"boost": 1.0
|
41 |
+
}
|
42 |
+
}
|
43 |
+
}
|
44 |
+
],
|
45 |
+
"minimum_should_match": 1
|
46 |
+
}
|
47 |
+
}
|
48 |
+
}
|
49 |
+
|
50 |
+
class ESRetriever:
|
51 |
+
|
52 |
+
def __init__(self, es_host, es_index_name, es_username="", es_password=""):
|
53 |
+
self.es_host = es_host
|
54 |
+
self.es_index_name = es_index_name
|
55 |
+
self.es_username = es_username
|
56 |
+
self.es_password = es_password
|
57 |
+
|
58 |
+
if (es_username != "" and es_password != ""):
|
59 |
+
self.auth = HTTPBasicAuth(es_username, es_password)
|
60 |
+
else:
|
61 |
+
self.auth = None
|
62 |
+
|
63 |
+
# Returns LawComponent
|
64 |
+
def retrieve(self, query_text: str):
|
65 |
+
query = copy.deepcopy(base_query)
|
66 |
+
query['query']['bool']['should'][0]['match']['text']['query'] = query_text
|
67 |
+
query['query']['bool']['should'][1]['match']['chapterTitle']['query'] = query_text
|
68 |
+
query['query']['bool']['should'][2]['match_phrase']['text']['query'] = query_text
|
69 |
+
query['query']['bool']['should'][3]['match_phrase']['chapterTitle']['query'] = query_text
|
70 |
+
|
71 |
+
# try:
|
72 |
+
response = requests.get(
|
73 |
+
self.es_host + self.es_index_name + '/_search',
|
74 |
+
headers={'Content-Type': 'application/json'},
|
75 |
+
data=json.dumps(query),
|
76 |
+
auth=self.auth
|
77 |
+
)
|
78 |
+
|
79 |
+
if response.ok:
|
80 |
+
results = response.json()["hits"]["hits"]
|
81 |
+
retrieval_results = []
|
82 |
+
for result in results:
|
83 |
+
lc = LawComponent.from_uri(result["_source"]["uri"])
|
84 |
+
lc.set_text(result["_source"]["text"])
|
85 |
+
retrieval_results.append(lc)
|
86 |
+
return retrieval_results
|
87 |
+
|
88 |
+
#
|
89 |
+
# response.content
|
utils/preprocessing.py
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
+
|
3 |
+
|
4 |
+
def question_to_statement(question: str):
|
5 |
+
pattern = "(?P<qWord>kapan|(ber|meng|si)?apa|(di|ke) ?(mana)|bagaimana) ?(saja)? ?(kah)?"
|
6 |
+
|
7 |
+
result = re.sub(pattern, "", question.lower())
|
8 |
+
result = result.replace("?", "")
|
9 |
+
result = re.sub("\s+", " ", result)
|
10 |
+
result = result.strip()
|
11 |
+
return result
|