import json
import os
import shutil

import gradio as gr
from huggingface_hub import Repository
from text_generation import Client

from share_btn import community_icon_html, loading_icon_html, share_js, share_btn_css

HF_TOKEN = os.environ.get("HF_TOKEN", None)
API_URL = os.environ.get("API_URL")


FIM_PREFIX = "<fim_prefix>"
FIM_MIDDLE = "<fim_middle>"
FIM_SUFFIX = "<fim_suffix>"

FIM_INDICATOR = "<FILL_HERE>"

FORMATS = """## Model formats

### Prefixes
Any combination of the three:

```
<reponame>REPONAME<filename>FILENAME<gh_stars>STARS\nCode<eos>
```
Stars be: 0, 1-10, 10-100, 100-1000, 1000+

### Commits

```
<commit_before>code<commit_msg>text<commit_after>code<|endoftext|>
```

### Jupyter structure

```
<start_jupyter><jupyter_text>text<jupyter_code>code<jupyter_output>output<jupyter_text>
```

### Fill-in-the-middle

```
code before<FILL_HERE>code after
```
"""

theme = gr.themes.Monochrome(
    primary_hue="indigo",
    secondary_hue="blue",
    neutral_hue="slate",
    radius_size=gr.themes.sizes.radius_sm,
    font=[gr.themes.GoogleFont("Open Sans"), "ui-sans-serif", "system-ui", "sans-serif"],
)

client = Client(
    API_URL,
    #headers={"Authorization": f"Bearer {HF_TOKEN}"},
)

def generate(prompt, temperature=0.9, max_new_tokens=256, top_p=0.95, repetition_penalty=1.0):

    temperature = float(temperature)
    if temperature < 1e-2:
        temperature = 1e-2
    top_p = float(top_p)
    fim_mode = False
    
    generate_kwargs = dict(
        temperature=temperature,
        max_new_tokens=max_new_tokens,
        top_p=top_p,
        repetition_penalty=repetition_penalty,
        do_sample=True,
        seed=42,
    )

    if FIM_INDICATOR in prompt:
        fim_mode = True
        try:
            prefix, suffix = prompt.split("<FILL-HERE>")
        except:
            ValueError("Only one <FILL-HERE> allowed in prompt!")
        prompt = f"{FIM_PREFIX}{prefix}{FIM_SUFFIX}{suffix}{FIM_MIDDLE}"

    stream = client.generate_stream(prompt, **generate_kwargs)

    if fim_mode:
        output = prefix
    else:
        output = prompt

    for response in stream:
        output += response.token.text
        yield output
    
    if fim_mode:
        output += suffix
    return output


examples = [
    "def hello_world():",
    "def fibonacci(n):",
    "class TransformerDecoder(nn.Module):",
    "class ComplexNumbers:"
]


def process_example(args):
    for x in generate(args):
        pass
    return x

css = ".generating {visibility: hidden}" + share_btn_css

with gr.Blocks(theme=theme, analytics_enabled=False, css=css) as demo:
    with gr.Column():
        gr.Markdown(
            """\
# BigCode - Playground

_Note:_ this is an internal playground - please do not share. The deployment can also change and thus the space not work as we continue development.\
"""

        )
        with gr.Row():
            with gr.Column(scale=3):
                instruction = gr.Textbox(placeholder="Enter your prompt here", label="Prompt", elem_id="q-input")
                submit = gr.Button("Generate", variant="primary")
                with gr.Box():
                    output = gr.Code(elem_id="q-output")
                with gr.Group(elem_id="share-btn-container"):
                    community_icon = gr.HTML(community_icon_html, visible=True)
                    loading_icon = gr.HTML(loading_icon_html, visible=True)
                    share_button = gr.Button("Share to community", elem_id="share-btn", visible=True)
                gr.Examples(
                    examples=examples,
                    inputs=[instruction],
                    cache_examples=False,
                    fn=process_example,
                    outputs=[output],
                )

            with gr.Column(scale=1):
                
                temperature = gr.Slider(
                    label="Temperature",
                    value=0.2,
                    minimum=0.0,
                    maximum=2.0,
                    step=0.1,
                    interactive=True,
                    info="Higher values produce more diverse outputs",
                )
                max_new_tokens = gr.Slider(
                    label="Max new tokens",
                    value=256,
                    minimum=0,
                    maximum=4096,
                    step=4,
                    interactive=True,
                    info="The maximum numbers of new tokens",
                )
                top_p = gr.Slider(
                    label="Top-p (nucleus sampling)",
                    value=0.90,
                    minimum=0.0,
                    maximum=1,
                    step=0.05,
                    interactive=True,
                    info="Higher values sample more low-probability tokens",
                )
                repetition_penalty = gr.Slider(
                    label="Repetition penalty",
                    value=1.2,
                    minimum=1.0,
                    maximum=2.0,
                    step=0.05,
                    interactive=True,
                    info="Penalize repeated tokens",
                )

    submit.click(generate, inputs=[instruction, temperature, max_new_tokens, top_p, repetition_penalty], outputs=[output])
    instruction.submit(generate, inputs=[instruction, temperature, max_new_tokens, top_p, repetition_penalty], outputs=[output])
    share_button.click(None, [], [], _js=share_js)
demo.queue(concurrency_count=16).launch(debug=True)