Spaces:
Runtime error
Runtime error
import gradio as gr | |
from functools import partial | |
from transformers import pipeline, pipelines | |
from sentence_transformers import SentenceTransformer, util | |
from scipy.special import softmax | |
import os | |
import json | |
###################### | |
##### INFERENCE ###### | |
###################### | |
class SentenceSimilarity: | |
def __init__(self, model: str): | |
self.model = SentenceTransformer(model) | |
def __call__(self, query: str, corpus: list[str]): | |
query_embedding = self.model.encode(query) | |
corpus_embeddings = self.model.encode(corpus) | |
output = util.semantic_search(query_embedding, corpus_embeddings, top_k=5) | |
return output[0] | |
# Sentence Similarity | |
def sentence_similarity( | |
query: str, | |
texts: list[str], | |
titles: list[str], | |
urls: list[str], | |
pipe: SentenceSimilarity, | |
): | |
answer = pipe(query=query, corpus=texts) | |
df = [ | |
[ | |
f"<a href='{urls[ans['corpus_id']]} target='_blank'>{titles[ans['corpus_id']]}</a>" | |
] | |
for ans in answer | |
] | |
return df | |
# Text Analysis | |
def cls_inference(input: list[str], pipe: pipeline) -> dict: | |
results = pipe(input, top_k=None) | |
return {x["label"]: x["score"] for x in results} | |
# POSP | |
def tagging(text: str, pipe: pipeline): | |
output = pipe(text) | |
return {"text": text, "entities": output} | |
# Text Analysis | |
def text_analysis(text, pipes: list[pipeline]): | |
outputs = [] | |
for pipe in pipes: | |
if isinstance(pipe, pipelines.token_classification.TokenClassificationPipeline): | |
outputs.append(tagging(text, pipe)) | |
else: | |
outputs.append(cls_inference(text, pipe)) | |
return outputs | |
###################### | |
##### INTERFACE ###### | |
###################### | |
def text_interface( | |
pipe: pipeline, examples: list[str], output_label: str, title: str, desc: str | |
): | |
return gr.Interface( | |
fn=partial(cls_inference, pipe=pipe), | |
inputs=[ | |
gr.Textbox(lines=5, label="Input Text"), | |
], | |
title=title, | |
description=desc, | |
outputs=[gr.Label(label=output_label)], | |
examples=examples, | |
allow_flagging="never", | |
) | |
def search_interface( | |
pipe: SentenceSimilarity, | |
examples: list[str], | |
output_label: str, | |
title: str, | |
desc: str, | |
sample: str, | |
): | |
f = open(sample) | |
data = json.load(f) | |
with gr.Blocks() as sentence_similarity_interface: | |
gr.Markdown(title) | |
gr.Markdown(desc) | |
with gr.Row(): | |
with gr.Column(): | |
input_text = gr.Textbox(lines=5, label="Query") | |
df = gr.DataFrame( | |
[ | |
[id, f"<a href='{url}' target='_blank'>{title}</a>"] | |
for id, title, url in zip( | |
data["id"], data["title"], data["url"] | |
) | |
], | |
headers=["ID", "Title"], | |
wrap=True, | |
datatype=["markdown", "html"], | |
interactive=False, | |
height=300, | |
) | |
button = gr.Button("Search...") | |
output = gr.DataFrame( | |
headers=["Title"], | |
wrap=True, | |
datatype=["html"], | |
interactive=False, | |
) | |
button.click( | |
fn=partial( | |
sentence_similarity, | |
pipe=pipe, | |
texts=data["text"], | |
titles=data["title"], | |
urls=data["url"], | |
), | |
inputs=[input_text], | |
outputs=[output], | |
) | |
return sentence_similarity_interface | |
def token_classification_interface( | |
pipe: pipeline, examples: list[str], output_label: str, title: str, desc: str | |
): | |
return gr.Interface( | |
fn=partial(tagging, pipe=pipe), | |
inputs=[ | |
gr.Textbox(placeholder="Masukan kalimat di sini...", label="Input Text"), | |
], | |
outputs=[gr.HighlightedText(label=output_label)], | |
title=title, | |
examples=examples, | |
description=desc, | |
allow_flagging="never", | |
) | |
def text_analysis_interface( | |
pipe: list, examples: list[str], output_label: str, title: str, desc: str | |
): | |
with gr.Blocks() as text_analysis_interface: | |
gr.Markdown(title) | |
gr.Markdown(desc) | |
input_text = gr.Textbox(lines=5, label="Input Text") | |
with gr.Row(): | |
outputs = [ | |
( | |
gr.HighlightedText(label=label) | |
if isinstance( | |
p, pipelines.token_classification.TokenClassificationPipeline | |
) | |
else gr.Label(label=label) | |
) | |
for label, p in zip(output_label, pipe) | |
] | |
btn = gr.Button("Analyze") | |
btn.click( | |
fn=partial(text_analysis, pipes=pipe), | |
inputs=[input_text], | |
outputs=outputs, | |
) | |
gr.Examples( | |
examples=examples, | |
inputs=input_text, | |
outputs=outputs, | |
) | |
return text_analysis_interface | |
# Summary | |
# summary_interface = gr.Interface.from_pipeline( | |
# pipes["summarization"], | |
# title="Summarization", | |
# examples=details["summarization"]["examples"], | |
# description=details["summarization"]["description"], | |
# allow_flagging="never", | |
# ) | |