import gradio as gr from gradio_client import Client as GrClient import inspect from gradio import routes from typing import List, Type from googletrans import Translator import requests, os gradio_client = GrClient(os.environ.get('GrClient_url')) translator = Translator() # Monkey patch def get_types(cls_set: List[Type], component: str): docset = [] types = [] if component == "input": for cls in cls_set: doc = inspect.getdoc(cls) doc_lines = doc.split("\n") docset.append(doc_lines[1].split(":")[-1]) types.append(doc_lines[1].split(")")[0].split("(")[-1]) else: for cls in cls_set: doc = inspect.getdoc(cls) doc_lines = doc.split("\n") docset.append(doc_lines[-1].split(":")[-1]) types.append(doc_lines[-1].split(")")[0].split("(")[-1]) return docset, types routes.get_types = get_types # App code def mbti(x): t = translator.translate(x, src='ko', dest='en') str_trans = re.sub('[-=+,#/\?:^.@*\"※~ㆍ!』‘|\(\)\[\]`\'…》\”\“\’·]', '', t) result = gradio_client.predict( str_trans, # str representing input in 'User input' Textbox component fn_index=2 ) return result def chat(x): result = gradio_client.predict( x,# str representing input in 'User input' Textbox component 0.9, # float, representing input in 'Top-p (nucleus sampling)' Slider component 50, # int, representing input in 'Top-k (nucleus sampling)' Slider component 0.7, # float, representing input in 'Temperature' Slider component 25, # int, representing input in 'Max New Tokens' Slider component 1.2, # float, representing input in 'repetition_penalty' Slider component fn_index=0 ) return result def yn(x): result = gradio_client.predict( x, # str representing input in 'User input' Textbox component fn_index=1 ) return result with gr.Blocks() as blk: gr.Markdown("# Gradio Blocks (3.0) with REST API") t = gr.Textbox() c = gr.Button("mbti") b = gr.Button("chat") a = gr.Button("yn") o = gr.Textbox() c.click(mbti, inputs=[t], outputs=[o]) b.click(chat, inputs=[t], outputs=[o]) a.click(yn, inputs=[t], outputs=[o]) gr.Markdown(""" ## API Can select which function to use by passing in `fn_index`: ```python import requests requests.post( url="https://hf.space/embed/versae/gradio-blocks-rest-api/+/api/predict/", json={"data": ["Jessie"], "fn_index": 0} ).json() requests.post( url="https://hf.space/embed/versae/gradio-blocks-rest-api/+/api/predict/", json={"data": ["Jessie"], "fn_index": 1} ).json() ``` Or using cURL ``` $ curl -X POST https://hf.space/embed/versae/gradio-blocks-rest-api/+/api/predict/ -H 'Content-Type: application/json' -d '{"data": ["Jessie"], "fn_index": 0}' $ curl -X POST https://hf.space/embed/versae/gradio-blocks-rest-api/+/api/predict/ -H 'Content-Type: application/json' -d '{"data": ["Jessie"], "fn_index": 1}' ```""") ifa = gr.Interface(lambda: None, inputs=[t], outputs=[o]) blk.input_components = ifa.input_components blk.output_components = ifa.output_components blk.examples = None blk.predict_durations = [] bapp = blk.launch()