File size: 6,492 Bytes
b51a719
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
import argparse
import json
import random
import time
import requests
import base64
from io import BytesIO

def get_image_as_base64(url):
    try:
        response = requests.get(url)
        response.raise_for_status()
        image_data = BytesIO(response.content)
        base64_image = base64.b64encode(image_data.getvalue()).decode('utf-8')
        return base64_image
    except requests.exceptions.RequestException as ex:
        print(f'Failed to retrieve image: {ex}')
        return None
    
def queue_prompt(url, prompt):
    p = {"prompt": prompt}
    data = json.dumps(p).encode('utf-8')
    prompt_url = f"{url}/prompt"
    try:
        r = requests.post(prompt_url, data=data)
        r.raise_for_status()
        return r.json()
    except requests.exceptions.RequestException as ex:
        print(f'POST {prompt_url} failed: {ex}')
        return None

def get_queue(url):
    queue_url = f"{url}/queue"
    try:
        r = requests.get(queue_url)
        r.raise_for_status()
        return r.json()
    except requests.exceptions.RequestException as ex:
        print(f'GET {queue_url} failed: {ex}')
        return None


def get_history(url, prompt_id):
    history_url = f"{url}/history/{prompt_id}"
    try:
        r = requests.get(history_url)
        r.raise_for_status()
        return r.json()
    except requests.exceptions.RequestException as ex:
        print(f'GET {history_url} failed: {ex}')
        return None


def main(ip, port, filepath, prompt=None, steps=None, seed=None, cfg=None, width=None, height=None, lora_name=None, lora_scale=None):
    url = f"http://{ip}:{port}"

    with open(filepath, 'r') as file:
        prompt_text = json.load(file)

    # Update prompt_text with provided arguments
    if prompt is not None:
        prompt_text["6"]["inputs"]["text"] = prompt
    if steps is not None:
        prompt_text["17"]["inputs"]["steps"] = steps
    if seed is not None:
        prompt_text["25"]["inputs"]["noise_seed"] = seed
    else:
        prompt_text["25"]["inputs"]["noise_seed"] = random.randint(0, 1000000000000000)
    if cfg is not None:
        prompt_text["26"]["inputs"]["guidance"] = cfg
    if width is not None:
        prompt_text["27"]["inputs"]["width"] = width
    if height is not None:
        prompt_text["27"]["inputs"]["height"] = height
    if lora_name is not None:
        prompt_text["30"]["inputs"]["lora_name"] = lora_name
    if lora_scale is not None:
        prompt_text["30"]["inputs"]["strength_model"] = lora_scale

    # Print the updated values
    print(f'Prompt: {prompt_text["6"]["inputs"]["text"]}')
    print(f'Steps: {prompt_text["17"]["inputs"]["steps"]}')
    print(f'Seed: {prompt_text["25"]["inputs"]["noise_seed"]}')
    print(f'CFG: {prompt_text["26"]["inputs"]["guidance"]}')
    print(f'Width: {prompt_text["27"]["inputs"]["width"]}')
    print(f'Height: {prompt_text["27"]["inputs"]["height"]}')
    print(f'LoRA Name: {prompt_text["30"]["inputs"]["lora_name"]}')
    print(f'LoRA Scale: {prompt_text["30"]["inputs"]["strength_model"]}')

    response1 = queue_prompt(url, prompt_text)
    if response1 is None:
        print("Failed to queue the prompt.")
        return

    prompt_id = response1['prompt_id']
    print(f'Prompt ID: {prompt_id}')
    print('-' * 20)

    while True:
        time.sleep(5)
        queue_response = get_queue(url)
        if queue_response is None:
            continue

        queue_pending = queue_response.get('queue_pending', [])
        queue_running = queue_response.get('queue_running', [])

        # Check position in queue
        for position, item in enumerate(queue_pending):
            if item[1] == prompt_id:
                print(f'Queue running: {len(queue_running)}, Queue pending: {len(queue_pending)}, Workflow is in position {position + 1} in the queue.')

        # Check if the prompt is currently running
        for item in queue_running:
            if item[1] == prompt_id:
                print(f'Queue running: {len(queue_running)}, Queue pending: {len(queue_pending)}, Workflow is currently running.')
                break

        if not any(prompt_id in item for item in queue_pending + queue_running):
            break

    history_response = get_history(url, prompt_id)
    if history_response is None:
        print("Failed to retrieve history.")
        return

    output_info = history_response.get(prompt_id, {}).get('outputs', {}).get('9', {}).get('images', [{}])[0]
    filename = output_info.get('filename', 'unknown.png')
    output_url = f"{url}/output/{filename}"

    print(f"Output URL: {output_url}")

    # Get base64 encoded image
    base64_image = get_image_as_base64(output_url)
    if base64_image:
        print("Base64 encoded image:")
        print(base64_image)
    else:
        print("Failed to retrieve base64 encoded image.")

    return {
        "output_url": output_url,
        "base64_image": base64_image
    }


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='Add a prompt to the queue and wait for the output.')
    parser.add_argument('--ip', type=str, required=True, help='The public IP address of the pod')
    parser.add_argument('--port', type=int, required=True, help='The external port of the pod')
    parser.add_argument('--filepath', type=str, required=True, help='The path to the JSON file containing the workflow in api format')
    parser.add_argument('--prompt', type=str, help='The prompt to use for the workflow')
    parser.add_argument('--steps', type=int, help='Number of steps for the sampler')
    parser.add_argument('--seed', type=int, help='Seed for the noise generator')
    parser.add_argument('--cfg', type=float, help='Classifier-free guidance scale')
    parser.add_argument('--width', type=int, help='Width of the output image')
    parser.add_argument('--height', type=int, help='Height of the output image')
    parser.add_argument('--lora_name', type=str, help='Name of the LoRA to use')
    parser.add_argument('--lora_scale', type=float, help='Scale of the LoRA effect')

    args = parser.parse_args()
    result = main(args.ip, args.port, args.filepath, args.prompt, args.steps, args.seed, args.cfg, args.width, args.height, args.lora_name, args.lora_scale)
    
    # If you want to save the base64 image to a file
    if result and result["base64_image"]:
        with open("output_image.txt", "w") as f:
            f.write(result["base64_image"])
        print("Base64 image saved to output_image.txt")