File size: 5,249 Bytes
ac614bf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
"""We can use Gradio to build the UI and then make it compatible for the Hugging face."""
import gradio as gr
import cv2
import numpy as np
import imutils
from PIL import Image

cv2.ocl.setUseOpenCL(False)

# Sharpening function
def image_sharpening(image):
    kernel_sharpening = np.array([[-1, -1, -1],
                                  [-1, 9, -1],
                                  [-1, -1, -1]])
    sharpened = cv2.filter2D(image, -1, kernel_sharpening)
    return sharpened

# Remove black borders function
def remove_black_region(result):
    gray = cv2.cvtColor(result, cv2.COLOR_BGR2GRAY)
    thresh = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY)[1]
    cnts = cv2.findContours(thresh.copy(), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    cnts = imutils.grab_contours(cnts)
    c = max(cnts, key=cv2.contourArea)
    (x, y, w, h) = cv2.boundingRect(c)
    crop = result[y:y + h, x:x + w]
    return crop

# Key point detection and descriptor function
def detectAndDescribe(image, method='orb'):
    if method == 'sift':
        descriptor = cv2.SIFT_create()
    elif method == 'brisk':
        descriptor = cv2.BRISK_create()
    elif method == 'orb':
        descriptor = cv2.ORB_create()
    (kps, features) = descriptor.detectAndCompute(image, None)
    return kps, features

# Matcher creation
def createMatcher(method, crossCheck):
    if method in ['sift', 'surf']:
        bf = cv2.BFMatcher(cv2.NORM_L2, crossCheck=crossCheck)
    else:
        bf = cv2.BFMatcher(cv2.NORM_HAMMING, crossCheck=crossCheck)
    return bf

# Matching key points
def matchKeyPointsKNN(featuresA, featuresB, ratio, method):
    bf = createMatcher(method, crossCheck=False)
    rawMatches = bf.knnMatch(featuresA, featuresB, 2)
    matches = []
    for m, n in rawMatches:
        if m.distance < n.distance * ratio:
            matches.append(m)
    return matches

# Homography calculation
def getHomography(kpsA, kpsB, featuresA, featuresB, matches, reprojThresh=4.0):
    kpsA = np.float32([kp.pt for kp in kpsA])
    kpsB = np.float32([kp.pt for kp in kpsB])
    if len(matches) > 4:
        ptsA = np.float32([kpsA[m.queryIdx] for m in matches])
        ptsB = np.float32([kpsB[m.trainIdx] for m in matches])
        (H, status) = cv2.findHomography(ptsA, ptsB, cv2.RANSAC, reprojThresh)
        return matches, H, status
    else:
        return None

# Stitching function for two images
def stitch_two_images(queryImg, trainImg, feature_extractor):
    queryImg_gray = cv2.cvtColor(queryImg, cv2.COLOR_BGR2GRAY)
    trainImg_gray = cv2.cvtColor(trainImg, cv2.COLOR_BGR2GRAY)
    kpsA, featuresA = detectAndDescribe(trainImg_gray, method=feature_extractor)
    kpsB, featuresB = detectAndDescribe(queryImg_gray, method=feature_extractor)
    matches = matchKeyPointsKNN(featuresA, featuresB, ratio=0.75, method=feature_extractor)
    M = getHomography(kpsA, kpsB, featuresA, featuresB, matches, reprojThresh=5)
    if M is None:
        return None
    (matches, H, status) = M
    width = trainImg.shape[1] + queryImg.shape[1]
    height = trainImg.shape[0] + queryImg.shape[0]
    result = cv2.warpPerspective(trainImg, H, (width, height))
    result[0:queryImg.shape[0], 0:queryImg.shape[1]] = queryImg
    crop_image = remove_black_region(result)
    return crop_image

# Calculate target brightness
def calculate_target_brightness(images):
    brightness_values = [np.mean(image.astype(np.float32)) for image in images]
    return np.mean(brightness_values)

# Brightness adjustment
def global_brightness_adjustment(images, target_brightness):
    adjusted_images = []
    for image in images:
        image_float = image.astype(np.float32)
        avg_brightness = np.mean(image_float)
        brightness_shift = target_brightness - avg_brightness
        adjusted_image = image_float + brightness_shift
        adjusted_image = np.clip(adjusted_image, 0, 255).astype(np.uint8)
        adjusted_images.append(adjusted_image)
    return adjusted_images   

# Main Stitching function
def stitch_images(uploaded_files, feature_extractor):
    images = [cv2.cvtColor(np.array(Image.open(file)), cv2.COLOR_RGB2BGR) for file in uploaded_files]
    if len(images) == 0:
        return None
    # feature_extractor = 'orb'
    target_brightness = calculate_target_brightness(images)
    adjusted_images = global_brightness_adjustment(images, target_brightness)
    stitched_image = adjusted_images[0]
    for i in range(1, len(adjusted_images)):
        queryImg = stitched_image
        trainImg = adjusted_images[i]
        stitched_image = stitch_two_images(queryImg, trainImg, feature_extractor)
    return cv2.cvtColor(stitched_image, cv2.COLOR_BGR2RGB)
  
# Gradio interface with feature extractor selector
with gr.Blocks() as demo:
    gr.Markdown("## Image Stitching App with Feature Extractor Selection")
    image_input = gr.Files(label="Upload Images", type="filepath")
    extractor_input = gr.Dropdown(choices=["orb", "sift", "brisk"], label="Feature Extractor", value="orb")
    image_output = gr.Image(type="numpy", label="Stitched Image")
    process_button = gr.Button("Process Image")
    process_button.click(stitch_images, inputs=[image_input, extractor_input], outputs=image_output)

# Launch the Gradio app
demo.launch()