File size: 3,399 Bytes
4c513c9
 
 
 
 
 
 
920c424
 
99f19f2
 
 
4c513c9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99f19f2
 
4c513c9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99f19f2
 
4c513c9
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
Enhanced Gradio UI for the Salesforce/codet5-large model using the Hugging Face Inference API.
Adheres to best practices, PEP8, flake8, and the Zen of Python.
"""

import gradio as gr

MODEL_ID = "Salesforce/codet5-large"


def prepare_payload(prompt: str, max_tokens: int) -> dict:
    """
    Prepare the payload dictionary for the Hugging Face inference call.

    Args:
        prompt (str): The input code containing `<extra_id_0>`.
        max_tokens (int): Maximum number of tokens for generation.

    Returns:
        dict: Payload for the model API call.
    """
    return {"inputs": prompt, "parameters": {"max_length": max_tokens}}


def extract_generated_text(api_response: dict) -> str:
    """
    Extract generated text from the API response.

    Args:
        api_response (dict): The response dictionary from the model API call.

    Returns:
        str: The generated text, or string representation of the response.
    """
    return api_response.get("generated_text", str(api_response))


def main():
    with gr.Blocks(fill_height=True, theme=gr.themes.Soft()) as demo:
        with gr.Sidebar():
            gr.Markdown("## 🤖 Inference Provider")
            gr.Markdown(
                (
                    "This Space showcases the `{}` model, served via the Hugging Face Inference API.\n\n"
                    "Sign in with your Hugging Face account to access the model."
                ).format(MODEL_ID)
            )
            login_button = gr.LoginButton("🔐 Sign in")
            gr.Markdown("---")
            gr.Markdown(f"**Model:** `{MODEL_ID}`")
            gr.Markdown("[📄 View Model Card](https://huggingface.co/Salesforce/codet5-large)")

        gr.Markdown("# 🧠 CodeT5 Inference UI")
        gr.Markdown("Enter your Python code snippet with `<extra_id_0>` as the mask token.")

        with gr.Row():
            with gr.Column(scale=1):
                code_input = gr.Code(
                    label="Input Code",
                    language="python",
                    value="def greet(user): print(f'hello <extra_id_0>!')",
                    lines=10,
                    autofocus=True,
                )
                max_tokens = gr.Slider(
                    minimum=8, maximum=128, value=32, step=8, label="Max Tokens"
                )
                submit_btn = gr.Button("🚀 Run Inference")
            with gr.Column(scale=1):
                output_text = gr.Textbox(
                    label="Inference Output",
                    lines=10,
                    interactive=False,
                    placeholder="Model output will appear here...",
                )

        # Load the model from Hugging Face Inference API.
        model_iface = gr.load(
            f"models/{MODEL_ID}",
            accept_token=login_button,
            provider="hf-inference",
        )

        # Chain click events: prepare payload -> API call -> extract output.
        submit_btn.click(
            fn=prepare_payload,
            inputs=[code_input, max_tokens],
            outputs=model_iface,
            api_name="prepare_payload",
        ).then(
            fn=extract_generated_text,
            inputs=model_iface,
            outputs=output_text,
            api_name="extract_output",
        )

    demo.launch()


if __name__ == "__main__":
    main()