siddharth060104 commited on
Commit
f296950
·
verified ·
1 Parent(s): 7386c79

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +91 -194
app.py CHANGED
@@ -67,112 +67,21 @@ def predict_next_move(fen, stockfish):
67
  return "Invalid FEN notation!"
68
 
69
  best_move = stockfish.get_best_move()
70
- ans = transform_string(best_move)
71
- return f"The predicted next move is: {ans}" if best_move else "No valid move found (checkmate/stalemate)."
72
 
73
- def process_image(image_path):
74
- # Ensure output directory exists
75
- if not os.path.exists('output'):
76
- os.makedirs('output')
77
-
78
- # Load the segmentation model
79
- segmentation_model = YOLO("segmentation.pt")
80
-
81
- # Run inference to get segmentation results
82
- results = segmentation_model.predict(
83
- source=image_path,
84
- conf=0.8 # Confidence threshold
85
- )
86
-
87
- # Initialize variables for the segmented mask and bounding box
88
- segmentation_mask = None
89
- bbox = None
90
-
91
- for result in results:
92
- if result.boxes.conf[0] >= 0.8: # Filter results by confidence
93
- segmentation_mask = result.masks.data.cpu().numpy().astype(np.uint8)[0]
94
- bbox = result.boxes.xyxy[0].cpu().numpy() # Get the bounding box coordinates
95
- break
96
-
97
- if segmentation_mask is None:
98
- print("No segmentation mask with confidence above 0.8 found.")
99
- return None
100
-
101
- # Load the image
102
- image = cv2.imread(image_path)
103
-
104
- # Convert the image to RGB format
105
- image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
106
-
107
- # Resize segmentation mask to match the input image dimensions
108
- segmentation_mask_resized = cv2.resize(segmentation_mask, (image.shape[1], image.shape[0]))
109
-
110
- # Extract bounding box coordinates
111
- if bbox is not None:
112
- x1, y1, x2, y2 = bbox
113
- # Crop the segmented region based on the bounding box
114
- cropped_segment = image[int(y1):int(y2), int(x1):int(x2)]
115
-
116
- # Convert the cropped segment to RGB
117
- cropped_segment_rgb = cv2.cvtColor(cropped_segment, cv2.COLOR_BGR2RGB)
118
-
119
- # Save the cropped segmented image
120
- cropped_image_path = 'output/cropped_segment.jpg'
121
- cv2.imwrite(cropped_image_path, cropped_segment)
122
- print(f"Cropped segmented image saved to {cropped_image_path}")
123
-
124
- # Display the image in Streamlit
125
- st.image(cropped_segment_rgb, caption="Uploaded Image (Cropped)", use_column_width=True)
126
-
127
- # Return the cropped RGB image
128
- return cropped_segment_rgb
129
-
130
-
131
- def transform_string(input_str):
132
- # Remove extra spaces and convert to lowercase
133
- input_str = input_str.strip().lower()
134
-
135
- # Check if input is valid
136
- if len(input_str) != 4 or not input_str[0].isalpha() or not input_str[1].isdigit() or \
137
- not input_str[2].isalpha() or not input_str[3].isdigit():
138
- return "Invalid input"
139
-
140
- # Define mappings
141
- letter_mapping = {
142
- 'a': 'h', 'b': 'g', 'c': 'f', 'd': 'e',
143
- 'e': 'd', 'f': 'c', 'g': 'b', 'h': 'a'
144
- }
145
- number_mapping = {
146
- '1': '8', '2': '7', '3': '6', '4': '5',
147
- '5': '4', '6': '3', '7': '2', '8': '1'
148
- }
149
-
150
- # Transform string
151
- result = ""
152
- for i, char in enumerate(input_str):
153
- if i % 2 == 0: # Letters
154
- result += letter_mapping.get(char, "Invalid")
155
- else: # Numbers
156
- result += number_mapping.get(char, "Invalid")
157
-
158
- # Check for invalid transformations
159
- if "Invalid" in result:
160
- return "Invalid input"
161
-
162
- return result
163
-
164
-
165
-
166
- # Streamlit app
167
- def main():
168
- st.title("Chessboard Position Detection and Move Prediction")
169
-
170
- os.chmod("/home/user/app/stockfish-ubuntu-x86-64-sse41-popcnt", 0o755)
171
 
