File size: 2,456 Bytes
89ded21
 
 
 
 
 
 
 
 
 
 
 
 
 
ce11ffc
89ded21
ce11ffc
89ded21
ce11ffc
 
 
89ded21
f20ab91
ce11ffc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89ded21
 
 
 
 
 
 
 
 
 
 
ce11ffc
 
 
 
 
 
 
 
 
 
89ded21
ce11ffc
 
89ded21
ce11ffc
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
# Copyright 2024-present, David Berenstein, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import io
import os
import random
import time

import requests
from PIL import Image

from dataset_viber import AnnotatorInterFace

HF_TOKEN = os.environ["HF_TOKEN"]
HEADERS = {"Authorization": f"Bearer {HF_TOKEN}"}
DATASET_SERVER_URL = "https://datasets-server.huggingface.co"
DATASET_NAME = "poloclub%2Fdiffusiondb&config=2m_random_1k&split=train"
MODEL_URL = (
    "https://api-inference.huggingface.co/models/black-forest-labs/FLUX.1-schnell"
)


def retrieve_sample(idx):
    api_url = f"{DATASET_SERVER_URL}/rows?dataset={DATASET_NAME}&offset={idx}&length=1"
    response = requests.get(api_url, headers=HEADERS)
    data = response.json()
    img_url = data["rows"][0]["row"]["image"]["src"]
    prompt = data["rows"][0]["row"]["prompt"]
    return img_url, prompt


def get_rows():
    api_url = f"{DATASET_SERVER_URL}/size?dataset={DATASET_NAME}"
    response = requests.get(api_url, headers=HEADERS)
    num_rows = response.json()["size"]["config"]["num_rows"]
    return num_rows


def generate_response(prompt):
    def _get_response(prompt):
        payload = {
            "inputs": prompt,
        }
        response = requests.post(MODEL_URL, headers=HEADERS, json=payload)
        if response.status_code != 200:
            time.sleep(10)
            return _get_response(prompt)
        return response

    response = _get_response(prompt)
    image = Image.open(io.BytesIO(response.content))
    return image


def next_input(_prompt, _completion_a, _completion_b):
    random_idx = random.randint(0, get_rows()) - 1
    img_url, prompt = retrieve_sample(random_idx)
    generated_image = generate_response(prompt)
    return (prompt, img_url, generated_image)


if __name__ == "__main__":
    interface = AnnotatorInterFace.for_image_generation_preference(
        interactive=False, fn_next_input=next_input
    )
    interface.launch()