Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -67,112 +67,21 @@ def predict_next_move(fen, stockfish):
|
|
67 |
return "Invalid FEN notation!"
|
68 |
|
69 |
best_move = stockfish.get_best_move()
|
70 |
-
|
71 |
-
|
72 |
|
73 |
-
def process_image(image_path):
|
74 |
-
# Ensure output directory exists
|
75 |
-
if not os.path.exists('output'):
|
76 |
-
os.makedirs('output')
|
77 |
-
|
78 |
-
# Load the segmentation model
|
79 |
-
segmentation_model = YOLO("segmentation.pt")
|
80 |
-
|
81 |
-
# Run inference to get segmentation results
|
82 |
-
results = segmentation_model.predict(
|
83 |
-
source=image_path,
|
84 |
-
conf=0.8 # Confidence threshold
|
85 |
-
)
|
86 |
-
|
87 |
-
# Initialize variables for the segmented mask and bounding box
|
88 |
-
segmentation_mask = None
|
89 |
-
bbox = None
|
90 |
-
|
91 |
-
for result in results:
|
92 |
-
if result.boxes.conf[0] >= 0.8: # Filter results by confidence
|
93 |
-
segmentation_mask = result.masks.data.cpu().numpy().astype(np.uint8)[0]
|
94 |
-
bbox = result.boxes.xyxy[0].cpu().numpy() # Get the bounding box coordinates
|
95 |
-
break
|
96 |
-
|
97 |
-
if segmentation_mask is None:
|
98 |
-
print("No segmentation mask with confidence above 0.8 found.")
|
99 |
-
return None
|
100 |
-
|
101 |
-
# Load the image
|
102 |
-
image = cv2.imread(image_path)
|
103 |
-
|
104 |
-
# Convert the image to RGB format
|
105 |
-
image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
106 |
-
|
107 |
-
# Resize segmentation mask to match the input image dimensions
|
108 |
-
segmentation_mask_resized = cv2.resize(segmentation_mask, (image.shape[1], image.shape[0]))
|
109 |
-
|
110 |
-
# Extract bounding box coordinates
|
111 |
-
if bbox is not None:
|
112 |
-
x1, y1, x2, y2 = bbox
|
113 |
-
# Crop the segmented region based on the bounding box
|
114 |
-
cropped_segment = image[int(y1):int(y2), int(x1):int(x2)]
|
115 |
-
|
116 |
-
# Convert the cropped segment to RGB
|
117 |
-
cropped_segment_rgb = cv2.cvtColor(cropped_segment, cv2.COLOR_BGR2RGB)
|
118 |
-
|
119 |
-
# Save the cropped segmented image
|
120 |
-
cropped_image_path = 'output/cropped_segment.jpg'
|
121 |
-
cv2.imwrite(cropped_image_path, cropped_segment)
|
122 |
-
print(f"Cropped segmented image saved to {cropped_image_path}")
|
123 |
-
|
124 |
-
# Display the image in Streamlit
|
125 |
-
st.image(cropped_segment_rgb, caption="Uploaded Image (Cropped)", use_column_width=True)
|
126 |
-
|
127 |
-
# Return the cropped RGB image
|
128 |
-
return cropped_segment_rgb
|
129 |
-
|
130 |
-
|
131 |
-
def transform_string(input_str):
|
132 |
-
# Remove extra spaces and convert to lowercase
|
133 |
-
input_str = input_str.strip().lower()
|
134 |
-
|
135 |
-
# Check if input is valid
|
136 |
-
if len(input_str) != 4 or not input_str[0].isalpha() or not input_str[1].isdigit() or \
|
137 |
-
not input_str[2].isalpha() or not input_str[3].isdigit():
|
138 |
-
return "Invalid input"
|
139 |
-
|
140 |
-
# Define mappings
|
141 |
-
letter_mapping = {
|
142 |
-
'a': 'h', 'b': 'g', 'c': 'f', 'd': 'e',
|
143 |
-
'e': 'd', 'f': 'c', 'g': 'b', 'h': 'a'
|
144 |
-
}
|
145 |
-
number_mapping = {
|
146 |
-
'1': '8', '2': '7', '3': '6', '4': '5',
|
147 |
-
'5': '4', '6': '3', '7': '2', '8': '1'
|
148 |
-
}
|
149 |
-
|
150 |
-
# Transform string
|
151 |
-
result = ""
|
152 |
-
for i, char in enumerate(input_str):
|
153 |
-
if i % 2 == 0: # Letters
|
154 |
-
result += letter_mapping.get(char, "Invalid")
|
155 |
-
else: # Numbers
|
156 |
-
result += number_mapping.get(char, "Invalid")
|
157 |
-
|
158 |
-
# Check for invalid transformations
|
159 |
-
if "Invalid" in result:
|
160 |
-
return "Invalid input"
|
161 |
-
|
162 |
-
return result
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
# Streamlit app
|
167 |
-
def main():
|
168 |
-
st.title("Chessboard Position Detection and Move Prediction")
|
169 |
-
|
170 |
-
os.chmod("/home/user/app/stockfish-ubuntu-x86-64-sse41-popcnt", 0o755)
|
171 |
|
172 |
-
|
173 |
|
174 |
|
175 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
176 |
# User uploads an image or captures it from their camera
|
177 |
image_file = st.camera_input("Capture a chessboard image") or st.file_uploader("Upload a chessboard image", type=["jpg", "jpeg", "png"])
|
178 |
|
@@ -184,96 +93,84 @@ def main():
|
|
184 |
with open(temp_file_path, "wb") as f:
|
185 |
f.write(image_file.getbuffer())
|
186 |
|
187 |
-
#
|
188 |
-
|
189 |
-
|
190 |
-
|
191 |
-
|
192 |
-
|
193 |
-
|
194 |
-
|
195 |
-
|
196 |
-
|
197 |
-
|
198 |
-
|
199 |
-
|
200 |
-
|
201 |
-
|
202 |
-
|
203 |
-
|
204 |
-
|
205 |
-
|
206 |
-
|
207 |
-
|
208 |
-
|
209 |
-
|
210 |
-
|
211 |
-
|
212 |
-
|
213 |
-
|
214 |
-
|
215 |
-
|
216 |
-
|
217 |
-
|
218 |
-
|
219 |
-
|
220 |
-
|
221 |
-
|
222 |
-
|
223 |
-
|
224 |
-
|
225 |
-
|
226 |
-
|
227 |
-
|
228 |
-
|
229 |
-
|
230 |
-
|
231 |
-
|
232 |
-
|
233 |
-
|
234 |
-
|
235 |
-
|
236 |
-
|
237 |
-
|
238 |
-
|
239 |
-
|
240 |
-
|
241 |
-
|
242 |
-
|
243 |
-
|
244 |
-
|
245 |
-
|
246 |
-
|
247 |
-
|
248 |
-
|
249 |
-
|
250 |
-
|
251 |
-
|
252 |
-
|
253 |
-
|
254 |
-
|
255 |
-
|
256 |
-
|
257 |
-
|
258 |
-
|
259 |
-
|
260 |
-
|
261 |
-
|
262 |
-
|
263 |
-
|
264 |
-
|
265 |
-
|
266 |
-
|
267 |
-
|
268 |
-
)
|
269 |
-
|
270 |
-
# Predict the next move
|
271 |
-
next_move = predict_next_move(fen_notation, stockfish)
|
272 |
-
st.subheader("Stockfish Recommended Move:")
|
273 |
-
st.write(next_move)
|
274 |
-
|
275 |
-
else:
|
276 |
-
st.error("Failed to process the image. Please try again.")
|
277 |
-
|
278 |
-
if __name__ == "__main__":
|
279 |
-
main()
|
|
|
67 |
return "Invalid FEN notation!"
|
68 |
|
69 |
best_move = stockfish.get_best_move()
|
70 |
+
return f"The predicted next move is: {best_move}" if best_move else "No valid move found (checkmate/stalemate)."
|
71 |
+
|
72 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
73 |
|
|
|
74 |
|
75 |
|
76 |
|
77 |
+
|
78 |
+
|
79 |
+
def main():
|
80 |
+
st.title("Chessboard Position Detection and Move Prediction")
|
81 |
+
|
82 |
+
# Set permissions for the Stockfish engine binary
|
83 |
+
os.chmod("/home/user/app/stockfish-ubuntu-x86-64-sse41-popcnt", 0o755)
|
84 |
+
|
85 |
# User uploads an image or captures it from their camera
|
86 |
image_file = st.camera_input("Capture a chessboard image") or st.file_uploader("Upload a chessboard image", type=["jpg", "jpeg", "png"])
|
87 |
|
|
|
93 |
with open(temp_file_path, "wb") as f:
|
94 |
f.write(image_file.getbuffer())
|
95 |
|
96 |
+
# Load the YOLO models
|
97 |
+
model = YOLO("fine_tuned_on_all_data.pt") # Replace with your trained model weights file
|
98 |
+
seg_model = YOLO("segmentation.pt")
|
99 |
+
|
100 |
+
# Load and process the image
|
101 |
+
img = cv2.imread(temp_file_path)
|
102 |
+
r = seg_model.predict(source=temp_file_path)
|
103 |
+
xyxy = r[0].boxes.xyxy
|
104 |
+
x_min, y_min, x_max, y_max = map(int, xyxy[0])
|
105 |
+
new_img = img[y_min:y_max, x_min:x_max]
|
106 |
+
|
107 |
+
# Resize the image to 224x224
|
108 |
+
image = cv2.resize(new_img, (224, 224))
|
109 |
+
height, width, _ = image.shape
|
110 |
+
|
111 |
+
# Get user input for perspective
|
112 |
+
p = st.radio("Select perspective:", ["b (Black)", "w (White)"])
|
113 |
+
p = p[0].lower()
|
114 |
+
|
115 |
+
# Initialize the board for FEN (empty rows represented by "8")
|
116 |
+
board = [["8"] * 8 for _ in range(8)]
|
117 |
+
|
118 |
+
# Run detection
|
119 |
+
results = model.predict(source=image, save=False, save_txt=False, conf=0.6)
|
120 |
+
|
121 |
+
# Extract predictions and map to FEN board
|
122 |
+
for result in results[0].boxes:
|
123 |
+
x1, y1, x2, y2 = result.xyxy[0].tolist()
|
124 |
+
class_id = int(result.cls[0])
|
125 |
+
class_name = model.names[class_id]
|
126 |
+
|
127 |
+
fen_piece = FEN_MAPPING.get(class_name, None)
|
128 |
+
if not fen_piece:
|
129 |
+
continue
|
130 |
+
|
131 |
+
center_x = (x1 + x2) / 2
|
132 |
+
center_y = (y1 + y2) / 2
|
133 |
+
pixel_x = int(center_x)
|
134 |
+
pixel_y = int(height - center_y)
|
135 |
+
|
136 |
+
grid_position = get_grid_coordinate(pixel_x, pixel_y, p)
|
137 |
+
if grid_position != "Pixel outside grid bounds":
|
138 |
+
file = ord(grid_position[0]) - ord('a')
|
139 |
+
rank = int(grid_position[1]) - 1
|
140 |
+
board[rank][file] = fen_piece
|
141 |
+
|
142 |
+
# Generate the FEN string
|
143 |
+
fen_rows = []
|
144 |
+
for row in board:
|
145 |
+
fen_row = ""
|
146 |
+
empty_count = 0
|
147 |
+
for cell in row:
|
148 |
+
if cell == "8":
|
149 |
+
empty_count += 1
|
150 |
+
else:
|
151 |
+
if empty_count > 0:
|
152 |
+
fen_row += str(empty_count)
|
153 |
+
empty_count = 0
|
154 |
+
fen_row += cell
|
155 |
+
if empty_count > 0:
|
156 |
+
fen_row += str(empty_count)
|
157 |
+
fen_rows.append(fen_row)
|
158 |
+
|
159 |
+
position_fen = "/".join(fen_rows)
|
160 |
+
move_side = st.radio("Select the side to move:", ["w (White)", "b (Black)"])[0].lower()
|
161 |
+
fen_notation = f"{position_fen} {move_side} - - 0 0"
|
162 |
+
st.subheader("Generated FEN Notation:")
|
163 |
+
st.code(fen_notation)
|
164 |
+
|
165 |
+
# Initialize the Stockfish engine
|
166 |
+
stockfish_path = os.path.join(os.getcwd(), "stockfish-ubuntu-x86-64-sse41-popcnt")
|
167 |
+
stockfish = Stockfish(
|
168 |
+
path=stockfish_path,
|
169 |
+
depth=15,
|
170 |
+
parameters={"Threads": 2, "Minimum Thinking Time": 30}
|
171 |
+
)
|
172 |
+
|
173 |
+
# Predict the next move
|
174 |
+
next_move = predict_next_move(fen_notation, stockfish)
|
175 |
+
st.subheader("Stockfish Recommended Move:")
|
176 |
+
st.write(next_move)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|