Spaces:
Sleeping
Sleeping
johaunh
commited on
Commit
Β·
3648e12
1
Parent(s):
84e4786
Update pipeline
Browse files- README.md +14 -11
- chains.py +3 -3
- main.py +114 -48
- schema.yml +3 -3
- utils.py +31 -34
README.md
CHANGED
@@ -1,6 +1,6 @@
|
|
1 |
# Text2KG
|
2 |
|
3 |
-
|
4 |
|
5 |
## Usage
|
6 |
|
@@ -25,21 +25,24 @@ in the repository's directory.
|
|
25 |
Import the primary pipeline method using
|
26 |
|
27 |
```python
|
28 |
-
>>> from main import
|
29 |
```
|
30 |
|
31 |
-
**`
|
32 |
|
33 |
```
|
34 |
api_key (str)
|
35 |
OpenAI API key
|
36 |
-
|
|
|
37 |
Number of sentences per forward pass
|
38 |
-
|
39 |
-
|
40 |
-
|
|
|
41 |
text (str)
|
42 |
Input text to extract knowledge graph from
|
|
|
43 |
progress
|
44 |
Progress bar. The default is Gradio's progress bar;
|
45 |
set `progress = tqdm` for implementations outside of Gradio
|
@@ -53,10 +56,10 @@ Read more [here](https://www.gradio.app/docs/python-client).
|
|
53 |
|
54 |
```
|
55 |
chains.py
|
56 |
-
Converts schema.yml
|
57 |
|
58 |
-
|
59 |
-
Contains packages required to run
|
60 |
|
61 |
main.py
|
62 |
Main pipeline/app code
|
@@ -65,7 +68,7 @@ README.md
|
|
65 |
This file
|
66 |
|
67 |
schema.yml
|
68 |
-
Contains definitions of prompts
|
69 |
|
70 |
utils.py
|
71 |
Contains helper functions
|
|
|
1 |
# Text2KG
|
2 |
|
3 |
+
We introduce Text2KG β an intuitive, domain-independent tool that leverages the creative generative ability of GPT-3.5 in the KG construction process. Text2KG automates and accelerates the construction of KGs from unstructured plain text, reducing the need for traditionally-used human labor and computer resources. Our approach incorporates a novel, clause-based text simplification step, reducing the processing of even the most extensive corpora down to the order of minutes. With Text2KG, we aim to streamline the creation of databases from natural language, offering a robust, cost-effective, and user-friendly solution for KG construction.
|
4 |
|
5 |
## Usage
|
6 |
|
|
|
25 |
Import the primary pipeline method using
|
26 |
|
27 |
```python
|
28 |
+
>>> from main import extract_knowledge_graph
|
29 |
```
|
30 |
|
31 |
+
**`extract_knowledge_graph` parameters**
|
32 |
|
33 |
```
|
34 |
api_key (str)
|
35 |
OpenAI API key
|
36 |
+
|
37 |
+
batch_size (int)
|
38 |
Number of sentences per forward pass
|
39 |
+
|
40 |
+
modules (list)
|
41 |
+
Additional modules to add before main extraction process (triplet_extraction). Must be a valid name in schema.yml
|
42 |
+
|
43 |
text (str)
|
44 |
Input text to extract knowledge graph from
|
45 |
+
|
46 |
progress
|
47 |
Progress bar. The default is Gradio's progress bar;
|
48 |
set `progress = tqdm` for implementations outside of Gradio
|
|
|
56 |
|
57 |
```
|
58 |
chains.py
|
59 |
+
Converts the items in schema.yml to LangChain modules
|
60 |
|
61 |
+
requirements.txt
|
62 |
+
Contains packages required to run Text2KG
|
63 |
|
64 |
main.py
|
65 |
Main pipeline/app code
|
|
|
68 |
This file
|
69 |
|
70 |
schema.yml
|
71 |
+
Contains definitions of modules -- prompts + output parsers
|
72 |
|
73 |
utils.py
|
74 |
Contains helper functions
|
chains.py
CHANGED
@@ -10,7 +10,7 @@ with open("./schema.yml") as f:
|
|
10 |
schema = yaml.safe_load(f)
|
11 |
|
12 |
|
13 |
-
class
|
14 |
|
15 |
def parse(self, text: str) -> str:
|
16 |
axioms = super().parse(text=text)
|
@@ -34,13 +34,13 @@ class TripletParser(NumberedListOutputParser):
|
|
34 |
return super().get_format_instructions()
|
35 |
|
36 |
|
37 |
-
|
38 |
|
39 |
for scheme in schema:
|
40 |
parser = schema[scheme]["parser"]
|
41 |
prompts = schema[scheme]["prompts"]
|
42 |
|
43 |
-
|
44 |
LLMChain,
|
45 |
output_parser=eval(f'{parser}()'),
|
46 |
prompt=ChatPromptTemplate.from_messages(list(prompts.items()))
|
|
|
10 |
schema = yaml.safe_load(f)
|
11 |
|
12 |
|
13 |
+
class ClauseParser(NumberedListOutputParser):
|
14 |
|
15 |
def parse(self, text: str) -> str:
|
16 |
axioms = super().parse(text=text)
|
|
|
34 |
return super().get_format_instructions()
|
35 |
|
36 |
|
37 |
+
llm_chains = {}
|
38 |
|
39 |
for scheme in schema:
|
40 |
parser = schema[scheme]["parser"]
|
41 |
prompts = schema[scheme]["prompts"]
|
42 |
|
43 |
+
llm_chains[scheme] = partial(
|
44 |
LLMChain,
|
45 |
output_parser=eval(f'{parser}()'),
|
46 |
prompt=ChatPromptTemplate.from_messages(list(prompts.items()))
|
main.py
CHANGED
@@ -1,18 +1,40 @@
|
|
1 |
-
import json
|
2 |
import os
|
|
|
3 |
import secrets
|
4 |
import string
|
|
|
5 |
from datetime import datetime
|
6 |
from zipfile import ZipFile
|
7 |
|
8 |
import gradio as gr
|
|
|
9 |
import pandas as pd
|
|
|
10 |
from langchain.chains import SimpleSequentialChain
|
11 |
from langchain.chat_models import ChatOpenAI
|
12 |
from nltk.tokenize import sent_tokenize
|
|
|
13 |
|
14 |
import utils
|
15 |
-
from chains import
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
16 |
|
17 |
|
18 |
class Text2KG:
|
@@ -20,10 +42,11 @@ class Text2KG:
|
|
20 |
|
21 |
def __init__(self, api_key: str, **kwargs):
|
22 |
|
23 |
-
self.
|
|
|
24 |
|
25 |
|
26 |
-
def
|
27 |
"""Initialize Text2KG pipeline from passed steps.
|
28 |
|
29 |
Args:
|
@@ -31,7 +54,7 @@ class Text2KG:
|
|
31 |
the schema.yml file
|
32 |
"""
|
33 |
self.pipeline = SimpleSequentialChain(
|
34 |
-
chains=[
|
35 |
verbose=False
|
36 |
)
|
37 |
|
@@ -50,45 +73,92 @@ class Text2KG:
|
|
50 |
return triplets
|
51 |
|
52 |
|
53 |
-
def
|
54 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
55 |
|
56 |
Args:
|
57 |
api_key (str): OpenAI API key
|
58 |
-
|
59 |
-
|
60 |
-
a pre-processing step. Doubles the amount of calls to ChatGPT
|
61 |
text (str): Text from which Text2KG will extract knowledge graph from
|
62 |
progress: Progress bar. The default is gradio's progress bar; for a
|
63 |
command line progress bar, set `progress = tqdm`
|
64 |
|
65 |
Returns:
|
66 |
-
knowledge_graph (DataFrame): The extracted knowledge graph
|
67 |
zip_path (str): Path to ZIP archive containing outputs
|
|
|
68 |
"""
|
69 |
# init
|
70 |
if api_key == "":
|
71 |
raise ValueError("API key is required")
|
72 |
|
73 |
-
|
|
|
|
|
74 |
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
steps = ["extract_triplets"]
|
79 |
|
80 |
-
|
|
|
81 |
|
82 |
-
|
|
|
|
|
83 |
sentences = sent_tokenize(text)
|
84 |
-
|
85 |
-
|
86 |
|
87 |
# create KG
|
88 |
knowledge_graph = []
|
89 |
|
90 |
-
for i,
|
91 |
-
|
|
|
92 |
[triplet.update({"sentence_id": i}) for triplet in output]
|
93 |
|
94 |
knowledge_graph.extend(output)
|
@@ -96,7 +166,7 @@ def create_knowledge_graph(api_key: str, ngram_size: int, axiomatize: bool, text
|
|
96 |
|
97 |
# convert to df, post-process data
|
98 |
knowledge_graph = pd.DataFrame(knowledge_graph)
|
99 |
-
knowledge_graph =
|
100 |
|
101 |
# rearrange columns
|
102 |
knowledge_graph = knowledge_graph[["sentence_id", "subject", "relation", "object"]]
|
@@ -104,28 +174,29 @@ def create_knowledge_graph(api_key: str, ngram_size: int, axiomatize: bool, text
|
|
104 |
# metadata
|
105 |
now = datetime.now()
|
106 |
date = str(now.date())
|
107 |
-
timestamp = now.strftime("%Y%m%d%H%M%S")
|
108 |
|
109 |
metadata = {
|
110 |
-
"
|
111 |
-
"batch_size":
|
112 |
-
"
|
113 |
}
|
114 |
|
115 |
-
# unique identifier for saving
|
116 |
uid = ''.join(secrets.choice(string.ascii_letters)
|
117 |
for _ in range(6))
|
118 |
|
|
|
|
|
119 |
save_dir = os.path.join(".", "output", date, uid)
|
120 |
os.makedirs(save_dir, exist_ok=True)
|
121 |
|
122 |
|
123 |
# save metadata & data
|
124 |
-
with open(os.path.join(save_dir, "metadata.
|
125 |
-
|
126 |
|
127 |
-
|
128 |
-
|
129 |
index=False)
|
130 |
|
131 |
knowledge_graph.to_csv(os.path.join(save_dir, "kg.txt"),
|
@@ -137,37 +208,32 @@ def create_knowledge_graph(api_key: str, ngram_size: int, axiomatize: bool, text
|
|
137 |
|
138 |
with ZipFile(zip_path, 'w') as zipObj:
|
139 |
|
140 |
-
zipObj.write(os.path.join(save_dir, "metadata.
|
141 |
zipObj.write(os.path.join(save_dir, "sentences.txt"))
|
142 |
zipObj.write(os.path.join(save_dir, "kg.txt"))
|
143 |
|
144 |
-
return
|
145 |
|
146 |
|
147 |
class App:
|
148 |
def __init__(self):
|
149 |
-
description = (
|
150 |
-
"# Text2KG\n\n"
|
151 |
-
"Text2KG is a framework that uses ChatGPT to automatically creates knowledge graphs from plain text.\n\n"
|
152 |
-
"**Usage:** (1) configure the pipeline; (2) add the text that will be processed"
|
153 |
-
)
|
154 |
demo = gr.Interface(
|
155 |
-
fn=
|
156 |
-
|
157 |
inputs=[
|
158 |
gr.Textbox(placeholder="API key...", label="OpenAI API Key", type="password"),
|
159 |
-
gr.Slider(minimum=1, maximum=10, step=1, label="Sentence Batch Size"
|
160 |
-
gr.
|
161 |
gr.Textbox(lines=2, placeholder="Text Here...", label="Input Text"),
|
162 |
],
|
163 |
outputs=[
|
164 |
-
gr.
|
|
|
165 |
headers=["sentence_id", "subject", "relation", "object"],
|
166 |
max_rows=10,
|
167 |
-
overflow_row_behaviour="
|
168 |
-
gr.File(label="Knowledge Graph")
|
169 |
],
|
170 |
-
examples=[[
|
171 |
"1) a plasma membrane, an outer covering that separates the "
|
172 |
"cell's interior from its surrounding environment; 2) cytoplasm, "
|
173 |
"consisting of a jelly-like cytosol within the cell in which "
|
@@ -182,7 +248,7 @@ class App:
|
|
182 |
allow_flagging="never",
|
183 |
cache_examples=False
|
184 |
)
|
185 |
-
demo.launch(share=False)
|
186 |
|
187 |
|
188 |
if __name__ == "__main__":
|
|
|
|
|
1 |
import os
|
2 |
+
import re
|
3 |
import secrets
|
4 |
import string
|
5 |
+
import yaml
|
6 |
from datetime import datetime
|
7 |
from zipfile import ZipFile
|
8 |
|
9 |
import gradio as gr
|
10 |
+
import nltk
|
11 |
import pandas as pd
|
12 |
+
from langchain.embeddings import OpenAIEmbeddings
|
13 |
from langchain.chains import SimpleSequentialChain
|
14 |
from langchain.chat_models import ChatOpenAI
|
15 |
from nltk.tokenize import sent_tokenize
|
16 |
+
from pandas import DataFrame
|
17 |
|
18 |
import utils
|
19 |
+
from chains import llm_chains
|
20 |
+
|
21 |
+
|
22 |
+
# download NLTK dependencies
|
23 |
+
nltk.download("punkt")
|
24 |
+
nltk.download("stopwords")
|
25 |
+
|
26 |
+
# load stop words const.
|
27 |
+
from nltk.corpus import stopwords
|
28 |
+
STOP_WORDS = stopwords.words("english")
|
29 |
+
|
30 |
+
# load global spacy model
|
31 |
+
# try:
|
32 |
+
# SPACY_MODEL = spacy.load("en_core_web_sm")
|
33 |
+
# except OSError:
|
34 |
+
# print("[spacy] Downloading model: en_core_web_sm")
|
35 |
+
|
36 |
+
# spacy.cli.download("en_core_web_sm")
|
37 |
+
# SPACY_MODEL = spacy.load("en_core_web_sm")
|
38 |
|
39 |
|
40 |
class Text2KG:
|
|
|
42 |
|
43 |
def __init__(self, api_key: str, **kwargs):
|
44 |
|
45 |
+
self.llm = ChatOpenAI(openai_api_key=api_key, **kwargs)
|
46 |
+
self.embedding = OpenAIEmbeddings(openai_api_key=api_key)
|
47 |
|
48 |
|
49 |
+
def init(self, steps: list[str]):
|
50 |
"""Initialize Text2KG pipeline from passed steps.
|
51 |
|
52 |
Args:
|
|
|
54 |
the schema.yml file
|
55 |
"""
|
56 |
self.pipeline = SimpleSequentialChain(
|
57 |
+
chains=[llm_chains[step](llm=self.llm) for step in steps],
|
58 |
verbose=False
|
59 |
)
|
60 |
|
|
|
73 |
return triplets
|
74 |
|
75 |
|
76 |
+
def clean(self, kg: DataFrame) -> DataFrame:
|
77 |
+
"""Text2KG post-processing."""
|
78 |
+
drop_list = []
|
79 |
+
|
80 |
+
for i, row in kg.iterrows():
|
81 |
+
# drop stopwords (e.g. pronouns)
|
82 |
+
if (row.subject in STOP_WORDS) or (row.object in STOP_WORDS):
|
83 |
+
drop_list.append(i)
|
84 |
+
|
85 |
+
# drop broken triplets
|
86 |
+
elif row.hasnans:
|
87 |
+
drop_list.append(i)
|
88 |
+
|
89 |
+
# lowercase nodes/edges, drop articles
|
90 |
+
else:
|
91 |
+
article_pattern = r'^(the|a|an) (.+)'
|
92 |
+
be_pattern = r'^(are|is) (a )?(.+)'
|
93 |
+
|
94 |
+
kg.at[i, "subject"] = re.sub(article_pattern, r'\2', row.subject.lower())
|
95 |
+
kg.at[i, "relation"] = re.sub(be_pattern, r'\3', row.relation.lower())
|
96 |
+
kg.at[i, "object"] = re.sub(article_pattern, r'\2', row.object.lower())
|
97 |
+
|
98 |
+
return kg.drop(drop_list)
|
99 |
+
|
100 |
+
|
101 |
+
def normalize(self, kg: DataFrame, threshold: float=0.3) -> DataFrame:
|
102 |
+
"""Reduce dimensionality of Text2KG output by merging cosine-similar nodes/edges."""
|
103 |
+
|
104 |
+
ents = pd.concat([kg["subject"], kg["object"]]).unique()
|
105 |
+
rels = kg["relation"].unique()
|
106 |
+
|
107 |
+
ent_map = utils.condense_labels(ents, self.embedding.embed_documents, threshold=threshold)
|
108 |
+
rel_map = utils.condense_labels(rels, self.embedding.embed_documents, threshold=threshold)
|
109 |
+
|
110 |
+
kg_normal = pd.DataFrame()
|
111 |
+
|
112 |
+
kg_normal["subject"] = kg["subject"].map(ent_map)
|
113 |
+
kg_normal["relation"] = kg["relation"].map(rel_map)
|
114 |
+
kg_normal["object"] = kg["object"].map(ent_map)
|
115 |
+
|
116 |
+
return kg_normal
|
117 |
+
|
118 |
+
|
119 |
+
def extract_knowledge_graph(api_key: str, batch_size: int, modules: list[str], text: str, progress=gr.Progress()):
|
120 |
+
"""Extract knowledge graph from text.
|
121 |
|
122 |
Args:
|
123 |
api_key (str): OpenAI API key
|
124 |
+
batch_size (int): Number of sentences per forward pass
|
125 |
+
modules (list): Additional modules to add before main extraction step
|
|
|
126 |
text (str): Text from which Text2KG will extract knowledge graph from
|
127 |
progress: Progress bar. The default is gradio's progress bar; for a
|
128 |
command line progress bar, set `progress = tqdm`
|
129 |
|
130 |
Returns:
|
|
|
131 |
zip_path (str): Path to ZIP archive containing outputs
|
132 |
+
knowledge_graph (DataFrame): The extracted knowledge graph
|
133 |
"""
|
134 |
# init
|
135 |
if api_key == "":
|
136 |
raise ValueError("API key is required")
|
137 |
|
138 |
+
pipeline = Text2KG(api_key=api_key, temperature=0.3) # low temp. -> low randomness
|
139 |
+
|
140 |
+
steps = []
|
141 |
|
142 |
+
for module in modules:
|
143 |
+
m = module.lower().replace(' ', '_')
|
144 |
+
steps.append(m)
|
|
|
145 |
|
146 |
+
if (len(steps) == 0) or (steps[-1] != "triplet_extraction"):
|
147 |
+
steps.append("triplet_extraction")
|
148 |
|
149 |
+
pipeline.init(steps)
|
150 |
+
|
151 |
+
# split text into batches
|
152 |
sentences = sent_tokenize(text)
|
153 |
+
batches = [" ".join(sentences[i:i+batch_size])
|
154 |
+
for i in range(0, len(sentences), batch_size)]
|
155 |
|
156 |
# create KG
|
157 |
knowledge_graph = []
|
158 |
|
159 |
+
for i, batch in progress.tqdm(list(enumerate(batches)),
|
160 |
+
desc="Processing...", unit="batches"):
|
161 |
+
output = pipeline.run(batch)
|
162 |
[triplet.update({"sentence_id": i}) for triplet in output]
|
163 |
|
164 |
knowledge_graph.extend(output)
|
|
|
166 |
|
167 |
# convert to df, post-process data
|
168 |
knowledge_graph = pd.DataFrame(knowledge_graph)
|
169 |
+
knowledge_graph = pipeline.clean(knowledge_graph)
|
170 |
|
171 |
# rearrange columns
|
172 |
knowledge_graph = knowledge_graph[["sentence_id", "subject", "relation", "object"]]
|
|
|
174 |
# metadata
|
175 |
now = datetime.now()
|
176 |
date = str(now.date())
|
|
|
177 |
|
178 |
metadata = {
|
179 |
+
"_timestamp": now,
|
180 |
+
"batch_size": batch_size,
|
181 |
+
"modules": steps
|
182 |
}
|
183 |
|
184 |
+
# unique identifier for local saving
|
185 |
uid = ''.join(secrets.choice(string.ascii_letters)
|
186 |
for _ in range(6))
|
187 |
|
188 |
+
print(f"Run ID: {date}/{uid}")
|
189 |
+
|
190 |
save_dir = os.path.join(".", "output", date, uid)
|
191 |
os.makedirs(save_dir, exist_ok=True)
|
192 |
|
193 |
|
194 |
# save metadata & data
|
195 |
+
with open(os.path.join(save_dir, "metadata.yml"), 'w') as f:
|
196 |
+
yaml.dump(metadata, f)
|
197 |
|
198 |
+
batches_df = pd.DataFrame(enumerate(batches), columns=["sentence_id", "text"])
|
199 |
+
batches_df.to_csv(os.path.join(save_dir, "sentences.txt"),
|
200 |
index=False)
|
201 |
|
202 |
knowledge_graph.to_csv(os.path.join(save_dir, "kg.txt"),
|
|
|
208 |
|
209 |
with ZipFile(zip_path, 'w') as zipObj:
|
210 |
|
211 |
+
zipObj.write(os.path.join(save_dir, "metadata.yml"))
|
212 |
zipObj.write(os.path.join(save_dir, "sentences.txt"))
|
213 |
zipObj.write(os.path.join(save_dir, "kg.txt"))
|
214 |
|
215 |
+
return zip_path, knowledge_graph
|
216 |
|
217 |
|
218 |
class App:
|
219 |
def __init__(self):
|
|
|
|
|
|
|
|
|
|
|
220 |
demo = gr.Interface(
|
221 |
+
fn=extract_knowledge_graph,
|
222 |
+
title="Text2KG",
|
223 |
inputs=[
|
224 |
gr.Textbox(placeholder="API key...", label="OpenAI API Key", type="password"),
|
225 |
+
gr.Slider(minimum=1, maximum=10, step=1, label="Sentence Batch Size"),
|
226 |
+
gr.CheckboxGroup(choices=["Clause Deconstruction"], label="Optional Modules"),
|
227 |
gr.Textbox(lines=2, placeholder="Text Here...", label="Input Text"),
|
228 |
],
|
229 |
outputs=[
|
230 |
+
gr.File(label="Knowledge Graph"),
|
231 |
+
gr.DataFrame(label="Preview",
|
232 |
headers=["sentence_id", "subject", "relation", "object"],
|
233 |
max_rows=10,
|
234 |
+
overflow_row_behaviour="paginate")
|
|
|
235 |
],
|
236 |
+
examples=[[None, 1, [], ("All cells share four common components: "
|
237 |
"1) a plasma membrane, an outer covering that separates the "
|
238 |
"cell's interior from its surrounding environment; 2) cytoplasm, "
|
239 |
"consisting of a jelly-like cytosol within the cell in which "
|
|
|
248 |
allow_flagging="never",
|
249 |
cache_examples=False
|
250 |
)
|
251 |
+
demo.queue().launch(share=False)
|
252 |
|
253 |
|
254 |
if __name__ == "__main__":
|
schema.yml
CHANGED
@@ -1,5 +1,5 @@
|
|
1 |
-
|
2 |
-
parser:
|
3 |
prompts:
|
4 |
system: |
|
5 |
You are a sentence parsing agent helping to construct a knowledge graph.
|
@@ -18,7 +18,7 @@ text2axiom:
|
|
18 |
human: |
|
19 |
{text}
|
20 |
|
21 |
-
|
22 |
parser: TripletParser
|
23 |
prompts:
|
24 |
system: |
|
|
|
1 |
+
clause_deconstruction:
|
2 |
+
parser: ClauseParser
|
3 |
prompts:
|
4 |
system: |
|
5 |
You are a sentence parsing agent helping to construct a knowledge graph.
|
|
|
18 |
human: |
|
19 |
{text}
|
20 |
|
21 |
+
triplet_extraction:
|
22 |
parser: TripletParser
|
23 |
prompts:
|
24 |
system: |
|
utils.py
CHANGED
@@ -1,34 +1,31 @@
|
|
1 |
-
import
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
df.at[i, "object"] = re.sub(article_pattern, r'\2', row.object.lower())
|
33 |
-
|
34 |
-
return df.drop(drop_list)
|
|
|
1 |
+
import numpy as np
|
2 |
+
from typing import Callable
|
3 |
+
|
4 |
+
from sklearn.cluster import AgglomerativeClustering
|
5 |
+
|
6 |
+
|
7 |
+
def condense_labels(labels: np.ndarray, embedding_func: Callable, threshold: float=0.5):
|
8 |
+
"""Combine cosine-similar labels under same name."""
|
9 |
+
|
10 |
+
embeddings = np.array(embedding_func(labels))
|
11 |
+
|
12 |
+
clustering = AgglomerativeClustering(
|
13 |
+
n_clusters=None,
|
14 |
+
distance_threshold=threshold
|
15 |
+
).fit(embeddings)
|
16 |
+
|
17 |
+
clusters = [np.where(clustering.labels_ == l)[0]
|
18 |
+
for l in range(clustering.n_clusters_)]
|
19 |
+
|
20 |
+
clusters_reduced = []
|
21 |
+
|
22 |
+
for c in clusters:
|
23 |
+
embs = embeddings[c]
|
24 |
+
centroid = np.mean(embs)
|
25 |
+
|
26 |
+
idx = c[np.argmin(np.linalg.norm(embs - centroid, axis=1))]
|
27 |
+
clusters_reduced.append(idx)
|
28 |
+
|
29 |
+
old2new = {old_id: new_id for old_ids, new_id in zip(clusters, clusters_reduced) for old_id in old_ids}
|
30 |
+
|
31 |
+
return {labels[i]: labels[j] for i, j in old2new.items()}
|
|
|
|
|
|