172
-
173
 
174
 
175
 
 
 
 
 
 
 
 
 
176
  # User uploads an image or captures it from their camera
177
  image_file = st.camera_input("Capture a chessboard image") or st.file_uploader("Upload a chessboard image", type=["jpg", "jpeg", "png"])
178
 
@@ -184,96 +93,84 @@ def main():
184
  with open(temp_file_path, "wb") as f:
185
  f.write(image_file.getbuffer())
186
 
187
- # Process the image using its file path
188
- processed_image = process_image(temp_file_path)
189
-
190
- if processed_image is not None:
191
- # Resize the image to 224x224
192
- processed_image = cv2.resize(processed_image, (224, 224))
193
- height, width, _ = processed_image.shape
194
-
195
- # Initialize the YOLO model
196
- model = YOLO("fine_tuned_on_all_data.pt") # Replace with your trained model weights file
197
-
198
- # Run detection
199
- results = model.predict(source=processed_image, save=False, save_txt=False, conf=0.5)
200
-
201
- # Initialize the board for FEN (empty rows represented by "8")
202
- board = [["8"] * 8 for _ in range(8)]
203
-
204
- # Extract predictions and map to FEN board
205
- for result in results[0].boxes:
206
- x1, y1, x2, y2 = result.xyxy[0].tolist()
207
- class_id = int(result.cls[0])
208
- class_name = model.names[class_id]
209
-
210
- # Convert class_name to FEN notation
211
- fen_piece = FEN_MAPPING.get(class_name, None)
212
- if not fen_piece:
213
- continue
214
-
215
- # Calculate the center of the bounding box
216
- center_x = (x1 + x2) / 2
217
- center_y = (y1 + y2) / 2
218
-
219
- # Convert to integer pixel coordinates
220
- pixel_x = int(center_x)
221
- pixel_y = int(height - center_y) # Flip Y-axis for generic coordinate system
222
-
223
- # Get grid coordinate
224
- grid_position = get_grid_coordinate(pixel_x, pixel_y)
225
-
226
- if grid_position != "Pixel outside grid bounds":
227
- file = ord(grid_position[0]) - ord('a') # Column index (0-7)
228
- rank = int(grid_position[1]) - 1 # Row index (0-7)
229
-
230
- # Place the piece on the board
231
- board[7 - rank][file] = fen_piece # Flip rank index for FEN
232
-
233
- # Generate the FEN string
234
- fen_rows = []
235
- for row in board:
236
- fen_row = ""
237
- empty_count = 0
238
- for cell in row:
239
- if cell == "8":
240
- empty_count += 1
241
- else:
242
- if empty_count > 0:
243
- fen_row += str(empty_count)
244
- empty_count = 0
245
- fen_row += cell
246
- if empty_count > 0:
247
- fen_row += str(empty_count)
248
- fen_rows.append(fen_row)
249
-
250
- position_fen = "/".join(fen_rows)
251
-
252
- # Ask the user for the next move side
253
- move_side = st.selectbox("Select the side to move:", ["w (White)", "b (Black)"])
254
- move_side = "w" if move_side.startswith("w") else "b"
255
-
256
- # Append the full FEN string continuation
257
- fen_notation = f"{position_fen} {move_side} - - 0 0"
258
-
259
- st.subheader("Generated FEN Notation:")
260
- st.code(fen_notation)
261
-
262
- # Initialize the Stockfish engine
263
- stockfish_path = os.path.join(os.getcwd(), "stockfish-ubuntu-x86-64-sse41-popcnt")
264
- stockfish = Stockfish(
265
- path=stockfish_path,
266
- depth=15,
267
- parameters={"Threads": 2, "Minimum Thinking Time": 30}
268
- )
269
-
270
- # Predict the next move
271
- next_move = predict_next_move(fen_notation, stockfish)
272
- st.subheader("Stockfish Recommended Move:")
273
- st.write(next_move)
274
-
275
- else:
276
- st.error("Failed to process the image. Please try again.")
277
-
278
- if __name__ == "__main__":
279
- main()
 
67
  return "Invalid FEN notation!"
68
 
69
  best_move = stockfish.get_best_move()
