import os
import json

import requests
import gradio as gr
from gradio import inputs, outputs

ENDPOINTS = (
    "https://api-inference.huggingface.co/models/ceshine/t5-paraphrase-quora-paws",
    "https://api-inference.huggingface.co/models/ceshine/t5-paraphrase-paws-msrp-opinosis",
)


def get_fn(endpoint):
    def paraphrase(source_text: str, temperature: float):
        if temperature > 0:
            params = {
                "do_sample": True,
                "temperature": temperature,
                "top_k": 5,
                "num_return_sequences": 10,
                "max_length": 100,
            }
        else:
            params = {"num_beams": 10, "num_return_sequences": 10, "max_length": 100}
        res = requests.post(
            endpoint,
            headers={"Authorization": f"Bearer {os.environ['TOKEN']}"},
            data=json.dumps(
                {
                    "inputs": "paraphrase: " + source_text,
                    "parameters": params,
                }
            ),
        )
        if not (res.status_code == 200):
            return f"Got a {res.status_code} status code from HuggingFace."
        results = res.json()
        # print(results)
        outputs = [
            x["generated_text"]
            for x in results
            if x["generated_text"].lower() != source_text.lower().strip()
        ][:3]
        text = ""
        for i, output in enumerate(outputs):
            text += f"{i+1}:  {output}\n\n"
        return text

    return paraphrase


interface_1 = gr.Interface(
    fn=get_fn(ENDPOINTS[0]),
    title="quora-paws",
    inputs=[
        inputs.Textbox(label="Source text"),
        inputs.Number(
            default=0.0, label="Temperature (0 -> disable sampling and use beam search)"
        ),
    ],
    outputs=outputs.Textbox(label="quora-paws"),
)

interface_2 = gr.Interface(
    fn=get_fn(ENDPOINTS[1]),
    title="paws-msrp-opinosis",
    inputs=[
        inputs.Textbox(label="Source text"),
        inputs.Number(
            default=0.0, label="Temperature (0 -> disable sampling and use beam search)"
        ),
    ],
    outputs=outputs.Textbox(label="paws-msrp-opinosis"),
)

gr.Parallel(
    interface_1,
    interface_2,
    title="T5 Sentence Paraphraser",
    description="Compare generated paraphrases from two models (`ceshine/t5-paraphrase-quora-paws` and `ceshine/t5-paraphrase-paws-msrp-opinosis`).",
    examples=[
        ["I bought a ticket from London to New York.", 0],
        ["Weh Seun spends 14 hours a week doing housework.", 1.2],
    ],
).launch(enable_queue=True)