File size: 712 Bytes
c09ea67
9fbf14c
38d4932
 
 
 
c09ea67
38d4932
 
 
52e8dae
 
 
38d4932
52e8dae
38d4932
52e8dae
f892f46
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
from typing import Any, Tuple

from transformers import pipeline, LongformerForSequenceClassification, LongformerTokenizer, Trainer
import gradio as gr


def predict_fn(text: str) -> Tuple[Any, Any]:
    model = LongformerForSequenceClassification.from_pretrained("model")
    tokenizer = LongformerTokenizer.from_pretrained("allenai/longformer-base-4096")
    p = pipeline("sentiment-analysis", model=model, tokenizer=tokenizer)
    results = p(text, top_k=3)
    WEIGHTS = {"Dovish": -100, "Neutral": 0, "Hawkish": 100}
    scores = [d["score"] * WEIGHTS.get(d["label"]) for d in results]

    return results[0]["label"], round(sum(scores), 0)


gr.Interface(predict_fn, "textbox", ["label", "label"]).launch()