70
+ return f"The predicted next move is: {best_move}" if best_move else "No valid move found (checkmate/stalemate)."
71
+
72
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
 
 
74
 
75
 
76
 
77
+
78
+
79
+ def main():
80
+ st.title("Chessboard Position Detection and Move Prediction")
81
+
82
+ # Set permissions for the Stockfish engine binary
83
+ os.chmod("/home/user/app/stockfish-ubuntu-x86-64-sse41-popcnt", 0o755)
84
+
85
  # User uploads an image or captures it from their camera
86
  image_file = st.camera_input("Capture a chessboard image") or st.file_uploader("Upload a chessboard image", type=["jpg", "jpeg", "png"])
87
 
 
93
  with open(temp_file_path, "wb") as f:
94
  f.write(image_file.getbuffer())
95
 
96
+ # Load the YOLO models
97
+ model = YOLO("fine_tuned_on_all_data.pt") # Replace with your trained model weights file
98
+ seg_model = YOLO("segmentation.pt")
99
+
100
+ # Load and process the image
101
+ img = cv2.imread(temp_file_path)
102
+ r = seg_model.predict(source=temp_file_path)
103
+ xyxy = r[0].boxes.xyxy
104
+ x_min, y_min, x_max, y_max = map(int, xyxy[0])
105
+ new_img = img[y_min:y_max, x_min:x_max]
106
+
107
+ # Resize the image to 224x224
108
+ image = cv2.resize(new_img, (224, 224))
109
+ height, width, _ = image.shape
110
+
111
+ # Get user input for perspective
112
+ p = st.radio("Select perspective:", ["b (Black)", "w (White)"])
113
+ p = p[0].lower()
114
+
115
+ # Initialize the board for FEN (empty rows represented by "8")
116
+ board = [["8"] * 8 for _ in range(8)]
117
+
118
+ # Run detection
119
+ results = model.predict(source=image, save=False, save_txt=False, conf=0.6)
120
+
121
+ # Extract predictions and map to FEN board
122
+ for result in results[0].boxes:
123
+ x1, y1, x2, y2 = result.xyxy[0].tolist()
124
+ class_id = int(result.cls[0])
125
+ class_name = model.names[class_id]
126
+
127
+ fen_piece = FEN_MAPPING.get(class_name, None)
128
+ if not fen_piece:
129
+ continue
130
+
131
+ center_x = (x1 + x2) / 2
132
+ center_y = (y1 + y2) / 2
133
+ pixel_x = int(center_x)
134
+ pixel_y = int(height - center_y)
135
+
136
+ grid_position = get_grid_coordinate(pixel_x, pixel_y, p)
137
+ if grid_position != "Pixel outside grid bounds":
138
+ file = ord(grid_position[0]) - ord('a')
139
+ rank = int(grid_position[1]) - 1
140
+ board[rank][file] = fen_piece
141
+
142
+ # Generate the FEN string
143
+ fen_rows = []
144
+ for row in board:
145
+ fen_row = ""
146
+ empty_count = 0
147
+ for cell in row:
148
+ if cell == "8":
149
+ empty_count += 1
150
+ else:
151
+ if empty_count > 0:
152
+ fen_row += str(empty_count)
153
+ empty_count = 0
154
+ fen_row += cell
155
+ if empty_count > 0:
156
+ fen_row += str(empty_count)
157
+ fen_rows.append(fen_row)
158
+
159
+ position_fen = "/".join(fen_rows)
160
+ move_side = st.radio("Select the side to move:", ["w (White)", "b (Black)"])[0].lower()
161
+ fen_notation = f"{position_fen} {move_side} - - 0 0"
162
+ st.subheader("Generated FEN Notation:")
163
+ st.code(fen_notation)
164
+
165
+ # Initialize the Stockfish engine
166
+ stockfish_path = os.path.join(os.getcwd(), "stockfish-ubuntu-x86-64-sse41-popcnt")
167
+ stockfish = Stockfish(
168
+ path=stockfish_path,
169
+ depth=15,
170
+ parameters={"Threads": 2, "Minimum Thinking Time": 30}
171
+ )
172
+
173
+ # Predict the next move
174
+ next_move = predict_next_move(fen_notation, stockfish)
175
+ st.subheader("Stockfish Recommended Move:")
176
+ st.write(next_move)