|
import os |
|
import re |
|
import time |
|
from pathlib import Path |
|
|
|
from relik.retriever import GoldenRetriever |
|
|
|
from relik.retriever.indexers.inmemory import InMemoryDocumentIndex |
|
from relik.retriever.indexers.document import DocumentStore |
|
from relik.retriever import GoldenRetriever |
|
from relik.reader.pytorch_modules.span import RelikReaderForSpanExtraction |
|
import requests |
|
import streamlit as st |
|
from spacy import displacy |
|
from streamlit_extras.badges import badge |
|
from streamlit_extras.stylable_container import stylable_container |
|
|
|
|
|
|
|
import random |
|
|
|
from relik.inference.annotator import Relik |
|
from relik.inference.data.objects import ( |
|
AnnotationType, |
|
RelikOutput, |
|
Span, |
|
TaskType, |
|
Triples, |
|
) |
|
|
|
def get_random_color(ents): |
|
colors = {} |
|
random_colors = generate_pastel_colors(len(ents)) |
|
for ent in ents: |
|
colors[ent] = random_colors.pop(random.randint(0, len(random_colors) - 1)) |
|
return colors |
|
|
|
|
|
def floatrange(start, stop, steps): |
|
if int(steps) == 1: |
|
return [stop] |
|
return [ |
|
start + float(i) * (stop - start) / (float(steps) - 1) for i in range(steps) |
|
] |
|
|
|
|
|
def hsl_to_rgb(h, s, l): |
|
def hue_2_rgb(v1, v2, v_h): |
|
while v_h < 0.0: |
|
v_h += 1.0 |
|
while v_h > 1.0: |
|
v_h -= 1.0 |
|
if 6 * v_h < 1.0: |
|
return v1 + (v2 - v1) * 6.0 * v_h |
|
if 2 * v_h < 1.0: |
|
return v2 |
|
if 3 * v_h < 2.0: |
|
return v1 + (v2 - v1) * ((2.0 / 3.0) - v_h) * 6.0 |
|
return v1 |
|
|
|
|
|
|
|
|
|
r, b, g = (l * 255,) * 3 |
|
if s != 0.0: |
|
if l < 0.5: |
|
var_2 = l * (1.0 + s) |
|
else: |
|
var_2 = (l + s) - (s * l) |
|
var_1 = 2.0 * l - var_2 |
|
r = 255 * hue_2_rgb(var_1, var_2, h + (1.0 / 3.0)) |
|
g = 255 * hue_2_rgb(var_1, var_2, h) |
|
b = 255 * hue_2_rgb(var_1, var_2, h - (1.0 / 3.0)) |
|
|
|
return int(round(r)), int(round(g)), int(round(b)) |
|
|
|
|
|
def generate_pastel_colors(n): |
|
"""Return different pastel colours. |
|
|
|
Input: |
|
n (integer) : The number of colors to return |
|
|
|
Output: |
|
A list of colors in HTML notation (eg.['#cce0ff', '#ffcccc', '#ccffe0', '#f5ccff', '#f5ffcc']) |
|
|
|
Example: |
|
>>> print generate_pastel_colors(5) |
|
['#cce0ff', '#f5ccff', '#ffcccc', '#f5ffcc', '#ccffe0'] |
|
""" |
|
if n == 0: |
|
return [] |
|
|
|
|
|
start_hue = 0.0 |
|
saturation = 1.0 |
|
lightness = 0.9 |
|
|
|
|
|
|
|
return [ |
|
"#%02x%02x%02x" % hsl_to_rgb(hue, saturation, lightness) |
|
for hue in floatrange(start_hue, start_hue + 1, n + 1) |
|
][:-1] |
|
|
|
|
|
def set_sidebar(css): |
|
with st.sidebar: |
|
st.markdown(f"<style>{css}</style>", unsafe_allow_html=True) |
|
st.image( |
|
"https://upload.wikimedia.org/wikipedia/commons/8/87/The_World_Bank_logo.svg", |
|
use_column_width=True, |
|
) |
|
st.markdown("### World Bank") |
|
st.markdown("### DIME") |
|
|
|
def get_el_annotations(response): |
|
i_link_wrapper = "<link rel='stylesheet' href='https://cdnjs.cloudflare.com/ajax/libs/font-awesome/6.4.2/css/all.min.css'><a href='https://developmentevidence.3ieimpact.org/taxonomy-search-detail/intervention/disaggregated-intervention/{}' style='color: #414141'> <span style='font-size: 1.0em; font-family: monospace'> Intervention {}</span></a>" |
|
o_link_wrapper = "<link rel='stylesheet' href='https://cdnjs.cloudflare.com/ajax/libs/font-awesome/6.4.2/css/all.min.css'><a href='https://developmentevidence.3ieimpact.org/taxonomy-search-detail/intervention/disaggregated-outcome/{}' style='color: #414141'><span style='font-size: 1.0em; font-family: monospace'> Outcome: {}</span></a>" |
|
|
|
ents = [ |
|
{ |
|
"start": l.start, |
|
"end": l.end, |
|
"label": i_link_wrapper.format(l.label[0].upper() + l.label[1:].replace("/", "%2").replace(" ", "%20").replace("&","%26"), l.label), |
|
} if io_map[l.label] == "intervention" else |
|
{ |
|
"start": l.start, |
|
"end": l.end, |
|
"label": o_link_wrapper.format(l.label[0].upper() + l.label[1:].replace("/", "%2").replace(" ", "%20").replace("&","%26"), l.label), |
|
} |
|
for l in response.spans |
|
] |
|
dict_of_ents = {"text": response.text, "ents": ents} |
|
label_in_text = set(l["label"] for l in dict_of_ents["ents"]) |
|
options = {"ents": label_in_text, "colors": get_random_color(label_in_text)} |
|
return dict_of_ents, options |
|
|
|
|
|
|
|
def get_retriever_annotations(response): |
|
el_link_wrapper = "<link rel='stylesheet' href='https://cdnjs.cloudflare.com/ajax/libs/font-awesome/6.4.2/css/all.min.css'><a href='https://en.wikipedia.org/wiki/{}' style='color: #414141'><i class='fa-brands fa-wikipedia-w fa-xs'></i> <span style='font-size: 1.0em; font-family: monospace'> {}</span></a>" |
|
|
|
ents = [l.text |
|
for l in response.candidates[TaskType.SPAN] |
|
] |
|
dict_of_ents = {"text": response.text, "ents": ents} |
|
label_in_text = set(l for l in dict_of_ents["ents"]) |
|
options = {"ents": label_in_text, "colors": get_random_color(label_in_text)} |
|
return dict_of_ents, options |
|
import json |
|
io_map = {} |
|
with open("/home/user/app/models/retriever/document_index/documents.jsonl", "r") as r: |
|
for line in r: |
|
element = json.loads(line) |
|
io_map[element["text"]] = element["metadata"]["type"] |
|
|
|
@st.cache_resource() |
|
def load_model(): |
|
|
|
retriever_question = GoldenRetriever( |
|
question_encoder="/home/user/app/models/retriever/question_encoder", |
|
document_index="home/user/app/models/retriever/document_index/questions" |
|
|
|
) |
|
|
|
retriever_intervention = GoldenRetriever( |
|
question_encoder="/home/user/app/models/retriever/question_encoder", |
|
document_index="home/user/app/models/retriever/document_index/interventions" |
|
|
|
) |
|
|
|
retriever_outcome = GoldenRetriever( |
|
question_encoder="/home/user/app/models/retriever/question_encoder", |
|
document_index="home/user/app/models/retriever/document_index/outcomes" |
|
|
|
) |
|
|
|
retriever_question_db = GoldenRetriever( |
|
question_encoder="/home/user/app/models/retriever/question_encoder", |
|
document_index="home/user/app/models/retriever/document_index/question_db" |
|
|
|
) |
|
|
|
retriever_intervention_db = GoldenRetriever( |
|
question_encoder="/home/user/app/models/retriever/question_encoder", |
|
document_index="home/user/app/models/retriever/document_index/interventions_db" |
|
|
|
) |
|
|
|
retriever_outcome_db = GoldenRetriever( |
|
question_encoder="/home/user/app/models/retriever/question_encoder", |
|
document_index="home/user/app/models/retriever/document_index/outcomes_db" |
|
|
|
) |
|
|
|
|
|
reader = RelikReaderForSpanExtraction("/home/user/app/models/small-extended-large-batch", |
|
dataset_kwargs={"use_nme": True}) |
|
|
|
relik_question = Relik(reader=reader, retriever=retriever_question, window_size="none", top_k=100, task="span", device="cuda", document_index_device="cpu") |
|
relik_intervention = Relik(reader=reader, retriever=retriever_intervention, window_size="none", top_k=100, task="span", device="cuda", document_index_device="cpu") |
|
relik_outcome = Relik(reader=reader, retriever=retriever_outcome, window_size="none", top_k=100, task="span", device="cuda", document_index_device="cpu") |
|
relik_question_db = Relik(reader=reader, retriever=retriever_question_db, window_size="none", top_k=100, task="span", device="cuda", document_index_device="cpu") |
|
relik_intrervention_db = Relik(reader=reader, retriever=retriever_intervention_db, window_size="none", top_k=100, task="span", device="cuda", document_index_device="cpu") |
|
relik_outcome_db = Relik(reader=reader, retriever=retriever_outcome_db, window_size="none", top_k=100, task="span", device="cuda", document_index_device="cpu") |
|
|
|
return [relik_question, relik_intervention, relik_outcome, relik_question_db, relik_intrervention_db, relik_outcome_db] |
|
|
|
def set_intro(css): |
|
|
|
|
|
st.markdown("# CausalAI") |
|
st.image( |
|
"http://35.237.102.64/public/logo.png", |
|
) |
|
st.markdown( |
|
"### 3ie taxonomy level 4 Intervention/Outcome candidate retriever with Entity Linking" |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def run_client(): |
|
with open(Path(__file__).parent / "style.css") as f: |
|
css = f.read() |
|
|
|
st.set_page_config( |
|
page_title="CausalAI", |
|
page_icon="🦮", |
|
layout="wide", |
|
) |
|
set_sidebar(css) |
|
set_intro(css) |
|
|
|
|
|
analysis_type = st.radio( |
|
"Choose analysis type:", |
|
options=["intervention", "outcome", "question", "db intervention", "db outcome"], |
|
index=2 |
|
) |
|
|
|
|
|
text = st.text_area( |
|
"Enter Text Below:", |
|
value="How does unconditional cash transver affect to reduce poverty?", |
|
height=200, |
|
max_chars=1500, |
|
) |
|
|
|
with stylable_container( |
|
key="annotate_button", |
|
css_styles=""" |
|
button { |
|
background-color: #a8ebff; |
|
color: black; |
|
border-radius: 25px; |
|
} |
|
""", |
|
): |
|
submit = st.button("Annotate") |
|
|
|
|
|
if "relik_model" not in st.session_state.keys(): |
|
st.session_state["relik_model"] = load_model() |
|
relik_model = st.session_state["relik_model"][0] |
|
|
|
|
|
if submit: |
|
entity_linking_bool = False |
|
|
|
|
|
if analysis_type == "question": |
|
relik_model = st.session_state["relik_model"][0] |
|
entity_linking_bool = True |
|
elif analysis_type == "intervention": |
|
relik_model = st.session_state["relik_model"][1] |
|
elif analysis_type == "outcome": |
|
relik_model = st.session_state["relik_model"][2] |
|
|
|
elif analysis_type == "db intervention": |
|
relik_model = st.session_state["relik_model"][4] |
|
|
|
else: |
|
relik_model = st.session_state["relik_model"][0] |
|
|
|
text = text.strip() |
|
if text: |
|
st.markdown("####") |
|
with st.spinner(text="In progress"): |
|
response = relik_model(text) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
dict_of_ents, options = get_el_annotations(response=response) |
|
dict_of_ents_candidates, options_candidates = get_retriever_annotations(response=response) |
|
|
|
if entity_linking_bool: |
|
st.markdown("#### Entity Linking") |
|
|
|
display = displacy.render( |
|
dict_of_ents, manual=True, style="ent", options=options |
|
) |
|
|
|
|
|
display = display.replace("\n", " ") |
|
|
|
|
|
display = display.replace("border-radius: 0.35em;", "border-radius: 0.35em; white-space: nowrap;") |
|
|
|
with st.container(): |
|
st.write(display, unsafe_allow_html=True) |
|
|
|
text = """ |
|
<h2 style='color: black;'>Possible Candidates:</h2> |
|
<ul style='color: black;'> |
|
""" + "".join(f"<li style='color: black;'>{candidate}</li>" for candidate in dict_of_ents_candidates["ents"][2:12]) + "</ul>" |
|
|
|
st.markdown(text, unsafe_allow_html=True) |
|
else: |
|
text = """ |
|
<h2 style='color: black;'>Possible Candidates:</h2> |
|
<ul style='color: black;'> |
|
""" + "".join(f"<li style='color: black;'>{candidate}</li>" for candidate in dict_of_ents_candidates["ents"][0:10]) + "</ul>" |
|
|
|
st.markdown(text, unsafe_allow_html=True) |
|
else: |
|
st.error("Please enter some text.") |
|
|
|
|
|
if __name__ == "__main__": |
|
run_client() |
|
|