import gradio as gr import requests import time from PIL import Image import os, io, json import base64 sd_api_base = os.environ["SD_API_BASE"] sd_api_key = os.environ["SD_API_KEY"] # 发送POST请求的函数 def send_post_request(input_json_string): try: # 尝试将输入的字符串转换为JSON对象 data = json.loads(input_json_string) except json.JSONDecodeError as e: return f"输入的字符串不是有效的JSON格式: {e}" url = f"{sd_api_base}/txt2img/run/" headers = { 'Content-Type': 'application/json', 'Authorization': f'Bearer {sd_api_key}', } response = requests.post(url, headers=headers, json=data) if response.status_code == 200: return response.json() else: raise Exception(f"Error in POST request: {response.text}") # 轮询GET请求,直到异步操作完成 def poll_status(id): url = f"{sd_api_base}/txt2img/status/{id}" headers = { 'Content-Type': 'application/json', 'Authorization': f'Bearer {sd_api_key}', } while True: response = requests.get(url, headers=headers) if response.status_code == 200: result = response.json() if result['status'] == 'COMPLETED': return result else: time.sleep(1) # 等待1秒后再次尝试 else: raise Exception(f"Error in GET request: {response.text}") # 将Base64编码的图片数据转换为可显示的图片 def display_images(output_json): images_data = output_json['output']['images'] images = [] for base64_data in images_data: image_data = base64.b64decode(base64_data) image = Image.open(io.BytesIO(image_data)) images.append(image) return images # Gradio界面的函数 def gradio_interface(input_json): post_response = send_post_request(input_json) print(post_response) status_response = poll_status(post_response['id']) images = display_images(status_response) return images # 设置Gradio界面 iface = gr.Interface( fn=gradio_interface, inputs=gr.Textbox(lines=2, placeholder="Type something here..."), outputs="gallery" # examples=[{"prompt": "a dog"}] ) # 启动Gradio应用程序 iface.launch()