siddharth060104 commited on
Commit
4ab2815
·
verified ·
1 Parent(s): df5a3b7

Upload 3 files

Browse files
Files changed (3) hide show
  1. main.py +268 -0
  2. segmentation.pt +3 -0
  3. standard.pt +3 -0
main.py ADDED
@@ -0,0 +1,268 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ultralytics import YOLO
2
+ import cv2
3
+ from stockfish import Stockfish
4
+ import os
5
+ import numpy as np
6
+ import streamlit as st
7
+
8
+ # Constants
9
+ FEN_MAPPING = {
10
+ "black-pawn": "p", "black-rook": "r", "black-knight": "n", "black-bishop": "b", "black-queen": "q", "black-king": "k",
11
+ "white-pawn": "P", "white-rook": "R", "white-knight": "N", "white-bishop": "B", "white-queen": "Q", "white-king": "K"
12
+ }
13
+ GRID_BORDER = 10 # Border size in pixels
14
+ GRID_SIZE = 204 # Effective grid size (10px to 214px)
15
+ BLOCK_SIZE = GRID_SIZE // 8 # Each block is ~25px
16
+ X_LABELS = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h'] # Labels for x-axis (a to h)
17
+ Y_LABELS = [8, 7, 6, 5, 4, 3, 2, 1] # Reversed labels for y-axis (8 to 1)
18
+
19
+ # Functions
20
+ def get_grid_coordinate(pixel_x, pixel_y):
21
+ """
22
+ Function to determine the grid coordinate of a pixel, considering a 10px border and
23
+ the grid where bottom-left is (a, 1) and top-left is (h, 8).
24
+ """
25
+ # Grid settings
26
+ border = 10 # 10px border
27
+ grid_size = 204 # Effective grid size (10px to 214px)
28
+ block_size = grid_size // 8 # Each block is ~25px
29
+
30
+ x_labels = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h'] # Labels for x-axis (a to h)
31
+ y_labels = [8, 7, 6, 5, 4, 3, 2, 1] # Reversed labels for y-axis (8 to 1)
32
+
33
+ # Adjust pixel_x and pixel_y by subtracting the border (grid starts at pixel 10)
34
+ adjusted_x = pixel_x - border
35
+ adjusted_y = pixel_y - border
36
+
37
+ # Check bounds
38
+ if adjusted_x < 0 or adjusted_y < 0 or adjusted_x >= grid_size or adjusted_y >= grid_size:
39
+ return "Pixel outside grid bounds"
40
+
41
+ # Determine the grid column and row
42
+ x_index = adjusted_x // block_size
43
+ y_index = adjusted_y // block_size
44
+
45
+ if x_index < 0 or x_index >= len(x_labels) or y_index < 0 or y_index >= len(y_labels):
46
+ return "Pixel outside grid bounds"
47
+
48
+ # Convert indices to grid coordinates
49
+ x_index = adjusted_x // block_size # Determine the column index (0-7)
50
+ y_index = adjusted_y // block_size # Determine the row index (0-7)
51
+
52
+ # Convert row index to the correct label, with '8' at the bottom
53
+ y_labeld = y_labels[y_index] # Correct index directly maps to '8' to '1'
54
+ x_label = x_labels[x_index]
55
+ y_label = 8 - y_labeld + 1
56
+
57
+ return f"{x_label}{y_label}"
58
+
59
+ def predict_next_move(fen, stockfish):
60
+ """
61
+ Predict the next move using Stockfish.
62
+ """
63
+ if stockfish.is_fen_valid(fen):
64
+ stockfish.set_fen_position(fen)
65
+ else:
66
+ return "Invalid FEN notation!"
67
+
68
+ best_move = stockfish.get_best_move()
69
+ ans = transform_string(best_move)
70
+ return f"The predicted next move is: {ans}" if best_move else "No valid move found (checkmate/stalemate)."
71
+
72
+
73
+
74
+
75
+ def process_image(image_path):
76
+ # Ensure output directory exists
77
+ if not os.path.exists('output'):
78
+ os.makedirs('output')
79
+
80
+ # Load the segmentation model
81
+ segmentation_model = YOLO("segmentation.pt")
82
+
83
+ # Run inference to get segmentation results
84
+ results = segmentation_model.predict(
85
+ source=image_path,
86
+ conf=0.8 # Confidence threshold
87
+ )
88
+
89
+ # Initialize variables for the segmented mask and bounding box
90
+ segmentation_mask = None
91
+ bbox = None
92
+
93
+ for result in results:
94
+ if result.boxes.conf[0] >= 0.8: # Filter results by confidence
95
+ segmentation_mask = result.masks.data.cpu().numpy().astype(np.uint8)[0]
96
+ bbox = result.boxes.xyxy[0].cpu().numpy() # Get the bounding box coordinates
97
+ break
98
+
99
+ if segmentation_mask is None:
100
+ print("No segmentation mask with confidence above 0.8 found.")
101
+ return None
102
+
103
+ # Load the image
104
+ image = cv2.imread(image_path)
105
+
106
+ # Resize segmentation mask to match the input image dimensions
107
+ segmentation_mask_resized = cv2.resize(segmentation_mask, (image.shape[1], image.shape[0]))
108
+
109
+ # Extract bounding box coordinates
110
+ if bbox is not None:
111
+ x1, y1, x2, y2 = bbox
112
+ # Crop the segmented region based on the bounding box
113
+ cropped_segment = image[int(y1):int(y2), int(x1):int(x2)]
114
+
115
+ # Save the cropped segmented image
116
+ cropped_image_path = 'output/cropped_segment.jpg'
117
+ cv2.imwrite(cropped_image_path, cropped_segment)
118
+ print(f"Cropped segmented image saved to {cropped_image_path}")
119
+
120
+ # Return the cropped image
121
+ return cropped_segment
122
+
123
+ def transform_string(input_str):
124
+ # Remove extra spaces and convert to lowercase
125
+ input_str = input_str.strip().lower()
126
+
127
+ # Check if input is valid
128
+ if len(input_str) != 4 or not input_str[0].isalpha() or not input_str[1].isdigit() or \
129
+ not input_str[2].isalpha() or not input_str[3].isdigit():
130
+ return "Invalid input"
131
+
132
+ # Define mappings
133
+ letter_mapping = {
134
+ 'a': 'h', 'b': 'g', 'c': 'f', 'd': 'e',
135
+ 'e': 'd', 'f': 'c', 'g': 'b', 'h': 'a'
136
+ }
137
+ number_mapping = {
138
+ '1': '8', '2': '7', '3': '6', '4': '5',
139
+ '5': '4', '6': '3', '7': '2', '8': '1'
140
+ }
141
+
142
+ # Transform string
143
+ result = ""
144
+ for i, char in enumerate(input_str):
145
+ if i % 2 == 0: # Letters
146
+ result += letter_mapping.get(char, "Invalid")
147
+ else: # Numbers
148
+ result += number_mapping.get(char, "Invalid")
149
+
150
+ # Check for invalid transformations
151
+ if "Invalid" in result:
152
+ return "Invalid input"
153
+
154
+ return result
155
+
156
+ # Example usage
157
+ # Output: d6h2
158
+ # Example usage:
159
+
160
+
161
+
162
+ # Streamlit app
163
+ def main():
164
+ st.title("Chessboard Position Detection and Move Prediction")
165
+
166
+ # User uploads an image or captures it from their camera
167
+ image_file = st.camera_input("Capture a chessboard image") or st.file_uploader("Upload a chessboard image", type=["jpg", "jpeg", "png"])
168
+
169
+ if image_file is not None:
170
+ # Save the image to a temporary file
171
+ temp_dir = "temp_images"
172
+ os.makedirs(temp_dir, exist_ok=True)
173
+ temp_file_path = os.path.join(temp_dir, "uploaded_image.jpg")
174
+ with open(temp_file_path, "wb") as f:
175
+ f.write(image_file.getbuffer())
176
+
177
+ # Process the image using its file path
178
+ processed_image = process_image(temp_file_path)
179
+
180
+ if processed_image is not None:
181
+ # Resize the image to 224x224
182
+ processed_image = cv2.resize(processed_image, (224, 224))
183
+ height, width, _ = processed_image.shape
184
+
185
+ # Initialize the YOLO model
186
+ model = YOLO("standard.pt") # Replace with your trained model weights file
187
+
188
+ # Run detection
189
+ results = model.predict(source=processed_image, save=False, save_txt=False, conf=0.6)
190
+
191
+ # Initialize the board for FEN (empty rows represented by "8")
192
+ board = [["8"] * 8 for _ in range(8)]
193
+
194
+ # Extract predictions and map to FEN board
195
+ for result in results[0].boxes:
196
+ x1, y1, x2, y2 = result.xyxy[0].tolist()
197
+ class_id = int(result.cls[0])
198
+ class_name = model.names[class_id]
199
+
200
+ # Convert class_name to FEN notation
201
+ fen_piece = FEN_MAPPING.get(class_name, None)
202
+ if not fen_piece:
203
+ continue
204
+
205
+ # Calculate the center of the bounding box
206
+ center_x = (x1 + x2) / 2
207
+ center_y = (y1 + y2) / 2
208
+
209
+ # Convert to integer pixel coordinates
210
+ pixel_x = int(center_x)
211
+ pixel_y = int(height - center_y) # Flip Y-axis for generic coordinate system
212
+
213
+ # Get grid coordinate
214
+ grid_position = get_grid_coordinate(pixel_x, pixel_y)
215
+
216
+ if grid_position != "Pixel outside grid bounds":
217
+ file = ord(grid_position[0]) - ord('a') # Column index (0-7)
218
+ rank = int(grid_position[1]) - 1 # Row index (0-7)
219
+
220
+ # Place the piece on the board
221
+ board[7 - rank][file] = fen_piece # Flip rank index for FEN
222
+
223
+ # Generate the FEN string
224
+ fen_rows = []
225
+ for row in board:
226
+ fen_row = ""
227
+ empty_count = 0
228
+ for cell in row:
229
+ if cell == "8":
230
+ empty_count += 1
231
+ else:
232
+ if empty_count > 0:
233
+ fen_row += str(empty_count)
234
+ empty_count = 0
235
+ fen_row += cell
236
+ if empty_count > 0:
237
+ fen_row += str(empty_count)
238
+ fen_rows.append(fen_row)
239
+
240
+ position_fen = "/".join(fen_rows)
241
+
242
+ # Ask the user for the next move side
243
+ move_side = st.selectbox("Select the side to move:", ["w (White)", "b (Black)"])
244
+ move_side = "w" if move_side.startswith("w") else "b"
245
+
246
+ # Append the full FEN string continuation
247
+ fen_notation = f"{position_fen} {move_side} - - 0 0"
248
+
249
+ st.subheader("Generated FEN Notation:")
250
+ st.code(fen_notation)
251
+
252
+ # Initialize the Stockfish engine
253
+ stockfish = Stockfish(
254
+ path=r"D:\Projects\ChessVision\StockFish\stockfish\stockfish-windows-x86-64-avx2.exe", # Replace with your Stockfish path
255
+ depth=15,
256
+ parameters={"Threads": 2, "Minimum Thinking Time": 30}
257
+ )
258
+
259
+ # Predict the next move
260
+ next_move = predict_next_move(fen_notation, stockfish)
261
+ st.subheader("Stockfish Recommended Move:")
262
+ st.write(next_move)
263
+
264
+ else:
265
+ st.error("Failed to process the image. Please try again.")
266
+
267
+ if __name__ == "__main__":
268
+ main()
segmentation.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:912bbbde63f435106d57c7416c11a49eb3e9cb93dfe71cb6f9bfaafc1a4e3683
3
+ size 6781485
standard.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e2c19a7f75312af21e9e514f008a05da5ff5624590cc5a8997c977a16d2ac459
3
+ size 114375506