|
import os |
|
import re |
|
import time |
|
from pathlib import Path |
|
|
|
from relik.retriever import GoldenRetriever |
|
|
|
from relik.retriever.indexers.inmemory import InMemoryDocumentIndex |
|
from relik.retriever.indexers.document import DocumentStore |
|
from relik.retriever import GoldenRetriever |
|
from relik.reader.pytorch_modules.span import RelikReaderForSpanExtraction |
|
import requests |
|
import streamlit as st |
|
from spacy import displacy |
|
from streamlit_extras.badges import badge |
|
from streamlit_extras.stylable_container import stylable_container |
|
import logging |
|
|
|
|
|
logging.basicConfig(level=logging.INFO, format='%(asctime)s %(message)s') |
|
logger = logging.getLogger() |
|
|
|
|
|
|
|
import random |
|
|
|
from relik.inference.annotator import Relik |
|
from relik.inference.data.objects import ( |
|
AnnotationType, |
|
RelikOutput, |
|
Span, |
|
TaskType, |
|
Triples, |
|
) |
|
|
|
def get_random_color(ents): |
|
colors = {} |
|
random_colors = generate_pastel_colors(len(ents)) |
|
for ent in ents: |
|
colors[ent] = random_colors.pop(random.randint(0, len(random_colors) - 1)) |
|
return colors |
|
|
|
|
|
def floatrange(start, stop, steps): |
|
if int(steps) == 1: |
|
return [stop] |
|
return [ |
|
start + float(i) * (stop - start) / (float(steps) - 1) for i in range(steps) |
|
] |
|
|
|
|
|
def hsl_to_rgb(h, s, l): |
|
def hue_2_rgb(v1, v2, v_h): |
|
while v_h < 0.0: |
|
v_h += 1.0 |
|
while v_h > 1.0: |
|
v_h -= 1.0 |
|
if 6 * v_h < 1.0: |
|
return v1 + (v2 - v1) * 6.0 * v_h |
|
if 2 * v_h < 1.0: |
|
return v2 |
|
if 3 * v_h < 2.0: |
|
return v1 + (v2 - v1) * ((2.0 / 3.0) - v_h) * 6.0 |
|
return v1 |
|
|
|
|
|
|
|
|
|
r, b, g = (l * 255,) * 3 |
|
if s != 0.0: |
|
if l < 0.5: |
|
var_2 = l * (1.0 + s) |
|
else: |
|
var_2 = (l + s) - (s * l) |
|
var_1 = 2.0 * l - var_2 |
|
r = 255 * hue_2_rgb(var_1, var_2, h + (1.0 / 3.0)) |
|
g = 255 * hue_2_rgb(var_1, var_2, h) |
|
b = 255 * hue_2_rgb(var_1, var_2, h - (1.0 / 3.0)) |
|
|
|
return int(round(r)), int(round(g)), int(round(b)) |
|
|
|
|
|
def generate_pastel_colors(n): |
|
"""Return different pastel colours. |
|
|
|
Input: |
|
n (integer) : The number of colors to return |
|
|
|
Output: |
|
A list of colors in HTML notation (eg.['#cce0ff', '#ffcccc', '#ccffe0', '#f5ccff', '#f5ffcc']) |
|
|
|
Example: |
|
>>> print generate_pastel_colors(5) |
|
['#cce0ff', '#f5ccff', '#ffcccc', '#f5ffcc', '#ccffe0'] |
|
""" |
|
if n == 0: |
|
return [] |
|
|
|
|
|
start_hue = 0.0 |
|
saturation = 1.0 |
|
lightness = 0.9 |
|
|
|
|
|
|
|
return [ |
|
"#%02x%02x%02x" % hsl_to_rgb(hue, saturation, lightness) |
|
for hue in floatrange(start_hue, start_hue + 1, n + 1) |
|
][:-1] |
|
|
|
|
|
def set_sidebar(css): |
|
with st.sidebar: |
|
st.markdown(f"<style>{css}</style>", unsafe_allow_html=True) |
|
st.image( |
|
"https://upload.wikimedia.org/wikipedia/commons/8/87/The_World_Bank_logo.svg", |
|
use_column_width=True, |
|
) |
|
st.markdown("### World Bank") |
|
st.markdown("### DIME") |
|
|
|
def get_el_annotations(response): |
|
i_link_wrapper = "<link rel='stylesheet' href='https://cdnjs.cloudflare.com/ajax/libs/font-awesome/6.4.2/css/all.min.css'><a href='https://developmentevidence.3ieimpact.org/taxonomy-search-detail/intervention/disaggregated-intervention/{}' style='color: #414141'> <span style='font-size: 1.0em; font-family: monospace'> Intervention {}</span></a>" |
|
o_link_wrapper = "<link rel='stylesheet' href='https://cdnjs.cloudflare.com/ajax/libs/font-awesome/6.4.2/css/all.min.css'><a href='https://developmentevidence.3ieimpact.org/taxonomy-search-detail/intervention/disaggregated-outcome/{}' style='color: #414141'><span style='font-size: 1.0em; font-family: monospace'> Outcome: {}</span></a>" |
|
|
|
ents = [ |
|
{ |
|
"start": l.start, |
|
"end": l.end, |
|
"label": i_link_wrapper.format(l.label[0].upper() + l.label[1:].replace("/", "%2").replace(" ", "%20").replace("&","%26"), l.label), |
|
} if io_map[l.label] == "intervention" else |
|
{ |
|
"start": l.start, |
|
"end": l.end, |
|
"label": o_link_wrapper.format(l.label[0].upper() + l.label[1:].replace("/", "%2").replace(" ", "%20").replace("&","%26"), l.label), |
|
} |
|
for l in response.spans |
|
] |
|
dict_of_ents = {"text": response.text, "ents": ents} |
|
label_in_text = set(l["label"] for l in dict_of_ents["ents"]) |
|
options = {"ents": label_in_text, "colors": get_random_color(label_in_text)} |
|
return dict_of_ents, options |
|
|
|
|
|
|
|
def get_retriever_annotations(response): |
|
el_link_wrapper = "<link rel='stylesheet' href='https://cdnjs.cloudflare.com/ajax/libs/font-awesome/6.4.2/css/all.min.css'><a href='https://en.wikipedia.org/wiki/{}' style='color: #414141'><i class='fa-brands fa-wikipedia-w fa-xs'></i> <span style='font-size: 1.0em; font-family: monospace'> {}</span></a>" |
|
|
|
ents = [l.text |
|
for l in response.candidates[TaskType.SPAN] |
|
] |
|
dict_of_ents = {"text": response.text, "ents": ents} |
|
label_in_text = set(l for l in dict_of_ents["ents"]) |
|
options = {"ents": label_in_text, "colors": get_random_color(label_in_text)} |
|
return dict_of_ents, options |
|
|
|
|
|
def get_retriever_annotations_candidates(text, ents): |
|
el_link_wrapper = "<link rel='stylesheet' href='https://cdnjs.cloudflare.com/ajax/libs/font-awesome/6.4.2/css/all.min.css'><a href='https://en.wikipedia.org/wiki/{}' style='color: #414141'><i class='fa-brands fa-wikipedia-w fa-xs'></i> <span style='font-size: 1.0em; font-family: monospace'> {}</span></a>" |
|
|
|
dict_of_ents = {"text": text, "ents": ents} |
|
label_in_text = set(l for l in dict_of_ents["ents"]) |
|
options = {"ents": label_in_text, "colors": get_random_color(label_in_text)} |
|
return dict_of_ents, options |
|
|
|
|
|
import json |
|
io_map = {} |
|
with open("/home/user/app/models/retriever/document_index/documents.jsonl", "r") as r: |
|
for line in r: |
|
element = json.loads(line) |
|
io_map[element["text"]] = element["metadata"]["type"] |
|
|
|
|
|
import json |
|
db_set = set() |
|
with open("models/retriever/intervention/gpt/db/document_index/documents.jsonl", "r") as r: |
|
for line in r: |
|
element = json.loads(line) |
|
db_set.add(element["text"]) |
|
|
|
with open("models/retriever/outcome/gpt/db/document_index/documents.jsonl", "r") as r: |
|
for line in r: |
|
element = json.loads(line) |
|
db_set.add(element["text"]) |
|
|
|
|
|
@st.cache_resource() |
|
def load_model(): |
|
|
|
retriever_question = GoldenRetriever( |
|
question_encoder="/home/user/app/models/retriever/question_encoder", |
|
document_index="/home/user/app/models/retriever/document_index/questions" |
|
|
|
) |
|
|
|
retriever_intervention_gpt_taxonomy = GoldenRetriever( |
|
question_encoder="models/retriever/intervention/gpt+llama/taxonomy/question_encoder", |
|
document_index="models/retriever/intervention/gpt+llama/taxonomy/document_index" |
|
|
|
) |
|
|
|
|
|
|
|
retriever_intervention_gpt_db = GoldenRetriever( |
|
question_encoder="models/retriever/intervention/gpt+llama/db/question_encoder", |
|
document_index="models/retriever/intervention/gpt+llama/db/document_index" |
|
|
|
) |
|
|
|
|
|
retriever_outcome_gpt_taxonomy = GoldenRetriever( |
|
question_encoder="models/retriever/outcome/gpt+llama/taxonomy/question_encoder", |
|
document_index="models/retriever/outcome/gpt+llama/taxonomy/document_index" |
|
|
|
) |
|
|
|
|
|
retriever_outcome_gpt_db = GoldenRetriever( |
|
question_encoder="models/retriever/outcome/gpt+llama/db/question_encoder", |
|
document_index="models/retriever/outcome/gpt+llama/db/document_index" |
|
|
|
) |
|
|
|
|
|
reader = RelikReaderForSpanExtraction("/home/user/app/models/small-extended-large-batch", |
|
dataset_kwargs={"use_nme": True}) |
|
|
|
relik_question = Relik(reader=reader, retriever=retriever_question, window_size="none", top_k=100, task="span", device="cpu", document_index_device="cpu") |
|
|
|
return [relik_question, retriever_intervention_gpt_db, retriever_outcome_gpt_db, retriever_intervention_gpt_taxonomy, retriever_outcome_gpt_taxonomy] |
|
|
|
def set_intro(css): |
|
|
|
|
|
st.markdown("# ImpactAI") |
|
st.image( |
|
"http://35.237.102.64/public/logo.png", |
|
) |
|
st.markdown( |
|
"### 3ie taxonomy level 4 Intervention/Outcome candidate retriever with Entity Linking" |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from datetime import datetime |
|
from pathlib import Path |
|
from huggingface_hub import HfApi, CommitScheduler |
|
from uuid import uuid4 |
|
|
|
|
|
api = HfApi() |
|
api.set_access_token(os.getenv("HF_TOKEN")) |
|
|
|
JSON_DATASET_DIR = Path("json_demo_selected_io") |
|
JSON_DATASET_DIR.mkdir(parents=True, exist_ok=True) |
|
JSON_DATASET_PATH = JSON_DATASET_DIR / f"train-{uuid4()}.json" |
|
|
|
scheduler = CommitScheduler( |
|
repo_id="demo-retriever", |
|
repo_type="dataset", |
|
folder_path=JSON_DATASET_DIR, |
|
path_in_repo="data", |
|
hf_api=api |
|
) |
|
|
|
|
|
def write_candidates_to_file(text, candidates, selected_candidates): |
|
logger.info(f"Text: {text}\tCandidates: {str(candidates)}\tSelected Candidates: {str(selected_candidates)}\n") |
|
with scheduler.lock: |
|
with JSON_DATASET_PATH.open("a") as f: |
|
json.dump({"text": text, "Candidates": [candidate for candidate in candidates], "Selected Candidates": [candidate for candidate in selected_candidates], "datetime": datetime.now().isoformat()}, f) |
|
f.write("\n") |
|
|
|
def run_client(): |
|
with open(Path(__file__).parent / "style.css") as f: |
|
css = f.read() |
|
|
|
st.set_page_config( |
|
page_title="ImpactAI", |
|
page_icon="🦮", |
|
layout="wide", |
|
) |
|
set_sidebar(css) |
|
set_intro(css) |
|
|
|
|
|
analysis_type = st.radio( |
|
"Choose analysis type:", |
|
options=["Retriever", "Entity Linking"], |
|
index=0 |
|
) |
|
|
|
selection_options = ["DB Intervention", "DB Outcome", "Taxonomy Intervention", "Taxonomy Outcome", "Top-k DB in Taxonomy Intervention", "Top-k DB in Taxonmy Outcome", ] |
|
|
|
if analysis_type == "Retriever": |
|
|
|
selection_list = st.selectbox( |
|
"Select an option:", |
|
options=selection_options |
|
) |
|
|
|
|
|
text = st.text_area( |
|
"Enter Text Below:", |
|
value="", |
|
height=200, |
|
max_chars=1500, |
|
) |
|
|
|
with stylable_container( |
|
key="annotate_button", |
|
css_styles=""" |
|
button { |
|
background-color: #a8ebff; |
|
color: black; |
|
border-radius: 25px; |
|
} |
|
""", |
|
): |
|
submit = st.button("Annotate") |
|
|
|
|
|
if "relik_model" not in st.session_state.keys(): |
|
st.session_state["relik_model"] = load_model() |
|
relik_model = st.session_state["relik_model"][0] |
|
|
|
if 'candidates' not in st.session_state: |
|
st.session_state['candidates'] = [] |
|
if 'selected_candidates' not in st.session_state: |
|
st.session_state['selected_candidates'] = [] |
|
|
|
|
|
if submit: |
|
|
|
if analysis_type == "Entity Linking": |
|
relik_model = st.session_state["relik_model"][0] |
|
else: |
|
model_idx = selection_options.index(selection_list) |
|
if selection_list == "Top-k DB in Taxonomy Intervention" or selection_list == "Top-k DB in Taxonmy Outcome": |
|
relik_model = st.session_state["relik_model"][model_idx-1] |
|
else: |
|
relik_model = st.session_state["relik_model"][model_idx+1] |
|
|
|
|
|
text = text.strip() |
|
if text: |
|
st.markdown("####") |
|
with st.spinner(text="In progress"): |
|
if analysis_type == "Entity Linking": |
|
response = relik_model(text) |
|
|
|
dict_of_ents, options = get_el_annotations(response=response) |
|
dict_of_ents_candidates, options_candidates = get_retriever_annotations(response=response) |
|
|
|
st.markdown("#### Entity Linking") |
|
|
|
display = displacy.render( |
|
dict_of_ents, manual=True, style="ent", options=options |
|
) |
|
|
|
|
|
display = display.replace("\n", " ") |
|
|
|
|
|
display = display.replace("border-radius: 0.35em;", "border-radius: 0.35em; white-space: nowrap;") |
|
|
|
with st.container(): |
|
st.write(display, unsafe_allow_html=True) |
|
candidate_text = "".join(f"<li style='color: black;'>Intervention: {candidate}</li>" if io_map[candidate] == "intervention" else f"<li style='color: black;'>Outcome: {candidate}</li>" for candidate in dict_of_ents_candidates["ents"][0:10]) |
|
text = """ |
|
<h2 style='color: black;'>Possible Candidates:</h2> |
|
<ul style='color: black;'> |
|
""" + candidate_text + "</ul>" |
|
|
|
st.markdown(text, unsafe_allow_html=True) |
|
else: |
|
if selection_list == "Top-k DB in Taxonomy Intervention" or selection_list == "Top-k DB in Taxonomy Outcome": |
|
response = relik_model.retrieve(text, k=50, batch_size=400, progress_bar=False) |
|
candidates_text = [pred.document.text for pred in response[0] if pred.document.text in db_set] |
|
else: |
|
response = relik_model.retrieve(text, k=10, batch_size=400, progress_bar=False) |
|
candidates_text = [pred.document.text for pred in response[0]] |
|
|
|
|
|
if candidates_text: |
|
st.session_state.candidates = candidates_text[:10] |
|
|
|
else: |
|
st.session_state.candidates = [] |
|
st.session_state.selected_candidates = [] |
|
st.markdown("<h2 style='color: black;'>No Candidates Found</h2>", unsafe_allow_html=True) |
|
|
|
else: |
|
st.error("Please enter some text.") |
|
|
|
|
|
if st.session_state.candidates: |
|
dict_of_ents_candidates, options_candidates = get_retriever_annotations_candidates(text, st.session_state.candidates) |
|
st.markdown("<h2 style='color: black;'>Possible Candidates:</h2>", unsafe_allow_html=True) |
|
for candidate in dict_of_ents_candidates["ents"]: |
|
checked = candidate in st.session_state.selected_candidates |
|
if st.checkbox(candidate, key=candidate, value=checked): |
|
if candidate not in st.session_state.selected_candidates: |
|
st.session_state.selected_candidates.append(candidate) |
|
else: |
|
if candidate in st.session_state.selected_candidates: |
|
st.session_state.selected_candidates.remove(candidate) |
|
|
|
if st.button("Save Selected Candidates"): |
|
if write_candidates_to_file(text, dict_of_ents_candidates["ents"], st.session_state.selected_candidates): |
|
st.success("Selected candidates have been saved to file.") |
|
|
|
|
|
if __name__ == "__main__": |
|
run_client() |
|
|