File size: 2,496 Bytes
894b24d
 
 
 
 
 
 
 
 
 
ed49033
1033026
894b24d
 
 
 
 
 
 
625b3d8
 
894b24d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
080ce23
894b24d
 
 
 
 
 
 
 
d3f1526
 
894b24d
d3f1526
d4df546
72fee02
 
 
 
57aa9a4
72fee02
953205d
1033026
72fee02
 
953205d
ed49033
 
57aa9a4
 
 
 
 
 
 
16e6449
72fee02
625b3d8
cc4118d
d4df546
52ded96
894b24d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
from metrics import calc_metrics
import gradio as gr
from openai import OpenAI
import os

from transformers import pipeline
# from dotenv import load_dotenv, find_dotenv
import huggingface_hub
import json
from evaluate_data import store_sample_data, get_metrics_trf
from sentence_transformers import SentenceTransformer
from tqdm import tqdm

hf_token= os.environ['HF_TOKEN']
huggingface_hub.login(hf_token)

pipe = pipeline("token-classification", model="elshehawy/finer-ord-transformers", aggregation_strategy="first")


# llm_model = 'gpt-3.5-turbo-0125'
llm_model = 'gpt-4-0125-preview'
# openai.api_key = os.environ['OPENAI_API_KEY']

client = OpenAI(
    api_key=os.environ.get("OPENAI_API_KEY"),
)


def get_completion(prompt, model=llm_model):
    messages = [{"role": "user", "content": prompt}]
    response = client.chat.completions.create(
        messages=messages,
        model=model,
        temperature=0,
    )
    return response.choices[0].message.content


def find_orgs_gpt(sentence):
    prompt = f"""
    In context of named entity recognition (NER), find all organizations in the text delimited by triple backticks.
    
    text:
    ```
    {sentence}
    ```
    You should output only a list of organizations and follow this output format exactly: ["org_1", "org_2", "org_3"], or [] if there are no organizations.
    """
    
    sent_orgs_str = get_completion(prompt)
    sent_orgs = json.loads(sent_orgs_str)
    
    return sent_orgs


example = """
My latest exclusive for The Hill : Conservative frustration over Republican efforts to force a House vote on reauthorizing the Export - Import Bank boiled over Wednesday during a contentious GOP meeting.

"""
def find_orgs(uploaded_file):
    uploaded_data = json.loads(uploaded_file)
    all_metrics = {}

    sample_data = store_sample_data(uploaded_data)

    gpt_orgs, true_orgs = [], []
    
    for sent in tqdm(sample_data):
        gpt_orgs.append(find_orgs_gpt(sent['text']))
        true_orgs.append(sent['orgs'])

    sim_model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
    all_metrics['gpt'] = calc_metrics(true_orgs, gpt_orgs, sim_model, threshold=0.85)        
    print(all_metrics)
    
    all_metrics['trf'] = get_metrics_trf(uploaded_data)
    

    
    
    print(all_metrics)
    return all_metrics

upload_btn = gr.UploadButton(label='Upload a json file.', type='binary')

iface = gr.Interface(fn=find_orgs, inputs=upload_btn, outputs="text")
iface.launch(share=True)