import numpy as np
import random, os
from PIL import Image



def merge_images_horizontally(img1, img2):
    
    img1_width, img1_height = img1.size
    img2_width, img2_height = img2.size

    merged_width = img1_width + img2_width
    merged_height = max(img1_height, img2_height)
    merged_image = Image.new('RGBA', (merged_width + 20, merged_height))

   
    merged_image.paste(img1, (0, 0))
    merged_image.paste(img2, (img1_width + 20, 0))
  
    return merged_image
   


def outpainting_generator_rectangle(image, box_width_ratio=0.35, mask_random_start=125):
    '''
    image: PIL Image file
    sqr: True or False, decide the shape of cropped images
    '''
    
    # sqr = random.choice([True, False])
    
    image = image.resize((512, 512))
    width, height = image.size  
   
    box_height = height 
    box_width = width * box_width_ratio #* random.uniform(0.35, 0.4)
    
    x = np.random.randint(0, int(width - box_width))
    y = 0

    left = x
    upper = y
    right = x + box_width
    lower = y + box_height

    box = (left, upper, right, lower)
    small_image = image.crop(box)

    large_box_width, large_box_height = 512, 512  
    small_box_height = 512
    small_box_width = 256
    small_image = small_image.resize((small_box_width, small_box_height))


    large_image = Image.new('RGB', (large_box_width, large_box_height), "black")
    mask_0 =  Image.new('RGB', (small_box_width, small_box_height), "black")
    mask_1 = Image.new('RGB', (large_box_width, large_box_height), "white")

    max_x = large_box_width - small_box_width
    
    random_x = mask_random_start # np.random.randint(0, max_x)
    random_y = 0


    large_image.paste(small_image, (random_x, random_y))
    mask_1.paste(mask_0, (random_x, random_y))
    # large_image.save(os.path.join('mask_color.png'))
    # mask_1.save(os.path.join('mask.png'))

    return large_image, mask_1


def outpainting_generator(image, sqr):
    '''
    image: PIL Image file
    sqr: True or False, decide the shape of cropped images
    '''
    
    # sqr = random.choice([True, False])
    
    image = image.resize((512, 512))
    width, height = image.size  
    if sqr:
        size = height * random.uniform(0.15, 0.25) # size = height * random.uniform(0.7, 0.8)
        box_height, box_width = size, size
    else:
        box_height = height 
        box_width = width * random.uniform(0.35, 0.4)
    
    x = np.random.randint(0, int(width - box_width))
    if sqr:
        y = np.random.randint(0, int(height - box_height))
    else:
        y = 0

    left = x
    upper = y
    right = x + box_width
    lower = y + box_height

    box = (left, upper, right, lower)
    small_image = image.crop(box)

    large_box_width, large_box_height = 512, 512  

    if sqr:
        size = random.randint(300, 350)
        small_box_width, small_box_height = size, size
    else:
        # ratio = box_width / box_height
        # small_box_height = 512
        # small_box_width = int(ratio * box_height)
        small_box_height = 512
        small_box_width = 256
    small_image = small_image.resize((small_box_width, small_box_height))


    large_image = Image.new('RGB', (large_box_width, large_box_height), "black")
    mask_0 =  Image.new('RGB', (small_box_width, small_box_height), "black")
    mask_1 = Image.new('RGB', (large_box_width, large_box_height), "white")

    max_x = large_box_width - small_box_width
    max_y = large_box_height - small_box_height

    random_x = np.random.randint(0, max_x)
    if sqr:
        random_y = np.random.randint(0, max_y)
    else:
        random_y = 0


    large_image.paste(small_image, (random_x, random_y))
    mask_1.paste(mask_0, (random_x, random_y))
    large_image.save(os.path.join('mask_color.png'))
    mask_1.save(os.path.join('mask.png'))

    return large_image, mask_1