import whisper |
import os |
import random |
import openai |
import yt_dlp |
from pytube import YouTube, extract |
import pandas as pd |
import plotly_express as px |
import nltk |
import plotly.graph_objects as go |
from optimum.onnxruntime import ORTModelForSequenceClassification |
from transformers import pipeline, AutoTokenizer, AutoModelForSequenceClassification, AutoModelForTokenClassification, AutoModelForSeq2SeqLM |
from sentence_transformers import SentenceTransformer, CrossEncoder, util |
import streamlit as st |
import en_core_web_lg |
import validators |
import re |
import itertools |
import numpy as np |
from bs4 import BeautifulSoup |
import base64, time |
from annotated_text import annotated_text |
import pickle, math |
import wikipedia |
from pyvis.network import Network |
import torch |
from pydub import AudioSegment |
from langchain.docstore.document import Document |
from langchain.embeddings import HuggingFaceEmbeddings, HuggingFaceBgeEmbeddings, HuggingFaceInstructEmbeddings |
from langchain.vectorstores import FAISS |
from langchain.text_splitter import RecursiveCharacterTextSplitter |
from langchain.chat_models import ChatOpenAI |
from langchain.chains import QAGenerationChain |
from langchain.callbacks import StreamlitCallbackHandler |
from langchain.agents import OpenAIFunctionsAgent, AgentExecutor |
from langchain.agents.agent_toolkits import create_retriever_tool |
from langchain.agents.openai_functions_agent.agent_token_buffer_memory import ( |
AgentTokenBufferMemory, |
) |
from langchain.prompts.chat import ( |
ChatPromptTemplate, |
SystemMessagePromptTemplate, |
AIMessagePromptTemplate, |
HumanMessagePromptTemplate, |
) |
from langchain.schema import ( |
AIMessage, |
HumanMessage, |
SystemMessage |
) |
from langchain.prompts import PromptTemplate |
nltk.download('punkt') |
from nltk import sent_tokenize |
OPEN_AI_KEY = os.environ.get('OPEN_AI_KEY') |
time_str = time.strftime("%d%m%Y-%H%M%S") |
HTML_WRAPPER = """<div style="overflow-x: auto; border: 1px solid #e6e9ef; border-radius: 0.25rem; padding: 1rem; |
margin-bottom: 2.5rem">{}</div> """ |
@st.cache_resource |
def load_models(): |
'''Load and cache all the models to be used''' |
q_model = ORTModelForSequenceClassification.from_pretrained("nickmuchi/quantized-optimum-finbert-tone") |
ner_model = AutoModelForTokenClassification.from_pretrained("xlm-roberta-large-finetuned-conll03-english") |
kg_model = AutoModelForSeq2SeqLM.from_pretrained("Babelscape/rebel-large") |
kg_tokenizer = AutoTokenizer.from_pretrained("Babelscape/rebel-large") |
q_tokenizer = AutoTokenizer.from_pretrained("nickmuchi/quantized-optimum-finbert-tone") |
ner_tokenizer = AutoTokenizer.from_pretrained("xlm-roberta-large-finetuned-conll03-english") |
emb_tokenizer = AutoTokenizer.from_pretrained('google/flan-t5-xl') |
sent_pipe = pipeline("text-classification",model=q_model, tokenizer=q_tokenizer) |
sum_pipe = pipeline("summarization",model="philschmid/flan-t5-base-samsum",clean_up_tokenization_spaces=True) |
ner_pipe = pipeline("ner", model=ner_model, tokenizer=ner_tokenizer, grouped_entities=True) |
cross_encoder = CrossEncoder('cross-encoder/mmarco-mMiniLMv2-L12-H384-v1') |
sbert = SentenceTransformer('all-MiniLM-L6-v2') |
return sent_pipe, sum_pipe, ner_pipe, cross_encoder, kg_model, kg_tokenizer, emb_tokenizer, sbert |
@st.cache_resource |
def get_spacy(): |
nlp = en_core_web_lg.load() |
return nlp |
nlp = get_spacy() |
sent_pipe, sum_pipe, ner_pipe, cross_encoder, kg_model, kg_tokenizer, emb_tokenizer, sbert = load_models() |
@st.cache_data |
def get_yt_audio(url): |
'''Get YT video from given URL link''' |
yt = YouTube(url) |
title = yt.title |
audio_stream = yt.streams.filter(progressive=True, file_extension='mp4').order_by('resolution').desc().first().download() |
return audio_stream, title |
@st.cache_data |
def load_whisper_api(audio): |
'''Transcribe YT audio to text using Open AI API''' |
file = open(audio, "rb") |
transcript = openai.Audio.translate("whisper-1", file) |
return transcript |
@st.cache_data |
def load_asr_model(model_name): |
'''Load the open source whisper model in cases where the API is not working''' |
model = whisper.load_model(model_name) |
return model |
@st.cache_data |
def inference(link, upload, _asr_model): |
'''Convert Youtube video or Audio upload to text''' |
try: |
if validators.url(link): |
st.info("`Downloading YT audio...`") |
audio_file, title = get_yt_audio(link) |
print(f'audio_file:{audio_file}') |
st.session_state['audio'] = audio_file |
print(f"audio_file_session_state:{st.session_state['audio'] }") |
audio_size = round(os.path.getsize(st.session_state['audio'])/(1024*1024),1) |
if audio_size <= 25: |
st.info("`Transcribing YT audio...`") |
results = load_whisper_api(st.session_state['audio'])['text'] |
else: |
st.warning('File size larger than 24mb, applying chunking and transcription',icon="โ ๏ธ") |
song = AudioSegment.from_file(st.session_state['audio'], format='mp4') |
twenty_minutes = 20 * 60 * 1000 |
chunks = song[::twenty_minutes] |
transcriptions = [] |
video_id = extract.video_id(link) |
for i, chunk in enumerate(chunks): |
chunk.export(f'output/chunk_{i}_{video_id}.mp4', format='mp4') |
transcriptions.append(load_whisper_api(f'output/chunk_{i}_{video_id}.mp4')['text']) |
results = ','.join(transcriptions) |
st.info("`YT Video transcription process complete...`") |
return results, title |
elif _upload: |
audio_size = round(os.path.getsize(_upload)/(1024*1024),1) |
if audio_size <= 25: |
st.info("`Transcribing uploaded audio...`") |
results = load_whisper_api(_upload)['text'] |
else: |
st.write('File size larger than 24mb, applying chunking and transcription') |
song = AudioSegment.from_file(_upload) |
twenty_minutes = 20 * 60 * 1000 |
chunks = song[::twenty_minutes] |
transcriptions = [] |
st.info("`Transcribing uploaded audio...`") |
for i, chunk in enumerate(chunks): |
chunk.export(f'output/chunk_{i}.mp4', format='mp4') |
transcriptions.append(load_whisper_api(f'output/chunk_{i}.mp4')['text']) |
results = ','.join(transcriptions) |
st.info("`Uploaded audio transcription process complete...`") |
return results, "Transcribed Earnings Audio" |
except Exception as e: |
st.error(f'''Whisper API Error: {e}, |
Using Whisper module from GitHub, might take longer than expected''',icon="๐จ") |
results = _asr_model.transcribe(st.session_state['audio'], task='transcribe', language='en') |
return results['text'], title |
@st.cache_data |
def clean_text(text): |
'''Clean all text after inference''' |
text = text.encode("ascii", "ignore").decode() |
text = re.sub(r"https*\S+", " ", text) |
text = re.sub(r"@\S+", " ", text) |
text = re.sub(r"#\S+", " ", text) |
text = re.sub(r"\s{2,}", " ", text) |
return text |
@st.cache_data |
def chunk_long_text(text,threshold,window_size=3,stride=2): |
'''Preprocess text and chunk for sentiment analysis''' |
sentences = sent_tokenize(text) |
out = [] |
for chunk in sentences: |
if len(chunk.split()) < threshold: |
out.append(chunk) |
else: |
words = chunk.split() |
num = int(len(words)/threshold) |
for i in range(0,num*threshold+1,threshold): |
out.append(' '.join(words[i:threshold+i])) |
passages = [] |
for paragraph in [out]: |
for start_idx in range(0, len(paragraph), stride): |
end_idx = min(start_idx+window_size, len(paragraph)) |
passages.append(" ".join(paragraph[start_idx:end_idx])) |
return passages |
@st.cache_data |
def sentiment_pipe(earnings_text): |
'''Determine the sentiment of the text''' |
earnings_sentences = chunk_long_text(earnings_text,150,1,1) |
earnings_sentiment = sent_pipe(earnings_sentences) |
return earnings_sentiment, earnings_sentences |
@st.cache_data |
def chunk_and_preprocess_text(text, model_name= 'philschmid/flan-t5-base-samsum'): |
'''Chunk and preprocess text for summarization''' |
tokenizer = AutoTokenizer.from_pretrained(model_name) |
sentences = sent_tokenize(text) |
length = 0 |
chunk = "" |
chunks = [] |
count = -1 |
for sentence in sentences: |
count += 1 |
combined_length = len(tokenizer.tokenize(sentence)) + length |
if combined_length <= tokenizer.max_len_single_sentence: |
chunk += sentence + " " |
length = combined_length |
if count == len(sentences) - 1: |
chunks.append(chunk) |
else: |
chunks.append(chunk) |
length = 0 |
chunk = "" |
chunk += sentence + " " |
length = len(tokenizer.tokenize(sentence)) |
return chunks |
@st.cache_data |
def summarize_text(text_to_summarize,max_len,min_len): |
'''Summarize text with HF model''' |
summarized_text = sum_pipe(text_to_summarize, |
max_length=max_len, |
min_length=min_len, |
do_sample=False, |
early_stopping=True, |
num_beams=4) |
summarized_text = ' '.join([summ['summary_text'] for summ in summarized_text]) |
return summarized_text |
@st.cache_data |
def get_all_entities_per_sentence(text): |
doc = nlp(''.join(text)) |
sentences = list(doc.sents) |
entities_all_sentences = [] |
for sentence in sentences: |
entities_this_sentence = [] |
for entity in sentence.ents: |
entities_this_sentence.append(str(entity)) |
entities_xlm = [entity["word"] for entity in ner_pipe(str(sentence))] |
for entity in entities_xlm: |
entities_this_sentence.append(str(entity)) |
entities_all_sentences.append(entities_this_sentence) |
return entities_all_sentences |
@st.cache_data |
def get_all_entities(text): |
all_entities_per_sentence = get_all_entities_per_sentence(text) |
return list(itertools.chain.from_iterable(all_entities_per_sentence)) |
@st.cache_data |
def get_and_compare_entities(article_content,summary_output): |
all_entities_per_sentence = get_all_entities_per_sentence(article_content) |
entities_article = list(itertools.chain.from_iterable(all_entities_per_sentence)) |
all_entities_per_sentence = get_all_entities_per_sentence(summary_output) |
entities_summary = list(itertools.chain.from_iterable(all_entities_per_sentence)) |
matched_entities = [] |
unmatched_entities = [] |
for entity in entities_summary: |
if any(entity.lower() in substring_entity.lower() for substring_entity in entities_article): |
matched_entities.append(entity) |
elif any( |
np.inner(sbert.encode(entity, show_progress_bar=False), |
sbert.encode(art_entity, show_progress_bar=False)) > 0.9 for |
art_entity in entities_article): |
matched_entities.append(entity) |
else: |
unmatched_entities.append(entity) |
matched_entities = list(dict.fromkeys(matched_entities)) |
unmatched_entities = list(dict.fromkeys(unmatched_entities)) |
matched_entities_to_remove = [] |
unmatched_entities_to_remove = [] |
for entity in matched_entities: |
for substring_entity in matched_entities: |
if entity != substring_entity and entity.lower() in substring_entity.lower(): |
matched_entities_to_remove.append(entity) |
for entity in unmatched_entities: |
for substring_entity in unmatched_entities: |
if entity != substring_entity and entity.lower() in substring_entity.lower(): |
unmatched_entities_to_remove.append(entity) |
matched_entities_to_remove = list(dict.fromkeys(matched_entities_to_remove)) |
unmatched_entities_to_remove = list(dict.fromkeys(unmatched_entities_to_remove)) |
for entity in matched_entities_to_remove: |
matched_entities.remove(entity) |
for entity in unmatched_entities_to_remove: |
unmatched_entities.remove(entity) |
return matched_entities, unmatched_entities |
@st.cache_data |
def highlight_entities(article_content,summary_output): |
markdown_start_red = "<mark class=\"entity\" style=\"background: rgb(238, 135, 135);\">" |
markdown_start_green = "<mark class=\"entity\" style=\"background: rgb(121, 236, 121);\">" |
markdown_end = "</mark>" |
matched_entities, unmatched_entities = get_and_compare_entities(article_content,summary_output) |
for entity in matched_entities: |
summary_output = re.sub(f'({entity})(?![^rgb\(]*\))',markdown_start_green + entity + markdown_end,summary_output) |
for entity in unmatched_entities: |
summary_output = re.sub(f'({entity})(?![^rgb\(]*\))',markdown_start_red + entity + markdown_end,summary_output) |
print("") |
print("") |
soup = BeautifulSoup(summary_output, features="html.parser") |
return HTML_WRAPPER.format(soup) |
def summary_downloader(raw_text): |
'''Download the summary generated''' |
b64 = base64.b64encode(raw_text.encode()).decode() |
new_filename = "new_text_file_{}_.txt".format(time_str) |
st.markdown("#### Download Summary as a File ###") |
href = f'<a href="data:file/txt;base64,{b64}" download="{new_filename}">Click to Download!!</a>' |
st.markdown(href,unsafe_allow_html=True) |
@st.cache_data |
def generate_eval(raw_text, N, chunk): |
update = st.empty() |
ques_update = st.empty() |
update.info("`Generating sample questions ...`") |
n = len(raw_text) |
starting_indices = [random.randint(0, n-chunk) for _ in range(N)] |
sub_sequences = [raw_text[i:i+chunk] for i in starting_indices] |
chain = QAGenerationChain.from_llm(ChatOpenAI(temperature=0)) |
eval_set = [] |
for i, b in enumerate(sub_sequences): |
try: |
qa = chain.run(b) |
eval_set.append(qa) |
ques_update.info(f"Creating Question: {i+1}") |
except Exception as e: |
print(e) |
st.warning(f'Error in generating Question: {i+1}...', icon="โ ๏ธ") |
continue |
eval_set_full = list(itertools.chain.from_iterable(eval_set)) |
update.empty() |
ques_update.empty() |
return eval_set_full |
@st.cache_resource |
def create_prompt_and_llm(): |
'''Create prompt''' |
llm = ChatOpenAI(temperature=0, streaming=True, model="gpt-4") |
message = SystemMessage( |
content=( |
"You are a helpful chatbot who is tasked with answering questions acuurately about earnings call transcript provided. " |
"Unless otherwise explicitly stated, it is probably fair to assume that questions are about the earnings call transcript. " |
"If there is any ambiguity, you probably assume they are about that." |
"Do not use any information not provided in the earnings context and remember you are a to speak like a finance expert." |
"If you don't know the answer, just say 'There is no relevant answer in the given earnings call transcript'" |
"don't try to make up an answer" |
) |
) |
prompt = OpenAIFunctionsAgent.create_prompt( |
system_message=message, |
extra_prompt_messages=[MessagesPlaceholder(variable_name="history")], |
) |
return prompt, llm |
@st.cache_resource |
def gen_embeddings(embedding_model): |
'''Generate embeddings for given model''' |
if 'hkunlp' in embedding_model: |
embeddings = HuggingFaceInstructEmbeddings(model_name=embedding_model, |
query_instruction='Represent the Financial question for retrieving supporting paragraphs: ', |
embed_instruction='Represent the Financial paragraph for retrieval: ') |
elif 'mpnet' in embedding_model: |
embeddings = HuggingFaceEmbeddings(model_name=embedding_model) |
elif 'FlagEmbedding' in embedding_model: |
encode_kwargs = {'normalize_embeddings': True} |
embeddings = HuggingFaceBgeEmbeddings(model_name=embedding_model, |
encode_kwargs = encode_kwargs |
) |
return embeddings |
@st.cache_data |
def create_vectorstore(corpus, title, embedding_model, chunk_size=1000, overlap=50): |
'''Process text for Semantic Search''' |
text_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size,chunk_overlap=overlap) |
texts = text_splitter.split_text(corpus) |
embeddings = gen_embeddings(embedding_model) |
vectorstore = FAISS.from_texts(texts, embeddings, metadatas=[{"source": i} for i in range(len(texts))]) |
return vectorstore |
@st.cache_data |
def create_memory_and_agent(query,_docsearch): |
'''Embed text and generate semantic search scores''' |
vectorstore = _docsearch.as_retriever(search_kwargs={"k": 4}) |
tool = create_retriever_tool( |
vectorstore, |
"earnings_call_search", |
"Searches and returns documents using the earnings context provided as a source, relevant to the user input question.", |
) |
tools = [tool] |
prompt,llm = create_prompt_and_llm() |
agent = OpenAIFunctionsAgent(llm=llm, tools=tools, prompt=prompt) |
agent_executor = AgentExecutor( |
agent=agent, |
tools=tools, |
verbose=True, |
return_intermediate_steps=True, |
) |
memory = AgentTokenBufferMemory(llm=llm) |
return memory, agent_executor |
@st.cache_data |
def gen_sentiment(text): |
'''Generate sentiment of given text''' |
return sent_pipe(text)[0]['label'] |
@st.cache_data |
def gen_annotated_text(df): |
'''Generate annotated text''' |
tag_list=[] |
for row in df.itertuples(): |
label = row[2] |
text = row[1] |
if label == 'Positive': |
tag_list.append((text,label,'#8fce00')) |
elif label == 'Negative': |
tag_list.append((text,label,'#f44336')) |
else: |
tag_list.append((text,label,'#000000')) |
return tag_list |
def display_df_as_table(model,top_k,score='score'): |
'''Display the df with text and scores as a table''' |
df = pd.DataFrame([(hit[score],passages[hit['corpus_id']]) for hit in model[0:top_k]],columns=['Score','Text']) |
df['Score'] = round(df['Score'],2) |
return df |
def make_spans(text,results): |
results_list = [] |
for i in range(len(results)): |
results_list.append(results[i]['label']) |
facts_spans = [] |
facts_spans = list(zip(sent_tokenizer(text),results_list)) |
return facts_spans |
def fin_ext(text): |
results = remote_clx(sent_tokenizer(text)) |
return make_spans(text,results) |
@st.cache_data |
def extract_relations_from_model_output(text): |
relations = [] |
relation, subject, relation, object_ = '', '', '', '' |
text = text.strip() |
current = 'x' |
text_replaced = text.replace("<s>", "").replace("<pad>", "").replace("</s>", "") |
for token in text_replaced.split(): |
if token == "<triplet>": |
current = 't' |
if relation != '': |
relations.append({ |
'head': subject.strip(), |
'type': relation.strip(), |
'tail': object_.strip() |
}) |
relation = '' |
subject = '' |
elif token == "<subj>": |
current = 's' |
if relation != '': |
relations.append({ |
'head': subject.strip(), |
'type': relation.strip(), |
'tail': object_.strip() |
}) |
object_ = '' |
elif token == "<obj>": |
current = 'o' |
relation = '' |
else: |
if current == 't': |
subject += ' ' + token |
elif current == 's': |
object_ += ' ' + token |
elif current == 'o': |
relation += ' ' + token |
if subject != '' and relation != '' and object_ != '': |
relations.append({ |
'head': subject.strip(), |
'type': relation.strip(), |
'tail': object_.strip() |
}) |
return relations |
def from_text_to_kb(text, model, tokenizer, article_url, span_length=128, article_title=None, |
article_publish_date=None, verbose=False): |
inputs = tokenizer([text], return_tensors="pt") |
num_tokens = len(inputs["input_ids"][0]) |
if verbose: |
print(f"Input has {num_tokens} tokens") |
num_spans = math.ceil(num_tokens / span_length) |
if verbose: |
print(f"Input has {num_spans} spans") |
overlap = math.ceil((num_spans * span_length - num_tokens) / |
max(num_spans - 1, 1)) |
spans_boundaries = [] |
start = 0 |
for i in range(num_spans): |
spans_boundaries.append([start + span_length * i, |
start + span_length * (i + 1)]) |
start -= overlap |
if verbose: |
print(f"Span boundaries are {spans_boundaries}") |
tensor_ids = [inputs["input_ids"][0][boundary[0]:boundary[1]] |
for boundary in spans_boundaries] |
tensor_masks = [inputs["attention_mask"][0][boundary[0]:boundary[1]] |
for boundary in spans_boundaries] |
inputs = { |
"input_ids": torch.stack(tensor_ids), |
"attention_mask": torch.stack(tensor_masks) |
} |
num_return_sequences = 3 |
gen_kwargs = { |
"max_length": 256, |
"length_penalty": 0, |
"num_beams": 3, |
"num_return_sequences": num_return_sequences |
} |
generated_tokens = model.generate( |
**inputs, |
**gen_kwargs, |
) |
decoded_preds = tokenizer.batch_decode(generated_tokens, |
skip_special_tokens=False) |
kb = KB() |
i = 0 |
for sentence_pred in decoded_preds: |
current_span_index = i // num_return_sequences |
relations = extract_relations_from_model_output(sentence_pred) |
for relation in relations: |
relation["meta"] = { |
article_url: { |
"spans": [spans_boundaries[current_span_index]] |
} |
} |
kb.add_relation(relation, article_title, article_publish_date) |
i += 1 |
return kb |
def get_article(url): |
article = Article(url) |
article.download() |
article.parse() |
return article |
def from_url_to_kb(url, model, tokenizer): |
article = get_article(url) |
config = { |
"article_title": article.title, |
"article_publish_date": article.publish_date |
} |
kb = from_text_to_kb(article.text, model, tokenizer, article.url, **config) |
return kb |
def get_news_links(query, lang="en", region="US", pages=1): |
googlenews = GoogleNews(lang=lang, region=region) |
googlenews.search(query) |
all_urls = [] |
for page in range(pages): |
googlenews.get_page(page) |
all_urls += googlenews.get_links() |
return list(set(all_urls)) |
def from_urls_to_kb(urls, model, tokenizer, verbose=False): |
kb = KB() |
if verbose: |
print(f"{len(urls)} links to visit") |
for url in urls: |
if verbose: |
print(f"Visiting {url}...") |
try: |
kb_url = from_url_to_kb(url, model, tokenizer) |
kb.merge_with_kb(kb_url) |
except ArticleException: |
if verbose: |
print(f" Couldn't download article at url {url}") |
return kb |
def save_network_html(kb, filename="network.html"): |
net = Network(directed=True, width="700px", height="700px") |
color_entity = "#00FF00" |
for e in kb.entities: |
net.add_node(e, shape="circle", color=color_entity) |
for r in kb.relations: |
net.add_edge(r["head"], r["tail"], |
title=r["type"], label=r["type"]) |
net.repulsion( |
node_distance=200, |
central_gravity=0.2, |
spring_length=200, |
spring_strength=0.05, |
damping=0.09 |
) |
net.set_edge_smooth('dynamic') |
net.show(filename) |
def save_kb(kb, filename): |
with open(filename, "wb") as f: |
pickle.dump(kb, f) |
class CustomUnpickler(pickle.Unpickler): |
def find_class(self, module, name): |
if name == 'KB': |
return KB |
return super().find_class(module, name) |
def load_kb(filename): |
res = None |
with open(filename, "rb") as f: |
res = CustomUnpickler(f).load() |
return res |
class KB(): |
def __init__(self): |
self.entities = {} |
self.relations = [] |
self.sources = {} |
def merge_with_kb(self, kb2): |
for r in kb2.relations: |
article_url = list(r["meta"].keys())[0] |
source_data = kb2.sources[article_url] |
self.add_relation(r, source_data["article_title"], |
source_data["article_publish_date"]) |
def are_relations_equal(self, r1, r2): |
return all(r1[attr] == r2[attr] for attr in ["head", "type", "tail"]) |
def exists_relation(self, r1): |
return any(self.are_relations_equal(r1, r2) for r2 in self.relations) |
def merge_relations(self, r2): |
r1 = [r for r in self.relations |
if self.are_relations_equal(r2, r)][0] |
article_url = list(r2["meta"].keys())[0] |
if article_url not in r1["meta"]: |
r1["meta"][article_url] = r2["meta"][article_url] |
else: |
spans_to_add = [span for span in r2["meta"][article_url]["spans"] |
if span not in r1["meta"][article_url]["spans"]] |
r1["meta"][article_url]["spans"] += spans_to_add |
def get_wikipedia_data(self, candidate_entity): |
try: |
page = wikipedia.page(candidate_entity, auto_suggest=False) |
entity_data = { |
"title": page.title, |
"url": page.url, |
"summary": page.summary |
} |
return entity_data |
except: |
return None |
def add_entity(self, e): |
self.entities[e["title"]] = {k:v for k,v in e.items() if k != "title"} |
def add_relation(self, r, article_title, article_publish_date): |
candidate_entities = [r["head"], r["tail"]] |
entities = [self.get_wikipedia_data(ent) for ent in candidate_entities] |
if any(ent is None for ent in entities): |
return |
for e in entities: |
self.add_entity(e) |
r["head"] = entities[0]["title"] |
r["tail"] = entities[1]["title"] |
article_url = list(r["meta"].keys())[0] |
if article_url not in self.sources: |
self.sources[article_url] = { |
"article_title": article_title, |
"article_publish_date": article_publish_date |
} |
if not self.exists_relation(r): |
self.relations.append(r) |
else: |
self.merge_relations(r) |
def get_textual_representation(self): |
res = "" |
res += "### Entities\n" |
for e in self.entities.items(): |
e_temp = (e[0], {k:(v[:100] + "..." if k == "summary" else v) for k,v in e[1].items()}) |
res += f"- {e_temp}\n" |
res += "\n" |
res += "### Relations\n" |
for r in self.relations: |
res += f"- {r}\n" |
res += "\n" |
res += "### Sources\n" |
for s in self.sources.items(): |
res += f"- {s}\n" |
return res |
def save_network_html(kb, filename="network.html"): |
net = Network(directed=True, width="700px", height="700px", bgcolor="#eeeeee") |
color_entity = "#00FF00" |
for e in kb.entities: |
net.add_node(e, shape="circle", color=color_entity) |
for r in kb.relations: |
net.add_edge(r["head"], r["tail"], |
title=r["type"], label=r["type"]) |
net.repulsion( |
node_distance=200, |
central_gravity=0.2, |
spring_length=200, |
spring_strength=0.05, |
damping=0.09 |
) |
net.set_edge_smooth('dynamic') |
net.show(filename) |