File size: 4,385 Bytes
c92c0ec
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import os
os.environ ['HF_ENDPOINT'] = 'https://hf-mirror.com'
from ast import main
from numpy import imag
import torch
from diffusers import StableDiffusionPipeline
import os
from PIL import Image

def normalize_bbox(bboxes, img_width, img_height):
    normalized_bboxes = []
    for box in bboxes:
        x_min, y_min, x_max, y_max = box
        
        x_min = (x_min / img_width)
        y_min = (y_min / img_height)
        x_max = (x_max / img_width)
        y_max = (y_max / img_height)
        
        normalized_bboxes.append([x_min, y_min, x_max, y_max])
    
    return normalized_bboxes

def create_reco_prompt(
                caption: str = '',
                phrases=[],
                boxes=[],
                normalize_boxes=True,
                image_resolution=512,
                num_bins=1000,
                ):
                """
                method to create ReCo prompt

                caption: global caption
                phrases: list of regional captions
                boxes: list of regional coordinates (unnormalized xyxy)
                """

                SOS_token = '<|startoftext|>'
                EOS_token = '<|endoftext|>'
                
                box_captions_with_coords = []
                
                box_captions_with_coords += [caption]
                box_captions_with_coords += [EOS_token]

                for phrase, box in zip(phrases, boxes):
                                
                    if normalize_boxes:
                        box = [float(x) / image_resolution for x in box]

                    # quantize into bins
                    quant_x0 = int(round((box[0] * (num_bins - 1))))
                    quant_y0 = int(round((box[1] * (num_bins - 1))))
                    quant_x1 = int(round((box[2] * (num_bins - 1))))
                    quant_y1 = int(round((box[3] * (num_bins - 1))))
                    
                    # ReCo format
                    # Add SOS/EOS before/after regional captions
                    box_captions_with_coords += [
                        f"<bin{str(quant_x0).zfill(3)}>",
                        f"<bin{str(quant_y0).zfill(3)}>",
                        f"<bin{str(quant_x1).zfill(3)}>",
                        f"<bin{str(quant_y1).zfill(3)}>",
                        SOS_token,
                        phrase,
                        EOS_token
                    ]

                text = " ".join(box_captions_with_coords)
                return text

def inference_image(pipe, prompt, grounding_instruction, state):
    print(prompt)
    print(grounding_instruction)
    bbox = state['boxes']
    # bbox = state
    print(bbox)
    bbox = normalize_bbox(bbox, 600, 600)
    print(bbox)
    objects = [obj for obj in grounding_instruction.split(';') if obj.strip()]
    print(objects)
    prompt_reco = create_reco_prompt(prompt, objects, bbox, normalize_boxes=False)
    print(prompt_reco)
    image = pipe(prompt_reco, guidance_scale=4).images[0]
    return image



if __name__ == "__main__":
    path = '/home/bcy/cache/.cache/huggingface/hub/models--j-min--reco_sd14_coco/snapshots/11a062da5a0a84501047cb19e113f520eb610415' if os.path.isdir('/home/bcy/cache/.cache/huggingface/hub/models--j-min--reco_sd14_coco/snapshots/11a062da5a0a84501047cb19e113f520eb610415') else "CompVis/stable-diffusion-v1-4"
    pipe = StableDiffusionPipeline.from_pretrained(
        "j-min/reco_sd14_coco",
        torch_dtype=torch.float16
    )
    pipe = pipe.to("cuda")
    # caption = "A box contains six donuts with varying types of glazes and toppings."
    # phrases = ["chocolate donut.", "dark vanilla donut.", "donut with sprinkles.", "donut with powdered sugar.", "pink donut.", "brown donut."]
    # boxes = [[263.68, 294.912, 380.544, 392.832], [121.344, 265.216, 267.392, 401.92], [391.168, 294.912, 506.368, 381.952], [120.064, 143.872, 268.8, 270.336], [264.192, 132.928, 393.216, 263.68], [386.048, 148.48, 490.688, 259.584]]
    # prompt = create_reco_prompt(caption, phrases, boxes)
    # print(prompt)
    # generated_image = pipe(
    # prompt,
    # guidance_scale=4).images[0]
    # generated_image.save("output1.jpg")
    prompt = "a dog and a cat;"
    grounding_instruction = "cut dog; big cat;"
    bbox = [(136, 252, 280, 455), (284, 205, 480, 500)]
    
    inference_image(pipe, prompt, grounding_instruction, bbox)