HaileyStorm commited on
Commit
5122a87
1 Parent(s): 03a4fe1

Update filter_lichess_multi.py

Browse files
Files changed (1) hide show
  1. filter_lichess_multi.py +31 -26
filter_lichess_multi.py CHANGED
@@ -3,37 +3,34 @@ import chess.pgn
3
  import csv
4
  import os
5
  import threading
 
6
 
7
  start_at = 0
8
  total_games = 92055571
9
  num_threads = 8
10
 
11
- def process_pgn_chunk(input_file, output_file, start_index, end_index):
12
- with open(input_file, 'r') as pgn_file, open(output_file, 'a', newline='') as csv_file:
13
  csv_writer = csv.writer(csv_file)
14
 
15
- file_size = os.stat(pgn_file.fileno()).st_size
16
- pgn_file.seek(int(file_size * (start_index / total_games)))
17
 
18
  games_seen = 0
19
  games_added = 0
20
- while True:
21
- game = chess.pgn.read_game(pgn_file)
22
- if game is None or games_seen >= end_index - start_index:
23
- break
24
  games_seen += 1
25
 
26
  # Filter games based on the specified criteria
27
  if (
28
- game.headers['Result'] == '1-0' and
29
- 'Rated' in game.headers['Event'] and
30
- 1500 < int(game.headers['WhiteElo']) < 2400 and
31
- 1400 < int(game.headers['BlackElo']) < 2800
32
  ):
33
  board = chess.Board()
34
  moves = []
35
  move_number = 1
36
- for move in game.mainline_moves():
37
  if board.turn == chess.WHITE:
38
  moves.append(f"{move_number}.")
39
  move_number += 1
@@ -46,28 +43,36 @@ def process_pgn_chunk(input_file, output_file, start_index, end_index):
46
  csv_writer.writerow([transcript.rstrip()])
47
  games_added += 1
48
  if games_added % 100 == 0:
49
- print(f"Thread {threading.current_thread().name} - Added {games_added} of {games_seen} games.") # {(games_seen+start_index)/float(total_games):.2%} complete.")
 
 
50
 
51
  def process_pgn_file(input_file, output_file):
52
  with open(output_file, 'w', newline='') as csv_file:
53
  csv_writer = csv.writer(csv_file)
54
  csv_writer.writerow(['transcript'])
55
 
56
- chunk_size = (total_games - start_at) // num_threads
57
- threads = []
58
- for i in range(num_threads):
59
- start_index = start_at + i * chunk_size
60
- end_index = start_at + (i + 1) * chunk_size
61
- if i == num_threads - 1:
62
- end_index = total_games
63
- thread = threading.Thread(target=process_pgn_chunk, args=(input_file, f"{output_file[:-4]}_{i}.csv", start_index, end_index))
64
- threads.append(thread)
65
- thread.start()
 
66
 
67
- for thread in threads:
68
- thread.join()
 
 
69
 
 
 
70
 
 
71
  input_file = './chess-mamba-vs-xformer/lichess_db_standard_rated_2022-07.pgn'
72
  output_file = './chess-mamba-vs-xformer/lichess_transcripts_phase2_stable.csv'
73
  process_pgn_file(input_file, output_file)
 
3
  import csv
4
  import os
5
  import threading
6
+ import mmap
7
 
8
  start_at = 0
9
  total_games = 92055571
10
  num_threads = 8
11
 
12
+ def process_pgn_chunk(pgn_data, output_file, start_index, end_index):
13
+ with open(output_file, 'a', newline='') as csv_file:
14
  csv_writer = csv.writer(csv_file)
15
 
16
+ pgn = chess.pgn.read_game(chess.pgn.StringIO(pgn_data.decode('utf-8')))
 
17
 
18
  games_seen = 0
19
  games_added = 0
20
+ while pgn is not None and games_seen < end_index - start_index:
 
 
 
21
  games_seen += 1
22
 
23
  # Filter games based on the specified criteria
24
  if (
25
+ pgn.headers['Result'] == '1-0' and
26
+ 'Rated' in pgn.headers['Event'] and
27
+ 1500 < int(pgn.headers['WhiteElo']) < 2400 and
28
+ 1400 < int(pgn.headers['BlackElo']) < 2800
29
  ):
30
  board = chess.Board()
31
  moves = []
32
  move_number = 1
33
+ for move in pgn.mainline_moves():
34
  if board.turn == chess.WHITE:
35
  moves.append(f"{move_number}.")
36
  move_number += 1
 
43
  csv_writer.writerow([transcript.rstrip()])
44
  games_added += 1
45
  if games_added % 100 == 0:
46
+ print(f"Thread {threading.current_thread().name} - Added {games_added} of {games_seen} games. {(games_seen+start_index)/float(total_games):.2%} complete.")
47
+
48
+ pgn = chess.pgn.read_game(chess.pgn.StringIO(pgn_data.decode('utf-8')))
49
 
50
  def process_pgn_file(input_file, output_file):
51
  with open(output_file, 'w', newline='') as csv_file:
52
  csv_writer = csv.writer(csv_file)
53
  csv_writer.writerow(['transcript'])
54
 
55
+ file_size = os.path.getsize(input_file)
56
+ chunk_size = (file_size - start_at) // num_threads
57
+
58
+ with open(input_file, 'rb') as pgn_file:
59
+ with mmap.mmap(pgn_file.fileno(), 0, access=mmap.ACCESS_READ) as pgn_mmap:
60
+ threads = []
61
+ for i in range(num_threads):
62
+ start_index = start_at + i * chunk_size
63
+ end_index = start_at + (i + 1) * chunk_size
64
+ if i == num_threads - 1:
65
+ end_index = file_size
66
 
67
+ pgn_chunk = pgn_mmap[start_index:end_index]
68
+ thread = threading.Thread(target=process_pgn_chunk, args=(pgn_chunk, f"{output_file[:-4]}_{i}.csv", start_index, end_index))
69
+ threads.append(thread)
70
+ thread.start()
71
 
72
+ for thread in threads:
73
+ thread.join()
74
 
75
+ # Usage example
76
  input_file = './chess-mamba-vs-xformer/lichess_db_standard_rated_2022-07.pgn'
77
  output_file = './chess-mamba-vs-xformer/lichess_transcripts_phase2_stable.csv'
78
  process_pgn_file(input_file, output_file)