siddharth060104 commited on
Commit
7d262ec
·
verified ·
1 Parent(s): b1570bd

Upload 4 files

Browse files
Files changed (5) hide show
  1. .gitattributes +1 -0
  2. app.py +265 -0
  3. segmentation.pt +3 -0
  4. standard.pt +3 -0
  5. stockfish-windows-x86-64-avx2.exe +3 -0
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ stockfish-windows-x86-64-avx2.exe filter=lfs diff=lfs merge=lfs -text
app.py ADDED
@@ -0,0 +1,265 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ultralytics import YOLO
2
+ import cv2
3
+ from stockfish import Stockfish
4
+ import os
5
+ import numpy as np
6
+ import streamlit as st
7
+
8
+ # Constants
9
+ FEN_MAPPING = {
10
+ "black-pawn": "p", "black-rook": "r", "black-knight": "n", "black-bishop": "b", "black-queen": "q", "black-king": "k",
11
+ "white-pawn": "P", "white-rook": "R", "white-knight": "N", "white-bishop": "B", "white-queen": "Q", "white-king": "K"
12
+ }
13
+ GRID_BORDER = 10 # Border size in pixels
14
+ GRID_SIZE = 204 # Effective grid size (10px to 214px)
15
+ BLOCK_SIZE = GRID_SIZE // 8 # Each block is ~25px
16
+ X_LABELS = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h'] # Labels for x-axis (a to h)
17
+ Y_LABELS = [8, 7, 6, 5, 4, 3, 2, 1] # Reversed labels for y-axis (8 to 1)
18
+
19
+ # Functions
20
+ def get_grid_coordinate(pixel_x, pixel_y):
21
+ """
22
+ Function to determine the grid coordinate of a pixel, considering a 10px border and
23
+ the grid where bottom-left is (a, 1) and top-left is (h, 8).
24
+ """
25
+ # Grid settings
26
+ border = 10 # 10px border
27
+ grid_size = 204 # Effective grid size (10px to 214px)
28
+ block_size = grid_size // 8 # Each block is ~25px
29
+
30
+ x_labels = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h'] # Labels for x-axis (a to h)
31
+ y_labels = [8, 7, 6, 5, 4, 3, 2, 1] # Reversed labels for y-axis (8 to 1)
32
+
33
+ # Adjust pixel_x and pixel_y by subtracting the border (grid starts at pixel 10)
34
+ adjusted_x = pixel_x - border
35
+ adjusted_y = pixel_y - border
36
+
37
+ # Check bounds
38
+ if adjusted_x < 0 or adjusted_y < 0 or adjusted_x >= grid_size or adjusted_y >= grid_size:
39
+ return "Pixel outside grid bounds"
40
+
41
+ # Determine the grid column and row
42
+ x_index = adjusted_x // block_size
43
+ y_index = adjusted_y // block_size
44
+
45
+ if x_index < 0 or x_index >= len(x_labels) or y_index < 0 or y_index >= len(y_labels):
46
+ return "Pixel outside grid bounds"
47
+
48
+ # Convert indices to grid coordinates
49
+ x_index = adjusted_x // block_size # Determine the column index (0-7)
50
+ y_index = adjusted_y // block_size # Determine the row index (0-7)
51
+
52
+ # Convert row index to the correct label, with '8' at the bottom
53
+ y_labeld = y_labels[y_index] # Correct index directly maps to '8' to '1'
54
+ x_label = x_labels[x_index]
55
+ y_label = 8 - y_labeld + 1
56
+
57
+ return f"{x_label}{y_label}"
58
+
59
+ def predict_next_move(fen, stockfish):
60
+ """
61
+ Predict the next move using Stockfish.
62
+ """
63
+ if stockfish.is_fen_valid(fen):
64
+ stockfish.set_fen_position(fen)
65
+ else:
66
+ return "Invalid FEN notation!"
67
+
68
+ best_move = stockfish.get_best_move()
69
+ ans = transform_string(best_move)
70
+ return f"The predicted next move is: {ans}" if best_move else "No valid move found (checkmate/stalemate)."
71
+
72
+
73
+
74
+
75
+ def process_image(image_path):
76
+ # Ensure output directory exists
77
+ if not os.path.exists('output'):
78
+ os.makedirs('output')
79
+
80
+ # Load the segmentation model
81
+ segmentation_model = YOLO("segmentation.pt")
82
+
83
+ # Run inference to get segmentation results
84
+ results = segmentation_model.predict(
85
+ source=image_path,
86
+ conf=0.8 # Confidence threshold
87
+ )
88
+
89
+ # Initialize variables for the segmented mask and bounding box
90
+ segmentation_mask = None
91
+ bbox = None
92
+
93
+ for result in results:
94
+ if result.boxes.conf[0] >= 0.8: # Filter results by confidence
95
+ segmentation_mask = result.masks.data.cpu().numpy().astype(np.uint8)[0]
96
+ bbox = result.boxes.xyxy[0].cpu().numpy() # Get the bounding box coordinates
97
+ break
98
+
99
+ if segmentation_mask is None:
100
+ print("No segmentation mask with confidence above 0.8 found.")
101
+ return None
102
+
103
+ # Load the image
104
+ image = cv2.imread(image_path)
105
+
106
+ # Resize segmentation mask to match the input image dimensions
107
+ segmentation_mask_resized = cv2.resize(segmentation_mask, (image.shape[1], image.shape[0]))
108
+
109
+ # Extract bounding box coordinates
110
+ if bbox is not None:
111
+ x1, y1, x2, y2 = bbox
112
+ # Crop the segmented region based on the bounding box
113
+ cropped_segment = image[int(y1):int(y2), int(x1):int(x2)]
114
+
115
+ # Save the cropped segmented image
116
+ cropped_image_path = 'output/cropped_segment.jpg'
117
+ cv2.imwrite(cropped_image_path, cropped_segment)
118
+ print(f"Cropped segmented image saved to {cropped_image_path}")
119
+
120
+ st.image(cropped_segment, caption="Uploaded Image", use_column_width=True)
121
+ # Return the cropped image
122
+ return cropped_segment
123
+
124
+ def transform_string(input_str):
125
+ # Remove extra spaces and convert to lowercase
126
+ input_str = input_str.strip().lower()
127
+
128
+ # Check if input is valid
129
+ if len(input_str) != 4 or not input_str[0].isalpha() or not input_str[1].isdigit() or \
130
+ not input_str[2].isalpha() or not input_str[3].isdigit():
131
+ return "Invalid input"
132
+
133
+ # Define mappings
134
+ letter_mapping = {
135
+ 'a': 'h', 'b': 'g', 'c': 'f', 'd': 'e',
136
+ 'e': 'd', 'f': 'c', 'g': 'b', 'h': 'a'
137
+ }
138
+ number_mapping = {
139
+ '1': '8', '2': '7', '3': '6', '4': '5',
140
+ '5': '4', '6': '3', '7': '2', '8': '1'
141
+ }
142
+
143
+ # Transform string
144
+ result = ""
145
+ for i, char in enumerate(input_str):
146
+ if i % 2 == 0: # Letters
147
+ result += letter_mapping.get(char, "Invalid")
148
+ else: # Numbers
149
+ result += number_mapping.get(char, "Invalid")
150
+
151
+ # Check for invalid transformations
152
+ if "Invalid" in result:
153
+ return "Invalid input"
154
+
155
+ return result
156
+
157
+
158
+
159
+ # Streamlit app
160
+ def main():
161
+ st.title("Chessboard Position Detection and Move Prediction")
162
+
163
+ # User uploads an image or captures it from their camera
164
+ image_file = st.camera_input("Capture a chessboard image") or st.file_uploader("Upload a chessboard image", type=["jpg", "jpeg", "png"])
165
+
166
+ if image_file is not None:
167
+ # Save the image to a temporary file
168
+ temp_dir = "temp_images"
169
+ os.makedirs(temp_dir, exist_ok=True)
170
+ temp_file_path = os.path.join(temp_dir, "uploaded_image.jpg")
171
+ with open(temp_file_path, "wb") as f:
172
+ f.write(image_file.getbuffer())
173
+
174
+ # Process the image using its file path
175
+ processed_image = process_image(temp_file_path)
176
+
177
+ if processed_image is not None:
178
+ # Resize the image to 224x224
179
+ processed_image = cv2.resize(processed_image, (224, 224))
180
+ height, width, _ = processed_image.shape
181
+
182
+ # Initialize the YOLO model
183
+ model = YOLO("standard.pt") # Replace with your trained model weights file
184
+
185
+ # Run detection
186
+ results = model.predict(source=processed_image, save=False, save_txt=False, conf=0.6)
187
+
188
+ # Initialize the board for FEN (empty rows represented by "8")
189
+ board = [["8"] * 8 for _ in range(8)]
190
+
191
+ # Extract predictions and map to FEN board
192
+ for result in results[0].boxes:
193
+ x1, y1, x2, y2 = result.xyxy[0].tolist()
194
+ class_id = int(result.cls[0])
195
+ class_name = model.names[class_id]
196
+
197
+ # Convert class_name to FEN notation
198
+ fen_piece = FEN_MAPPING.get(class_name, None)
199
+ if not fen_piece:
200
+ continue
201
+
202
+ # Calculate the center of the bounding box
203
+ center_x = (x1 + x2) / 2
204
+ center_y = (y1 + y2) / 2
205
+
206
+ # Convert to integer pixel coordinates
207
+ pixel_x = int(center_x)
208
+ pixel_y = int(height - center_y) # Flip Y-axis for generic coordinate system
209
+
210
+ # Get grid coordinate
211
+ grid_position = get_grid_coordinate(pixel_x, pixel_y)
212
+
213
+ if grid_position != "Pixel outside grid bounds":
214
+ file = ord(grid_position[0]) - ord('a') # Column index (0-7)
215
+ rank = int(grid_position[1]) - 1 # Row index (0-7)
216
+
217
+ # Place the piece on the board
218
+ board[7 - rank][file] = fen_piece # Flip rank index for FEN
219
+
220
+ # Generate the FEN string
221
+ fen_rows = []
222
+ for row in board:
223
+ fen_row = ""
224
+ empty_count = 0
225
+ for cell in row:
226
+ if cell == "8":
227
+ empty_count += 1
228
+ else:
229
+ if empty_count > 0:
230
+ fen_row += str(empty_count)
231
+ empty_count = 0
232
+ fen_row += cell
233
+ if empty_count > 0:
234
+ fen_row += str(empty_count)
235
+ fen_rows.append(fen_row)
236
+
237
+ position_fen = "/".join(fen_rows)
238
+
239
+ # Ask the user for the next move side
240
+ move_side = st.selectbox("Select the side to move:", ["w (White)", "b (Black)"])
241
+ move_side = "w" if move_side.startswith("w") else "b"
242
+
243
+ # Append the full FEN string continuation
244
+ fen_notation = f"{position_fen} {move_side} - - 0 0"
245
+
246
+ st.subheader("Generated FEN Notation:")
247
+ st.code(fen_notation)
248
+
249
+ # Initialize the Stockfish engine
250
+ stockfish = Stockfish(
251
+ path=r"stockfish-windows-x86-64-avx2.exe", # Replace with your Stockfish path"
252
+ depth=15,
253
+ parameters={"Threads": 2, "Minimum Thinking Time": 30}
254
+ )
255
+
256
+ # Predict the next move
257
+ next_move = predict_next_move(fen_notation, stockfish)
258
+ st.subheader("Stockfish Recommended Move:")
259
+ st.write(next_move)
260
+
261
+ else:
262
+ st.error("Failed to process the image. Please try again.")
263
+
264
+ if __name__ == "__main__":
265
+ main()
segmentation.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:912bbbde63f435106d57c7416c11a49eb3e9cb93dfe71cb6f9bfaafc1a4e3683
3
+ size 6781485
standard.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e2c19a7f75312af21e9e514f008a05da5ff5624590cc5a8997c977a16d2ac459
3
+ size 114375506
stockfish-windows-x86-64-avx2.exe ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0b24c22f7894fa13ab27e32a29763055d0867dfb123d8763579dea5b7a91f419
3
+ size 79811584