import json

import gradio as gr
from distilabel.llms import InferenceEndpointsLLM
from distilabel.steps.tasks.argillalabeller import ArgillaLabeller

llm = InferenceEndpointsLLM(
    model_id="meta-llama/Meta-Llama-3.1-8B-Instruct",
    tokenizer_id="meta-llama/Meta-Llama-3.1-8B-Instruct",
    generation_kwargs={"max_new_tokens": 1000 * 4},
)
task = ArgillaLabeller(llm=llm)
task.load()


def load_examples():
    with open("examples.json", "r") as f:
        return json.load(f)


# Create Gradio examples
examples = load_examples()


def process_fields(fields):
    if isinstance(fields, str):
        fields = json.loads(fields)
    if isinstance(fields, dict):
        fields = [fields]
    return [field if isinstance(field, dict) else json.loads(field) for field in fields]


def process_records_gradio(records, fields, question, example_records=None):
    try:
        # Convert string inputs to dictionaries
        if isinstance(records, str) and records:
            records = json.loads(records)
        if isinstance(example_records, str) and example_records:
            example_records = json.loads(example_records)
        if isinstance(fields, str) and fields:
            fields = json.loads(fields)
        if isinstance(question, str) and question:
            question = json.loads(question)

        if not fields and not question:
            raise Exception("Error: Either fields or question must be provided")

        runtime_parameters = {"fields": fields, "question": question}
        if example_records:
            runtime_parameters["example_records"] = example_records

        task.set_runtime_parameters(runtime_parameters)

        results = []
        output = task.process(inputs=[{"record": record} for record in records])
        for _ in range(len(records)):
            entry = next(output)[0]
            if entry["suggestions"]:
                results.append(entry["suggestions"])

        return json.dumps({"results": results}, indent=2)
    except Exception as e:
        raise gr.Error(f"Error: {str(e)}")


description = """
An example workflow for JSON payload.

```python
import json
import os
from gradio_client import Client

import argilla as rg

# Initialize Argilla client
client = rg.Argilla(
    api_key=os.environ["ARGILLA_API_KEY"], api_url=os.environ["ARGILLA_API_URL"]
)

# Load the dataset
dataset = client.datasets(name="my_dataset", workspace="my_workspace")

# Prepare example data
example_field = dataset.settings.fields["my_input_field"].serialize()
example_question = dataset.settings.questions["my_question_to_predict"].serialize()

payload = {
    "records": [next(dataset.records()).to_dict()],
    "fields": [example_field],
    "question": example_question,
}

# Use gradio client to process the data
client = Client("davidberenstein1957/distilabel-argilla-labeller")

result = client.predict(
    records=json.dumps(payload["records"]),
    example_records=json.dumps(payload["example_records"]),
    fields=json.dumps(payload["fields"]),
    question=json.dumps(payload["question"]),
    api_name="/predict"
)

```
"""

interface = gr.Interface(
    fn=process_records_gradio,
    inputs=[
        gr.Code(label="Records (JSON)", language="json", lines=5),
        gr.Code(label="Example Records (JSON, optional)", language="json", lines=5),
        gr.Code(label="Fields (JSON, optional)", language="json"),
        gr.Code(label="Question (JSON, optional)", language="json"),
    ],
    examples=examples,
    cache_examples=True,
    outputs=gr.Code(label="Suggestions", language="json", lines=10),
    title="Distilabel - ArgillaLabeller - Record Processing Interface",
    description=description,
)

if __name__ == "__main__":
    interface.launch()