import os
import re
import time
from pathlib import Path
from relik.retriever import GoldenRetriever
from relik.retriever.indexers.inmemory import InMemoryDocumentIndex
from relik.retriever.indexers.document import DocumentStore
from relik.retriever import GoldenRetriever
from relik.reader.pytorch_modules.span import RelikReaderForSpanExtraction
import requests
import streamlit as st
from spacy import displacy
from streamlit_extras.badges import badge
from streamlit_extras.stylable_container import stylable_container
# RELIK = os.getenv("RELIK", "localhost:8000/api/entities")
import random
from relik.inference.annotator import Relik
from relik.inference.data.objects import (
AnnotationType,
RelikOutput,
Span,
TaskType,
Triples,
)
def get_random_color(ents):
colors = {}
random_colors = generate_pastel_colors(len(ents))
for ent in ents:
colors[ent] = random_colors.pop(random.randint(0, len(random_colors) - 1))
return colors
def floatrange(start, stop, steps):
if int(steps) == 1:
return [stop]
return [
start + float(i) * (stop - start) / (float(steps) - 1) for i in range(steps)
]
def hsl_to_rgb(h, s, l):
def hue_2_rgb(v1, v2, v_h):
while v_h < 0.0:
v_h += 1.0
while v_h > 1.0:
v_h -= 1.0
if 6 * v_h < 1.0:
return v1 + (v2 - v1) * 6.0 * v_h
if 2 * v_h < 1.0:
return v2
if 3 * v_h < 2.0:
return v1 + (v2 - v1) * ((2.0 / 3.0) - v_h) * 6.0
return v1
# if not (0 <= s <= 1): raise ValueError, "s (saturation) parameter must be between 0 and 1."
# if not (0 <= l <= 1): raise ValueError, "l (lightness) parameter must be between 0 and 1."
r, b, g = (l * 255,) * 3
if s != 0.0:
if l < 0.5:
var_2 = l * (1.0 + s)
else:
var_2 = (l + s) - (s * l)
var_1 = 2.0 * l - var_2
r = 255 * hue_2_rgb(var_1, var_2, h + (1.0 / 3.0))
g = 255 * hue_2_rgb(var_1, var_2, h)
b = 255 * hue_2_rgb(var_1, var_2, h - (1.0 / 3.0))
return int(round(r)), int(round(g)), int(round(b))
def generate_pastel_colors(n):
"""Return different pastel colours.
Input:
n (integer) : The number of colors to return
Output:
A list of colors in HTML notation (eg.['#cce0ff', '#ffcccc', '#ccffe0', '#f5ccff', '#f5ffcc'])
Example:
>>> print generate_pastel_colors(5)
['#cce0ff', '#f5ccff', '#ffcccc', '#f5ffcc', '#ccffe0']
"""
if n == 0:
return []
# To generate colors, we use the HSL colorspace (see http://en.wikipedia.org/wiki/HSL_color_space)
start_hue = 0.0 # 0=red 1/3=0.333=green 2/3=0.666=blue
saturation = 1.0
lightness = 0.9
# We take points around the chromatic circle (hue):
# (Note: we generate n+1 colors, then drop the last one ([:-1]) because
# it equals the first one (hue 0 = hue 1))
return [
"#%02x%02x%02x" % hsl_to_rgb(hue, saturation, lightness)
for hue in floatrange(start_hue, start_hue + 1, n + 1)
][:-1]
def set_sidebar(css):
with st.sidebar:
st.markdown(f"", unsafe_allow_html=True)
st.image(
"https://upload.wikimedia.org/wikipedia/commons/8/87/The_World_Bank_logo.svg",
use_column_width=True,
)
st.markdown("### World Bank")
st.markdown("### DIME")
def get_el_annotations(response):
i_link_wrapper = " Intervention {}"
o_link_wrapper = " Outcome: {}"
# swap labels key with ents
ents = [
{
"start": l.start,
"end": l.end,
"label": i_link_wrapper.format(l.label[0].upper() + l.label[1:].replace("/", "%2").replace(" ", "%20").replace("&","%26"), l.label),
} if io_map[l.label] == "intervention" else
{
"start": l.start,
"end": l.end,
"label": o_link_wrapper.format(l.label[0].upper() + l.label[1:].replace("/", "%2").replace(" ", "%20").replace("&","%26"), l.label),
}
for l in response.spans
]
dict_of_ents = {"text": response.text, "ents": ents}
label_in_text = set(l["label"] for l in dict_of_ents["ents"])
options = {"ents": label_in_text, "colors": get_random_color(label_in_text)}
return dict_of_ents, options
def get_retriever_annotations(response):
el_link_wrapper = " {}"
# swap labels key with ents
ents = [l.text
for l in response.candidates[TaskType.SPAN]
]
dict_of_ents = {"text": response.text, "ents": ents}
label_in_text = set(l for l in dict_of_ents["ents"])
options = {"ents": label_in_text, "colors": get_random_color(label_in_text)}
return dict_of_ents, options
import json
io_map = {}
with open("/home/user/app/models/retriever/document_index/documents.jsonl", "r") as r:
for line in r:
element = json.loads(line)
io_map[element["text"]] = element["metadata"]["type"]
@st.cache_resource()
def load_model():
retriever = GoldenRetriever(
question_encoder="/home/user/app/models/retriever/question_encoder",
document_index=InMemoryDocumentIndex(
documents=DocumentStore.from_file(
"/home/user/app/models/retriever/document_index/documents.jsonl"
),
metadata_fields=["definition"],
separator=' ',
device="cpu"
),
devide="cpu"
)
retriever.index()
reader = RelikReaderForSpanExtraction("/home/user/app/models/small-extended-large-batch",
dataset_kwargs={"use_nme": True})
relik = Relik(reader=reader, retriever=retriever, window_size=32, window_stride=16, top_k=100, task="span", device="cpu", document_index_device="cpu")
return relik
def set_intro(css):
# intro
st.markdown("# CausalAI")
st.image(
"http://35.237.102.64/public/logo.png",
)
st.markdown(
"### 3ie taxonomy level 4 Intervention/Outcome candidate retriever with Entity Linking"
)
# st.markdown(
# "This is a front-end for the paper [Universal Semantic Annotator: the First Unified API "
# "for WSD, SRL and Semantic Parsing](https://www.researchgate.net/publication/360671045_Universal_Semantic_Annotator_the_First_Unified_API_for_WSD_SRL_and_Semantic_Parsing), which will be presented at LREC 2022 by "
# "[Riccardo Orlando](https://riccorl.github.io), [Simone Conia](https://c-simone.github.io/), "
# "[Stefano Faralli](https://corsidilaurea.uniroma1.it/it/users/stefanofaralliuniroma1it), and [Roberto Navigli](https://www.diag.uniroma1.it/navigli/)."
# )
def run_client():
with open(Path(__file__).parent / "style.css") as f:
css = f.read()
st.set_page_config(
page_title="CausalAI",
page_icon="🦮",
layout="wide",
)
set_sidebar(css)
set_intro(css)
# text input
text = st.text_area(
"Enter Text Below:",
value="How does unconditional cash transver affect to reduce poverty?",
height=200,
max_chars=1500,
)
with stylable_container(
key="annotate_button",
css_styles="""
button {
background-color: #a8ebff;
color: black;
border-radius: 25px;
}
""",
):
submit = st.button("Annotate")
# submit = st.button("Run")
if "relik_model" not in st.session_state.keys():
st.session_state["relik_model"] = load_model()
relik_model = st.session_state["relik_model"]
# ReLik API call
if submit:
text = text.strip()
if text:
st.markdown("####")
st.markdown("#### Entity Linking")
with st.spinner(text="In progress"):
response = relik_model(text)
# response = requests.post(RELIK, json=text)
# if response.status_code != 200:
# st.error("Error: {}".format(response.status_code))
# else:
# response = response.json()
# st.markdown("##")
dict_of_ents, options = get_el_annotations(response=response)
dict_of_ents_candidates, options_candidates = get_retriever_annotations(response=response)
display = displacy.render(
dict_of_ents, manual=True, style="ent", options=options
)
display = display.replace("\n", " ")
# heurstic, prevents split of annotation decorations
display = display.replace("border-radius: 0.35em;", "border-radius: 0.35em; white-space: nowrap;")
with st.container():
st.write(display, unsafe_allow_html=True)
text = """
## Possible Candidates:
- {}
""".format("\\n- ".join(dict_of_ents_candidates["ents"][2:12]))
st.markdown(text, unsafe_allow_html=True)
else:
st.error("Please enter some text.")
if __name__ == "__main__":
run_client()