diff --git a/fairseq/examples/__pycache__/__init__.cpython-310.pyc b/fairseq/examples/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cfbd835ebd1ae020bc3572debcde47e6dd12aae5 Binary files /dev/null and b/fairseq/examples/__pycache__/__init__.cpython-310.pyc differ diff --git a/fairseq/examples/wav2vec/unsupervised/config/timit_matched/test.uid b/fairseq/examples/wav2vec/unsupervised/config/timit_matched/test.uid new file mode 100644 index 0000000000000000000000000000000000000000..401008246a1bc2cbf309d9d0aa56710f0ff643bc --- /dev/null +++ b/fairseq/examples/wav2vec/unsupervised/config/timit_matched/test.uid @@ -0,0 +1,192 @@ +FDHC0_SI1559 +FDHC0_SI2189 +FDHC0_SI929 +FDHC0_SX119 +FDHC0_SX209 +FDHC0_SX29 +FDHC0_SX299 +FDHC0_SX389 +FELC0_SI1386 +FELC0_SI2016 +FELC0_SI756 +FELC0_SX126 +FELC0_SX216 +FELC0_SX306 +FELC0_SX36 +FELC0_SX396 +FJLM0_SI1043 +FJLM0_SI1673 +FJLM0_SI2303 +FJLM0_SX143 +FJLM0_SX233 +FJLM0_SX323 +FJLM0_SX413 +FJLM0_SX53 +FMGD0_SI1564 +FMGD0_SI2194 +FMGD0_SI934 +FMGD0_SX124 +FMGD0_SX214 +FMGD0_SX304 +FMGD0_SX34 +FMGD0_SX394 +FMLD0_SI2185 +FMLD0_SI822 +FMLD0_SI925 +FMLD0_SX115 +FMLD0_SX205 +FMLD0_SX25 +FMLD0_SX295 +FMLD0_SX385 +FNLP0_SI1308 +FNLP0_SI1938 +FNLP0_SI678 +FNLP0_SX138 +FNLP0_SX228 +FNLP0_SX318 +FNLP0_SX408 +FNLP0_SX48 +FPAS0_SI1272 +FPAS0_SI2204 +FPAS0_SI944 +FPAS0_SX134 +FPAS0_SX224 +FPAS0_SX314 +FPAS0_SX404 +FPAS0_SX44 +FPKT0_SI1538 +FPKT0_SI2168 +FPKT0_SI908 +FPKT0_SX188 +FPKT0_SX278 +FPKT0_SX368 +FPKT0_SX8 +FPKT0_SX98 +MBPM0_SI1577 +MBPM0_SI1584 +MBPM0_SI947 +MBPM0_SX137 +MBPM0_SX227 +MBPM0_SX317 +MBPM0_SX407 +MBPM0_SX47 +MCMJ0_SI1094 +MCMJ0_SI464 +MCMJ0_SI602 +MCMJ0_SX104 +MCMJ0_SX14 +MCMJ0_SX194 +MCMJ0_SX284 +MCMJ0_SX374 +MDAB0_SI1039 +MDAB0_SI1669 +MDAB0_SI2299 +MDAB0_SX139 +MDAB0_SX229 +MDAB0_SX319 +MDAB0_SX409 +MDAB0_SX49 +MGRT0_SI1450 +MGRT0_SI2080 +MGRT0_SI820 +MGRT0_SX10 +MGRT0_SX100 +MGRT0_SX190 +MGRT0_SX280 +MGRT0_SX370 +MJDH0_SI1354 +MJDH0_SI1984 +MJDH0_SI724 +MJDH0_SX184 +MJDH0_SX274 +MJDH0_SX364 +MJDH0_SX4 +MJDH0_SX94 +MJLN0_SI1449 +MJLN0_SI2079 +MJLN0_SI819 +MJLN0_SX189 +MJLN0_SX279 +MJLN0_SX369 +MJLN0_SX9 +MJLN0_SX99 +MJMP0_SI1535 +MJMP0_SI1791 +MJMP0_SI905 +MJMP0_SX185 +MJMP0_SX275 +MJMP0_SX365 +MJMP0_SX5 +MJMP0_SX95 +MKLT0_SI1213 +MKLT0_SI1843 +MKLT0_SI583 +MKLT0_SX133 +MKLT0_SX223 +MKLT0_SX313 +MKLT0_SX403 +MKLT0_SX43 +MLLL0_SI1363 +MLLL0_SI1993 +MLLL0_SI733 +MLLL0_SX103 +MLLL0_SX13 +MLLL0_SX193 +MLLL0_SX283 +MLLL0_SX373 +MLNT0_SI1574 +MLNT0_SI1902 +MLNT0_SI642 +MLNT0_SX102 +MLNT0_SX12 +MLNT0_SX192 +MLNT0_SX282 +MLNT0_SX372 +MNJM0_SI1580 +MNJM0_SI2210 +MNJM0_SI950 +MNJM0_SX140 +MNJM0_SX230 +MNJM0_SX320 +MNJM0_SX410 +MNJM0_SX50 +MPAM0_SI1189 +MPAM0_SI1819 +MPAM0_SI1961 +MPAM0_SX109 +MPAM0_SX19 +MPAM0_SX199 +MPAM0_SX289 +MPAM0_SX379 +MTAS1_SI1473 +MTAS1_SI2098 +MTAS1_SI838 +MTAS1_SX118 +MTAS1_SX208 +MTAS1_SX28 +MTAS1_SX298 +MTAS1_SX388 +MTLS0_SI1370 +MTLS0_SI2000 +MTLS0_SI740 +MTLS0_SX110 +MTLS0_SX20 +MTLS0_SX200 +MTLS0_SX290 +MTLS0_SX380 +MWBT0_SI1553 +MWBT0_SI2183 +MWBT0_SI923 +MWBT0_SX113 +MWBT0_SX203 +MWBT0_SX23 +MWBT0_SX293 +MWBT0_SX383 +MWEW0_SI1361 +MWEW0_SI1991 +MWEW0_SI731 +MWEW0_SX101 +MWEW0_SX11 +MWEW0_SX191 +MWEW0_SX281 +MWEW0_SX371 diff --git a/fairseq/examples/wav2vec/unsupervised/config/timit_unmatched/train_text.uid b/fairseq/examples/wav2vec/unsupervised/config/timit_unmatched/train_text.uid new file mode 100644 index 0000000000000000000000000000000000000000..0e0c2517c9415ce76d5863781f621402cd15b911 --- /dev/null +++ b/fairseq/examples/wav2vec/unsupervised/config/timit_unmatched/train_text.uid @@ -0,0 +1,1000 @@ +FAEM0_SI762 +FAEM0_SX42 +FAJW0_SA1 +FAJW0_SX3 +FAJW0_SX93 +FALK0_SX186 +FALK0_SX6 +FALR0_SI1325 +FBAS0_SA1 +FBAS0_SX217 +FBCG1_SA1 +FBCG1_SX172 +FBCG1_SX442 +FBCH0_SX236 +FBCH0_SX416 +FBLV0_SA1 +FBLV0_SI1058 +FBLV0_SX338 +FBLV0_SX68 +FBMH0_SA1 +FBMJ0_SI815 +FCAG0_SA1 +FCAG0_SX153 +FCAG0_SX243 +FCAJ0_SI1479 +FCAJ0_SX309 +FCDR1_SX106 +FCDR1_SX196 +FCEG0_SA2 +FCJF0_SA1 +FCJF0_SX127 +FCJS0_SI1607 +FCJS0_SI2237 +FCJS0_SX257 +FCKE0_SA2 +FCKE0_SX121 +FCLT0_SI2068 +FCLT0_SX448 +FCLT0_SX88 +FCMG0_SA2 +FCMG0_SI1872 +FCMG0_SX72 +FCMM0_SA1 +FCMM0_SA2 +FCMM0_SX183 +FCRZ0_SI2053 +FCRZ0_SX433 +FCYL0_SA1 +FCYL0_SX37 +FDAS1_SI2091 +FDAS1_SX201 +FDAS1_SX381 +FDAW0_SI1406 +FDFB0_SA1 +FDFB0_SA2 +FDFB0_SI2010 +FDFB0_SX58 +FDJH0_SX305 +FDML0_SA2 +FDML0_SX159 +FDML0_SX249 +FDML0_SX429 +FDMY0_SA2 +FDMY0_SX27 +FDNC0_SX198 +FDNC0_SX288 +FDTD0_SX211 +FDXW0_SA1 +FDXW0_SX251 +FDXW0_SX341 +FDXW0_SX71 +FEAC0_SX165 +FEAC0_SX75 +FEAR0_SI622 +FECD0_SX68 +FEEH0_SA1 +FEEH0_SI1742 +FEEH0_SI471 +FEEH0_SX122 +FEME0_SA1 +FEME0_SX155 +FEME0_SX65 +FETB0_SA1 +FETB0_SI1148 +FETB0_SX158 +FEXM0_SI1101 +FGCS0_SX136 +FGCS0_SX226 +FGCS0_SX316 +FGCS0_SX406 +FGDP0_SA1 +FGMB0_SI1775 +FGMB0_SX245 +FHLM0_SX390 +FHXS0_SA2 +FHXS0_SX445 +FJDM2_SA1 +FJDM2_SX232 +FJDM2_SX52 +FJHK0_SX302 +FJKL0_SX212 +FJKL0_SX392 +FJLG0_SI2306 +FJLR0_SA1 +FJRP1_SI2062 +FJRP1_SX82 +FJSK0_SA1 +FJSP0_SX264 +FJSP0_SX354 +FJSP0_SX444 +FJWB1_SA1 +FJWB1_SX345 +FJWB1_SX435 +FJXM0_SA1 +FJXM0_SI581 +FJXM0_SX401 +FJXP0_SA1 +FJXP0_SI1122 +FJXP0_SX132 +FKAA0_SX128 +FKAA0_SX398 +FKDE0_SA1 +FKDE0_SX151 +FKDE0_SX241 +FKDE0_SX421 +FKDE0_SX61 +FKDW0_SX397 +FKFB0_SA2 +FKFB0_SX348 +FKFB0_SX78 +FKKH0_SA1 +FKKH0_SA2 +FKKH0_SX120 +FKKH0_SX390 +FKLC0_SX355 +FKLC1_SI2308 +FKLC1_SX238 +FKLC1_SX328 +FKLC1_SX418 +FKLH0_SA2 +FKLH0_SX177 +FKSR0_SA1 +FKSR0_SA2 +FKSR0_SI1747 +FKSR0_SI487 +FKSR0_SX217 +FLAC0_SX451 +FLAG0_SA2 +FLAG0_SX114 +FLAG0_SX204 +FLAG0_SX24 +FLAG0_SX384 +FLEH0_SI1681 +FLEH0_SI2311 +FLEH0_SX331 +FLET0_SA1 +FLHD0_SI1827 +FLHD0_SX354 +FLJA0_SA1 +FLJA0_SI2338 +FLJD0_SI886 +FLJD0_SX76 +FLJG0_SA2 +FLKM0_SA2 +FLKM0_SI686 +FLKM0_SX260 +FLKM0_SX80 +FLMA0_SA1 +FLMA0_SI613 +FLMA0_SX433 +FLMA0_SX73 +FLMC0_SX22 +FLMK0_SI1035 +FLMK0_SX315 +FLMK0_SX405 +FLOD0_SI1917 +FLOD0_SX117 +FLOD0_SX171 +FLOD0_SX297 +FLTM0_SA1 +FLTM0_SI1070 +FLTM0_SI2330 +FMAH1_SA2 +FMAH1_SX159 +FMBG0_SA2 +FMBG0_SI2264 +FMEM0_SI747 +FMEM0_SX387 +FMJB0_SI547 +FMJB0_SX97 +FMJF0_SA2 +FMJU0_SX309 +FMJU0_SX399 +FMKC0_SI1702 +FMKC0_SX442 +FMKC0_SX82 +FMKF0_SX186 +FMPG0_SA2 +FNKL0_SI1522 +FNTB0_SI1203 +FNTB0_SI573 +FNTB0_SX303 +FPAB1_SI1471 +FPAB1_SX211 +FPAC0_SA2 +FPAD0_SA2 +FPAD0_SX356 +FPAD0_SX86 +FPAF0_SA2 +FPAF0_SX154 +FPAZ0_SA1 +FPAZ0_SA2 +FPAZ0_SX243 +FPJF0_SA1 +FPJF0_SX146 +FPJF0_SX56 +FPLS0_SI1590 +FPLS0_SX330 +FPMY0_SA1 +FPMY0_SX343 +FREH0_SA1 +FREH0_SA2 +FREH0_SX415 +FRJB0_SX347 +FRLL0_SX434 +FSAG0_SA1 +FSAG0_SX243 +FSAH0_SA1 +FSAH0_SA2 +FSAH0_SX164 +FSAH0_SX434 +FSBK0_SA2 +FSBK0_SI1069 +FSBK0_SX169 +FSCN0_SA2 +FSCN0_SI626 +FSCN0_SX266 +FSCN0_SX446 +FSCN0_SX86 +FSDC0_SA2 +FSDC0_SX142 +FSDC0_SX322 +FSDC0_SX52 +FSDJ0_SI485 +FSDJ0_SX215 +FSDJ0_SX305 +FSDJ0_SX395 +FSGF0_SX117 +FSJG0_SX130 +FSJK1_SA2 +FSJK1_SX125 +FSJK1_SX35 +FSJS0_SX181 +FSJW0_SI1963 +FSJW0_SX433 +FSKC0_SI1416 +FSKC0_SI786 +FSKC0_SX246 +FSKL0_SI1529 +FSKL0_SX449 +FSKP0_SA2 +FSLS0_SX156 +FSLS0_SX426 +FSMA0_SA2 +FSMA0_SX181 +FSMM0_SX144 +FSMM0_SX234 +FSMS1_SX244 +FSMS1_SX347 +FSPM0_SA2 +FSPM0_SX161 +FSPM0_SX71 +FSRH0_SI1931 +FSRH0_SI671 +FSRH0_SX221 +FSRH0_SX401 +FTAJ0_SI699 +FTAJ0_SX159 +FTAJ0_SX249 +FTAJ0_SX429 +FTBR0_SX21 +FTBW0_SA1 +FTMG0_SI1532 +FTMG0_SI2162 +FTMG0_SX452 +FVFB0_SA2 +FVFB0_SX132 +FVFB0_SX42 +FVKB0_SA1 +FVMH0_SA2 +FVMH0_SX116 +FVMH0_SX26 +MABC0_SI1620 +MABC0_SI2041 +MABC0_SI781 +MADC0_SX107 +MADC0_SX377 +MADD0_SA2 +MADD0_SI1295 +MADD0_SX178 +MADD0_SX268 +MADD0_SX88 +MAEB0_SX450 +MAEO0_SA1 +MAFM0_SI939 +MAFM0_SX129 +MAFM0_SX309 +MAJP0_SA2 +MAKB0_SI1646 +MAKB0_SX26 +MAKB0_SX386 +MAKR0_SX362 +MAKR0_SX92 +MAPV0_SX213 +MARC0_SA2 +MARC0_SX108 +MARC0_SX18 +MARC0_SX198 +MARW0_SI1906 +MBAR0_SA1 +MBAR0_SX419 +MBAR0_SX59 +MBBR0_SI2315 +MBBR0_SX65 +MBCG0_SA1 +MBCG0_SI486 +MBEF0_SI1281 +MBEF0_SI1911 +MBEF0_SI651 +MBEF0_SX21 +MBEF0_SX381 +MBGT0_SA2 +MBGT0_SX261 +MBGT0_SX351 +MBGT0_SX441 +MBJV0_SA1 +MBJV0_SI617 +MBJV0_SX347 +MBMA0_SI592 +MBMA0_SX232 +MBMA0_SX52 +MBMA1_SI2214 +MBMA1_SX54 +MBML0_SA2 +MBML0_SI1169 +MBML0_SX89 +MBOM0_SA2 +MBOM0_SI2274 +MBOM0_SX294 +MBSB0_SA1 +MBSB0_SX3 +MBTH0_SA2 +MBTH0_SX122 +MBTH0_SX32 +MCAE0_SX277 +MCAL0_SA2 +MCAL0_SI1768 +MCDC0_SA1 +MCDC0_SX212 +MCDD0_SA2 +MCDD0_SI883 +MCDD0_SX253 +MCDD0_SX433 +MCDR0_SI1154 +MCEF0_SX235 +MCEF0_SX415 +MCEW0_SA2 +MCHL0_SX87 +MCLK0_SX310 +MCLM0_SA1 +MCLM0_SI2086 +MCLM0_SI826 +MCPM0_SA1 +MCPM0_SX114 +MCPM0_SX294 +MCPM0_SX384 +MCSS0_SI750 +MCTH0_SA1 +MCTH0_SX39 +MCXM0_SX91 +MDAC0_SA1 +MDAC0_SX181 +MDAC0_SX361 +MDAS0_SX6 +MDBB1_SX106 +MDBB1_SX16 +MDBB1_SX376 +MDBP0_SX168 +MDCD0_SI1415 +MDCD0_SX245 +MDCD0_SX425 +MDCM0_SX40 +MDCM0_SX400 +MDDC0_SI2049 +MDDC0_SI789 +MDDC0_SX159 +MDDC0_SX69 +MDED0_SA1 +MDED0_SA2 +MDEF0_SX123 +MDEF0_SX303 +MDHL0_SI1439 +MDHL0_SX269 +MDHL0_SX449 +MDHS0_SA1 +MDHS0_SA2 +MDHS0_SI1530 +MDHS0_SI2160 +MDJM0_SX105 +MDJM0_SX15 +MDKS0_SX436 +MDLB0_SA2 +MDLC0_SX405 +MDLC1_SA2 +MDLC1_SI2065 +MDLC1_SI2144 +MDLC1_SX445 +MDLC2_SI2244 +MDLC2_SX354 +MDLH0_SA2 +MDLM0_SI1234 +MDLM0_SI1864 +MDLM0_SX154 +MDLM0_SX424 +MDLR0_SA1 +MDLR0_SA2 +MDLR0_SI1863 +MDLR0_SI603 +MDLR0_SX153 +MDLR1_SA1 +MDLR1_SA2 +MDMA0_SI1430 +MDMA0_SX260 +MDMA0_SX80 +MDMT0_SA1 +MDMT0_SA2 +MDMT0_SI1832 +MDMT0_SX122 +MDMT0_SX32 +MDNS0_SA2 +MDNS0_SI2271 +MDNS0_SX201 +MDNS0_SX21 +MDPB0_SX416 +MDPK0_SI1053 +MDPK0_SX333 +MDPK0_SX423 +MDPS0_SI719 +MDPS0_SX359 +MDRD0_SA1 +MDRD0_SX32 +MDSJ0_SI2092 +MDSS0_SA2 +MDSS0_SX441 +MDSS1_SA1 +MDSS1_SI1327 +MDSS1_SI697 +MDSS1_SX157 +MDSS1_SX67 +MDTB0_SI1200 +MDTB0_SI1830 +MDTB0_SX120 +MDWD0_SA2 +MDWD0_SX270 +MDWD0_SX90 +MDWH0_SX215 +MDWH0_SX305 +MDWM0_SA1 +MDWM0_SA2 +MDWM0_SX16 +MDWM0_SX286 +MEAL0_SA2 +MEAL0_SI2177 +MEAL0_SX107 +MEAL0_SX347 +MEDR0_SA1 +MEDR0_SA2 +MEDR0_SI1374 +MEFG0_SA1 +MEGJ0_SA2 +MEGJ0_SX257 +MEGJ0_SX3 +MEJL0_SA1 +MEJL0_SX152 +MEJL0_SX242 +MEJS0_SI610 +MEJS0_SX160 +MEJS0_SX340 +MESG0_SX432 +MESJ0_SX187 +MESJ0_SX97 +MEWM0_SI718 +MEWM0_SX178 +MEWM0_SX88 +MFER0_SI862 +MFER0_SX142 +MFRM0_SX345 +MFRM0_SX435 +MFWK0_SI1879 +MFWK0_SX169 +MFXS0_SX54 +MFXV0_SA2 +MFXV0_SX105 +MGAF0_SA1 +MGAF0_SX22 +MGAF0_SX382 +MGAG0_SA2 +MGAK0_SX226 +MGAK0_SX46 +MGAR0_SX132 +MGAW0_SI535 +MGAW0_SX175 +MGES0_SA1 +MGES0_SI2111 +MGES0_SI851 +MGJC0_SA2 +MGJC0_SX75 +MGRL0_SI2127 +MGRL0_SI867 +MGRL0_SX147 +MGRP0_SA2 +MGSH0_SA2 +MGSH0_SI1806 +MGSH0_SX127 +MGSH0_SX276 +MGSH0_SX6 +MGSL0_SA1 +MGSL0_SI534 +MGSL0_SX264 +MGXP0_SX187 +MGXP0_SX7 +MHBS0_SX315 +MHBS0_SX45 +MHIT0_SA1 +MHJB0_SA1 +MHJB0_SI1017 +MHMG0_SX195 +MHMR0_SA1 +MHMR0_SI489 +MHRM0_SA1 +MHRM0_SI958 +MHRM0_SX148 +MHRM0_SX58 +MHXL0_SI1772 +MHXL0_SX242 +MILB0_SA2 +MJAC0_SX307 +MJAC0_SX71 +MJAE0_SX174 +MJAI0_SA1 +MJAI0_SA2 +MJBG0_SX62 +MJDA0_SI1031 +MJDA0_SX311 +MJDE0_SI463 +MJDG0_SA2 +MJDG0_SI1042 +MJDG0_SI1705 +MJDM0_SA1 +MJDM0_SI974 +MJEB0_SI656 +MJEB0_SX296 +MJEB1_SA2 +MJEB1_SX207 +MJEB1_SX387 +MJEE0_SA1 +MJEE0_SX247 +MJEE0_SX337 +MJFH0_SA2 +MJFH0_SI1107 +MJFR0_SX75 +MJHI0_SA1 +MJHI0_SX158 +MJJB0_SA1 +MJJB0_SX239 +MJJJ0_SX443 +MJJM0_SA2 +MJJM0_SI827 +MJJM0_SX107 +MJKR0_SA1 +MJKR0_SI571 +MJLB0_SX176 +MJLG1_SX292 +MJLS0_SX106 +MJMA0_SA1 +MJMA0_SA2 +MJMD0_SA2 +MJMD0_SX308 +MJMD0_SX38 +MJMM0_SX85 +MJPG0_SI1191 +MJPG0_SX111 +MJPG0_SX201 +MJPG0_SX21 +MJPM0_SA2 +MJPM0_SX378 +MJPM1_SI2280 +MJPM1_SX401 +MJRA0_SA1 +MJRA0_SA2 +MJRA0_SI1236 +MJRA0_SI1866 +MJRA0_SX426 +MJRG0_SI1366 +MJRG0_SI1996 +MJRG0_SX376 +MJRH0_SX225 +MJRH1_SA1 +MJRH1_SI514 +MJRH1_SX154 +MJRH1_SX244 +MJRH1_SX424 +MJRK0_SA1 +MJRK0_SA2 +MJRK0_SI1662 +MJRK0_SX160 +MJRK0_SX250 +MJRK0_SX430 +MJRP0_SA1 +MJRP0_SA2 +MJRP0_SX225 +MJSR0_SA1 +MJSR0_SI1424 +MJSR0_SX344 +MJWG0_SA1 +MJWG0_SX265 +MJWS0_SI513 +MJWS0_SX153 +MJWS0_SX63 +MJWT0_SA1 +MJWT0_SX121 +MJWT0_SX211 +MJWT0_SX301 +MJWT0_SX31 +MJWT0_SX391 +MJXA0_SX427 +MJXL0_SI542 +MKAG0_SA1 +MKAG0_SX259 +MKAJ0_SA2 +MKAJ0_SX154 +MKAM0_SA1 +MKAM0_SX146 +MKAM0_SX326 +MKAM0_SX56 +MKDB0_SA1 +MKDB0_SA2 +MKDB0_SX152 +MKDD0_SA2 +MKES0_SA1 +MKES0_SI1253 +MKES0_SI1883 +MKES0_SX173 +MKJO0_SI1517 +MKJO0_SI887 +MKJO0_SX437 +MKLN0_SI968 +MKLN0_SX248 +MKLR0_SA2 +MKLR0_SI1689 +MKLS0_SA1 +MKLS0_SX357 +MKLS0_SX87 +MKLS1_SA1 +MKLS1_SA2 +MKLS1_SX375 +MKLW0_SA1 +MKRG0_SX411 +MKXL0_SA2 +MKXL0_SX15 +MKXL0_SX375 +MLBC0_SA1 +MLBC0_SI1869 +MLBC0_SX249 +MLEL0_SA1 +MLEL0_SA2 +MLEL0_SI1246 +MLEL0_SX256 +MLEL0_SX436 +MLJC0_SX145 +MLJC0_SX415 +MLJH0_SX64 +MLNS0_SI2037 +MMAA0_SA1 +MMAA0_SA2 +MMAA0_SX35 +MMAB1_SI1494 +MMAB1_SX234 +MMAG0_SA2 +MMAG0_SI1126 +MMAG0_SX316 +MMAM0_SI2227 +MMAM0_SX157 +MMAM0_SX427 +MMAR0_SX256 +MMBS0_SI1781 +MMCC0_SA2 +MMDB0_SX177 +MMDG0_SA1 +MMDG0_SA2 +MMDG0_SI520 +MMDG0_SX160 +MMDG0_SX250 +MMDM0_SI1941 +MMDM0_SI681 +MMDM0_SX141 +MMDM1_SA2 +MMDM1_SI2043 +MMDM1_SX423 +MMDM1_SX63 +MMDS0_SA1 +MMEA0_SA1 +MMEA0_SX128 +MMEA0_SX398 +MMEB0_SA2 +MMEB0_SX187 +MMEB0_SX367 +MMGC0_SA2 +MMGC0_SX135 +MMGC0_SX225 +MMGG0_SX269 +MMGK0_SX332 +MMGK0_SX62 +MMJB1_SA2 +MMRP0_SA2 +MMRP0_SX144 +MMSM0_SX116 +MMSM0_SX206 +MMVP0_SA1 +MMVP0_SA2 +MMWB0_SI989 +MMWB0_SX89 +MMWS0_SA2 +MMWS0_SX168 +MMWS0_SX348 +MMWS0_SX438 +MMWS1_SI1701 +MMXS0_SI2136 +MMXS0_SX246 +MMXS0_SX426 +MNET0_SI816 +MNET0_SX6 +MNTW0_SA2 +MNTW0_SX168 +MNTW0_SX78 +MPAR0_SI2206 +MPAR0_SI946 +MPAR0_SX136 +MPAR0_SX316 +MPEB0_SI1034 +MPEB0_SI1860 +MPEB0_SX240 +MPEB0_SX330 +MPFU0_SI628 +MPFU0_SX448 +MPGH0_SX114 +MPGH0_SX24 +MPGR0_SX240 +MPGR0_SX330 +MPGR1_SX149 +MPPC0_SA1 +MPRD0_SA1 +MPRD0_SX261 +MPRD0_SX351 +MPRD0_SX441 +MPRD0_SX81 +MPRK0_SI1727 +MPRK0_SX107 +MPRK0_SX377 +MPRT0_SA1 +MPRT0_SX310 +MPSW0_SI1067 +MPSW0_SX167 +MPSW0_SX437 +MRAB1_SX128 +MRAB1_SX308 +MRAI0_SA1 +MRAI0_SA2 +MRAI0_SX72 +MRAM0_SA1 +MRAM0_SA2 +MRAM0_SX15 +MRBC0_SI1859 +MRBC0_SX329 +MRBC0_SX419 +MRCG0_SI798 +MRCG0_SX168 +MRCW0_SA1 +MRCW0_SX291 +MRDD0_SI1680 +MRDD0_SX150 +MRDD0_SX277 +MRDD0_SX60 +MRDM0_SI1595 +MRDM0_SX65 +MRDS0_SA1 +MREE0_SX24 +MREH1_SX249 +MREH1_SX69 +MREM0_SA2 +MREW1_SI870 +MRFK0_SX446 +MRFL0_SA1 +MRFL0_SX256 +MRFL0_SX436 +MRFL0_SX76 +MRGM0_SA2 +MRGM0_SX262 +MRGS0_SA2 +MRGS0_SX186 +MRHL0_SI885 +MRHL0_SX345 +MRHL0_SX435 +MRJB1_SA1 +MRJB1_SA2 +MRJB1_SX210 +MRJB1_SX30 +MRJB1_SX390 +MRJH0_SA2 +MRJH0_SX307 +MRJH0_SX79 +MRJM0_SX148 +MRJM1_SA2 +MRJM1_SI1298 +MRJM1_SI1928 +MRJM1_SX128 +MRJT0_SA2 +MRJT0_SI1498 +MRJT0_SX328 +MRJT0_SX418 +MRKM0_SA2 +MRKM0_SX367 +MRLD0_SA2 +MRLD0_SI2224 +MRLD0_SX154 +MRLD0_SX424 +MRLJ0_SA1 +MRLJ0_SX250 +MRLJ0_SX340 +MRLJ1_SA1 +MRLJ1_SA2 +MRLJ1_SX321 +MRLK0_SI843 +MRLK0_SX123 +MRLK0_SX213 +MRMB0_SA2 +MRMB0_SI1581 +MRMB0_SX411 +MRMG0_SA1 +MRMG0_SI1080 +MRMG0_SX450 +MRMH0_SI1349 +MRMH0_SI2281 +MRMH0_SX121 +MRML0_SA2 +MRML0_SX341 +MRPC1_SI2112 +MRRE0_SA2 +MRRE0_SX164 +MRRE0_SX344 +MRRE0_SX74 +MRSO0_SX129 +MRSO0_SX39 +MRSP0_SX259 +MRTC0_SX378 +MRVG0_SI1140 +MRVG0_SX240 +MRWA0_SI973 +MRWA0_SX163 +MRWA0_SX73 +MRWS0_SI1732 +MRWS0_SI472 +MRWS0_SX22 +MRWS0_SX382 +MRXB0_SA2 +MRXB0_SX415 +MSAH1_SI1679 +MSAS0_SX116 +MSAS0_SX206 +MSAS0_SX386 +MSAT0_SA1 +MSAT1_SX263 +MSAT1_SX443 +MSAT1_SX83 +MSDB0_SX197 +MSDB0_SX287 +MSDB0_SX377 +MSDH0_SI2240 +MSDH0_SX440 +MSDH0_SX80 +MSDS0_SA1 +MSEM1_SI1440 +MSEM1_SX180 +MSEM1_SX270 +MSES0_SI1589 +MSES0_SX239 +MSES0_SX419 +MSFH0_SX316 +MSFV0_SI1892 +MSFV0_SX362 +MSFV0_SX92 +MSMR0_SX415 +MSMS0_SA1 +MSMS0_SX173 +MSMS0_SX83 +MSRG0_SA1 +MSRG0_SI1221 +MSTF0_SI766 +MSTF0_SX316 +MSTF0_SX46 +MSVS0_SA2 +MSVS0_SX308 +MTAS0_SX215 +MTAS0_SX35 +MTAS0_SX395 +MTAT0_SX390 +MTAT1_SX59 +MTBC0_SI1803 +MTCS0_SA2 +MTCS0_SI2265 +MTCS0_SX82 +MTDP0_SA2 +MTER0_SA2 +MTER0_SI1787 +MTJG0_SA1 +MTJG0_SI2157 +MTJG0_SX260 +MTJM0_SI1856 +MTJM0_SX146 +MTJU0_SX130 +MTJU0_SX400 +MTKD0_SX107 +MTKD0_SX287 +MTKP0_SI1023 +MTLB0_SA1 +MTLB0_SX234 +MTLC0_SA1 +MTML0_SI2325 +MTML0_SX165 +MTMN0_SA2 +MTMN0_SI1064 +MTMN0_SI2324 +MTMN0_SX434 +MTMT0_SA2 +MTMT0_SI1748 +MTPF0_SX65 +MTPG0_SI1383 +MTPG0_SI753 +MTPG0_SX303 +MTPP0_SX338 +MTPR0_SX340 +MTQC0_SI480 +MTQC0_SX91 +MTRR0_SX198 +MTRR0_SX288 +MTRT0_SA2 +MTRT0_SX254 +MTRT0_SX57 +MTWH1_SX72 +MTXS0_SA1 +MTXS0_SA2 +MVJH0_SI926 +MVJH0_SX206 +MVJH0_SX296 +MVLO0_SA1 +MVRW0_SA2 +MVRW0_SX135 +MVRW0_SX225 +MWAC0_SA2 +MWAC0_SX341 +MWAC0_SX431 +MWAD0_SX432 +MWAD0_SX72 +MWAR0_SA1 +MWAR0_SI1675 +MWCH0_SI1895 +MWCH0_SI2252 +MWCH0_SX182 +MWCH0_SX452 +MWDK0_SA1 +MWDK0_SA2 +MWDK0_SI2017 +MWDK0_SI806 +MWDK0_SX176 +MWDK0_SX86 +MWEM0_SA2 +MWEM0_SI1320 +MWEM0_SI1393 +MWEM0_SX150 +MWGR0_SX346 +MWRE0_SX247 +MWRE0_SX337 +MWRE0_SX427 +MWRP0_SA1 +MWRP0_SX273 +MWRP0_SX363 +MWSB0_SX276 +MWSH0_SX256 +MWSH0_SX76 +MZMB0_SA1 diff --git a/fairseq/examples/wav2vec/unsupervised/config/timit_unmatched/valid.uid b/fairseq/examples/wav2vec/unsupervised/config/timit_unmatched/valid.uid new file mode 100644 index 0000000000000000000000000000000000000000..e99edfe937854a5f47a2f0384f0e067487336883 --- /dev/null +++ b/fairseq/examples/wav2vec/unsupervised/config/timit_unmatched/valid.uid @@ -0,0 +1,620 @@ +FAEM0_SI1392 +FAJW0_SI1263 +FAJW0_SI633 +FALK0_SI658 +FALR0_SX335 +FAPB0_SI1063 +FAPB0_SI2323 +FAPB0_SX433 +FBAS0_SI1472 +FBAS0_SI2066 +FBCG1_SX352 +FBCH0_SI959 +FBJL0_SI922 +FBLV0_SI1688 +FBMH0_SI1136 +FBMH0_SI970 +FBMJ0_SA1 +FBMJ0_SI1776 +FBMJ0_SI516 +FBMJ0_SX336 +FCDR1_SI1186 +FCDR1_SI1816 +FCDR1_SI556 +FCDR1_SX286 +FCKE0_SI1741 +FCKE0_SI481 +FCLT0_SI808 +FCMG0_SI1142 +FCMG0_SX432 +FCMM0_SI1957 +FCMM0_SX420 +FCYL0_SI667 +FCYL0_SX349 +FDAS1_SI1461 +FDAS1_SI831 +FDAW0_SI1271 +FDAW0_SI2036 +FDJH0_SI935 +FDKN0_SI1202 +FDKN0_SX181 +FDKN0_SX451 +FDMY0_SA1 +FDMY0_SI567 +FDMY0_SI714 +FDMY0_SX387 +FDNC0_SI1278 +FDNC0_SI1908 +FDTD0_SA1 +FDTD0_SX321 +FEAC0_SI615 +FEAR0_SX352 +FECD0_SA1 +FECD0_SI1418 +FECD0_SI788 +FEME0_SI875 +FEME0_SX335 +FEXM0_SA1 +FEXM0_SI482 +FEXM0_SX366 +FGDP0_SI988 +FGDP0_SX88 +FGMB0_SI1145 +FGMB0_SX335 +FGRW0_SA1 +FGRW0_SI1152 +FGRW0_SX162 +FGRW0_SX432 +FHLM0_SX120 +FHLM0_SX349 +FHXS0_SA1 +FHXS0_SI1075 +FHXS0_SI2302 +FHXS0_SX175 +FJDM2_SA2 +FJDM2_SX142 +FJEN0_SA1 +FJEN0_SX327 +FJEN0_SX417 +FJHK0_SI2282 +FJKL0_SI932 +FJLG0_SI1889 +FJLR0_SI1231 +FJRB0_SX402 +FJRP1_SA1 +FJRP1_SI1432 +FJRP1_SX262 +FJRP1_SX352 +FJSK0_SI1052 +FJSP0_SI1434 +FJWB1_SI748 +FJXM0_SX311 +FJXM0_SX41 +FJXP0_SI1752 +FKAA0_SA1 +FKDE0_SI1141 +FKDE0_SI1771 +FKDW0_SI1207 +FKDW0_SI1891 +FKFB0_SI1608 +FKFB0_SX438 +FKKH0_SI1290 +FKKH0_SI1920 +FKLC0_SI985 +FKLC0_SX175 +FKLC1_SI1048 +FKLH0_SI1257 +FKSR0_SX366 +FLAC0_SI1339 +FLAG0_SI1464 +FLAG0_SI834 +FLEH0_SI1051 +FLET0_SI507 +FLJA0_SI1078 +FLJA0_SX178 +FLJD0_SI1516 +FLJG0_SI981 +FLJG0_SX171 +FLJG0_SX351 +FLKM0_SA1 +FLKM0_SI620 +FLKM0_SX350 +FLKM0_SX440 +FLMC0_SI1372 +FLMK0_SA1 +FLMK0_SI1229 +FLTM0_SX170 +FLTM0_SX350 +FLTM0_SX440 +FMAH1_SI879 +FMBG0_SI1160 +FMEM0_SA1 +FMEM0_SX333 +FMJB0_SI1177 +FMJF0_SI624 +FMJF0_SX174 +FMJF0_SX84 +FMJU0_SI1389 +FMKC0_SI1041 +FMKF0_SI1018 +FMPG0_SA1 +FMPG0_SI972 +FMPG0_SX162 +FMPG0_SX342 +FMPG0_SX432 +FNKL0_SI892 +FNTB0_SI679 +FPAB1_SA1 +FPAB1_SI2101 +FPAB1_SI841 +FPAC0_SI1921 +FPAC0_SI661 +FPAD0_SI716 +FPAD0_SX176 +FPAF0_SA1 +FPAF0_SI1054 +FPAZ0_SI2223 +FPAZ0_SI963 +FPJF0_SI1259 +FPJF0_SX352 +FPLS0_SI960 +FPMY0_SI1153 +FPMY0_SI523 +FREH0_SI1945 +FRLL0_SI805 +FSAG0_SI1323 +FSAG0_SX153 +FSAG0_SX333 +FSAG0_SX423 +FSAH0_SI614 +FSAH0_SX327 +FSAK0_SI1300 +FSBK0_SX349 +FSCN0_SA1 +FSCN0_SI705 +FSCN0_SX176 +FSDC0_SI1312 +FSDJ0_SI1115 +FSGF0_SI2187 +FSGF0_SI927 +FSJG0_SA1 +FSJG0_SA2 +FSJG0_SI940 +FSJG0_SX220 +FSJG0_SX40 +FSJG0_SX400 +FSJS0_SA1 +FSJS0_SX451 +FSJW0_SI1333 +FSKP0_SI1098 +FSMA0_SI991 +FSMA0_SX451 +FSMM0_SX324 +FSPM0_SI1241 +FSPM0_SX251 +FSRH0_SX311 +FSSB0_SI1712 +FSSB0_SX362 +FTBR0_SI1402 +FTBR0_SI921 +FTBW0_SI715 +FTBW0_SX175 +FTLG0_SI1743 +FTLG0_SI483 +FTMG0_SI902 +FVFB0_SI1510 +FVKB0_SX349 +FVMH0_SI1466 +FVMH0_SI836 +MADC0_SI1367 +MADC0_SI737 +MAEB0_SI1411 +MAEO0_SI1326 +MAJP0_SI1704 +MAJP0_SX174 +MAKB0_SA2 +MAKB0_SI1016 +MAKB0_SI2276 +MAKB0_SX116 +MAPV0_SI1293 +MAPV0_SI663 +MARW0_SX286 +MARW0_SX349 +MBBR0_SI1055 +MBBR0_SX335 +MBCG0_SI957 +MBCG0_SX327 +MBGT0_SI1841 +MBGT0_SX171 +MBMA0_SI1222 +MBMA1_SI954 +MBMA1_SX324 +MBTH0_SI2102 +MBWP0_SX349 +MCAE0_SI1447 +MCAE0_SI2077 +MCAE0_SI817 +MCAL0_SI1138 +MCDR0_SI1784 +MCDR0_SI524 +MCEF0_SI842 +MCEW0_SA1 +MCEW0_SI2072 +MCEW0_SI812 +MCEW0_SX362 +MCEW0_SX452 +MCHL0_SI1347 +MCHL0_SI1404 +MCLK0_SI2290 +MCLK0_SI650 +MCPM0_SI1824 +MCSS0_SI1380 +MCSS0_SI688 +MCTM0_SI1350 +MCTM0_SI1980 +MDAC0_SI631 +MDAS0_SI1896 +MDAS0_SI636 +MDBP0_SI528 +MDBP0_SX438 +MDCD0_SI785 +MDCD0_SX335 +MDCM0_SI1480 +MDDC0_SI1419 +MDED0_SI540 +MDEF0_SI1123 +MDEM0_SA1 +MDEM0_SI608 +MDEM0_SI800 +MDEM0_SX428 +MDHS0_SI900 +MDJM0_SI1455 +MDKS0_SX166 +MDKS0_SX346 +MDLB0_SI1306 +MDLB0_SX136 +MDLB0_SX406 +MDLC0_SI1395 +MDLC0_SI2025 +MDLC1_SI1435 +MDLH0_SX160 +MDLH0_SX430 +MDLM0_SI604 +MDLR0_SX333 +MDLR1_SI669 +MDMA0_SX170 +MDMA0_SX350 +MDMA0_SX440 +MDNS0_SI1011 +MDNS0_SI873 +MDPB0_SI1760 +MDPB0_SI866 +MDRD0_SI752 +MDSJ0_SI1462 +MDSJ0_SX438 +MDWD0_SI1260 +MDWH0_SA1 +MDWH0_SI1168 +MDWH0_SI665 +MDWM0_SI916 +MEDR0_SI2004 +MEFG0_SI491 +MEFG0_SI598 +MEGJ0_SA1 +MEGJ0_SI1337 +MEGJ0_SI707 +MEGJ0_SX167 +MEJS0_SI1240 +MESG0_SI702 +MESJ0_SI2039 +MFWK0_SX349 +MFXS0_SX324 +MFXV0_SI1005 +MFXV0_SI1342 +MGAF0_SI1282 +MGAG0_SI691 +MGAK0_SI1036 +MGAK0_SX136 +MGAR0_SX312 +MGAW0_SI1165 +MGES0_SX311 +MGJC0_SX435 +MGRL0_SX327 +MGRP0_SI1317 +MGRP0_SX327 +MGSH0_SI1176 +MGSH0_SI546 +MGSL0_SI797 +MGXP0_SI1087 +MGXP0_SI525 +MHBS0_SI945 +MHIT0_SI983 +MHMG0_SI735 +MHMR0_SI1692 +MILB0_SI903 +MJAC0_SI701 +MJAC0_SX251 +MJAE0_SX84 +MJAI0_SI682 +MJAI0_SI710 +MJDC0_SI531 +MJDE0_SA1 +MJDE0_SI1120 +MJDE0_SI490 +MJDE0_SX220 +MJDM0_SI1340 +MJDM0_SX170 +MJDM0_SX350 +MJEB0_SX170 +MJEB1_SI1467 +MJEB1_SI837 +MJFR0_SA1 +MJFR0_SX435 +MJHI0_SI1328 +MJJJ0_SI1163 +MJJM0_SI1251 +MJLB0_SI1616 +MJLS0_SI1726 +MJMA0_SI2125 +MJMD0_SI2288 +MJMM0_SI1255 +MJMM0_SX175 +MJPG0_SI1821 +MJPM0_SI1368 +MJPM1_SX311 +MJRA0_SX336 +MJRG0_SI736 +MJRG0_SX352 +MJRH0_SI1840 +MJRH1_SI1558 +MJRK0_SI880 +MJRP0_SI1845 +MJSR0_SI2054 +MJSR0_SI794 +MJWG0_SI813 +MJWG0_SI895 +MJWG0_SX175 +MJWS0_SX333 +MJWT0_SI1291 +MJWT0_SI1381 +MJXL0_SI1172 +MKAG0_SI979 +MKAH0_SX178 +MKAM0_SI1250 +MKAM0_SI1465 +MKDD0_SI1567 +MKDD0_SI2197 +MKDD0_SI937 +MKDT0_SI814 +MKES0_SI623 +MKLS0_SI1437 +MKLS0_SI2067 +MKLS1_SI915 +MKLW0_SI1571 +MKLW0_SX311 +MKRG0_SI861 +MKXL0_SI1815 +MKXL0_SI1958 +MLBC0_SI1239 +MLEL0_SI616 +MLEL0_SX166 +MLJC0_SI1225 +MLJH0_SA1 +MLJH0_SA2 +MLJH0_SI1422 +MLJH0_SI694 +MLJH0_SX244 +MLSH0_SI1417 +MLSH0_SX247 +MMAA0_SI1588 +MMAA0_SI845 +MMAB1_SI864 +MMAB1_SX324 +MMAG0_SA1 +MMAG0_SI1756 +MMAG0_SX136 +MMAR0_SI1966 +MMAR0_SX166 +MMAR0_SX346 +MMBS0_SI521 +MMBS0_SX161 +MMCC0_SI1338 +MMDB0_SI987 +MMDG0_SI1780 +MMDM0_SI1311 +MMDM1_SX153 +MMDM1_SX333 +MMEB0_SX327 +MMGC0_SI1305 +MMGG0_SI1079 +MMGG0_SX449 +MMLM0_SI2150 +MMPM0_SX161 +MMRP0_SX324 +MMSM0_SI1106 +MMSM0_SI476 +MMVP0_SI654 +MMVP0_SX347 +MMWB0_SA1 +MMWB0_SI2249 +MMWB0_SX359 +MMWB0_SX449 +MNTW0_SI1068 +MNTW0_SI1698 +MPEB0_SI600 +MPFU0_SI1258 +MPGH0_SI675 +MPGR0_SI1410 +MPGR1_SI1499 +MPMB0_SA1 +MPMB0_SA2 +MPMB0_SI1501 +MPMB0_SI2131 +MPMB0_SI871 +MPMB0_SX151 +MPMB0_SX331 +MPMB0_SX421 +MPMB0_SX61 +MPPC0_SI1412 +MPRB0_SI1215 +MPRB0_SI575 +MPRD0_SI801 +MPRD0_SX171 +MPRK0_SA1 +MPRK0_SI1097 +MPRK0_SI467 +MPRK0_SX287 +MRAB0_SI1854 +MRAB1_SI848 +MRAI0_SI2052 +MRAI0_SI792 +MRAI0_SX432 +MRAM0_SI1951 +MRCG0_SA2 +MRCG0_SI1428 +MRCG0_SX348 +MRCG0_SX438 +MRCW0_SI741 +MRDM0_SI1044 +MRDM0_SX335 +MREE0_SI1104 +MREE0_SI1959 +MREH1_SA1 +MREH1_SI1599 +MREH1_SI969 +MREM0_SI511 +MRFK0_SI1076 +MRFL0_SI1156 +MRFL0_SI526 +MRFL0_SX166 +MRGM0_SI532 +MRGM0_SX172 +MRGM0_SX442 +MRGS0_SI1356 +MRGS0_SI726 +MRGS0_SX6 +MRJB1_SI1413 +MRJB1_SI2021 +MRJB1_SX120 +MRJH0_SI1519 +MRJH0_SI889 +MRJH0_SX169 +MRJT0_SI868 +MRJT0_SX58 +MRKM0_SI1267 +MRKM0_SI1391 +MRKM0_SI637 +MRLJ0_SI790 +MRLJ1_SI2301 +MRLK0_SI1468 +MRLR0_SI1196 +MRML0_SA1 +MRML0_SI1421 +MRML0_SX161 +MRML0_SX251 +MRMS0_SI2057 +MRRE0_SA1 +MRRE0_SI1334 +MRRE0_SI952 +MRSO0_SI1206 +MRSP0_SI1429 +MRTC0_SI1458 +MRTJ0_SA1 +MRTJ0_SI772 +MRTJ0_SX142 +MRTJ0_SX232 +MRTJ0_SX52 +MRWS0_SI1102 +MRXB0_SI2215 +MRXB0_SI955 +MSAS0_SI1376 +MSAS0_SI746 +MSDH0_SI980 +MSDH0_SX170 +MSDS0_SI1077 +MSDS0_SX267 +MSDS0_SX357 +MSEM1_SI2070 +MSEM1_SI810 +MSFH0_SA1 +MSFH0_SI1738 +MSFH0_SX136 +MSFH0_SX406 +MSFV0_SI632 +MSJK0_SI1596 +MSJK0_SX336 +MSMC0_SI509 +MSMR0_SI1150 +MSMS0_SI1433 +MSRR0_SI1761 +MSRR0_SI501 +MSTF0_SI852 +MSVS0_SI2198 +MSVS0_SI938 +MSVS0_SX398 +MTAB0_SI1572 +MTAB0_SX312 +MTAT0_SA1 +MTAT0_SI1110 +MTAT0_SI811 +MTAT1_SI779 +MTAT1_SX149 +MTAT1_SX329 +MTBC0_SI543 +MTCS0_SI712 +MTDB0_SI1401 +MTDB0_SI771 +MTDP0_SA1 +MTDP0_SI1521 +MTDP0_SX171 +MTDP0_SX351 +MTER0_SA1 +MTER0_SI1157 +MTER0_SX437 +MTJG0_SX170 +MTJS0_SA2 +MTJS0_SI1822 +MTJS0_SI562 +MTJS0_SX382 +MTJU0_SI2020 +MTKD0_SI630 +MTKP0_SI2283 +MTKP0_SI454 +MTLB0_SI1134 +MTLB0_SX324 +MTLC0_SI1313 +MTLC0_SI1477 +MTML0_SX435 +MTMN0_SI582 +MTMT0_SI488 +MTPP0_SI1508 +MTPR0_SI2230 +MTPR0_SX160 +MTPR0_SX430 +MTQC0_SA1 +MTQC0_SI1441 +MTQC0_SX181 +MTQC0_SX451 +MTRC0_SI589 +MTRR0_SI918 +MTRT0_SI1227 +MTXS0_SI1060 +MTXS0_SI2320 +MTXS0_SX160 +MTXS0_SX430 +MVJH0_SI1556 +MVLO0_SI517 +MWAC0_SI1601 +MWAC0_SX161 +MWAC0_SX251 +MWAR0_SI1045 +MWDK0_SI1436 +MWEM0_SX420 +MWRE0_SA2 +MWRE0_SI1057 +MWRE0_SX67 +MWRP0_SI1443 +MWSB0_SI996 +MWSH0_SI1426 +MWSH0_SI796 +MWSH0_SX166 diff --git a/fairseq/examples/wav2vec/unsupervised/data/__init__.py b/fairseq/examples/wav2vec/unsupervised/data/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d0545627efc9a6f9bb180e351ead519a2cb6dea7 --- /dev/null +++ b/fairseq/examples/wav2vec/unsupervised/data/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from .extracted_features_dataset import ExtractedFeaturesDataset +from .random_input_dataset import RandomInputDataset + + +__all__ = [ + "ExtractedFeaturesDataset", + "RandomInputDataset", +] diff --git a/fairseq/examples/wav2vec/unsupervised/data/extracted_features_dataset.py b/fairseq/examples/wav2vec/unsupervised/data/extracted_features_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..7f7a58c0e5e8ed1a36745aa4e8ad34462a8d9fc1 --- /dev/null +++ b/fairseq/examples/wav2vec/unsupervised/data/extracted_features_dataset.py @@ -0,0 +1,167 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + + +import logging +import os +import contextlib + +import numpy as np +import torch + +from fairseq.data import FairseqDataset, data_utils + + +logger = logging.getLogger(__name__) + + +class ExtractedFeaturesDataset(FairseqDataset): + def __init__( + self, + path, + split, + min_length=3, + max_length=None, + labels=None, + label_dict=None, + shuffle=True, + sort_by_length=True, + aux_target_postfix=None, + ): + super().__init__() + + self.min_length = min_length + self.max_length = max_length + self.shuffle = shuffle + self.sort_by_length = sort_by_length + self.label_dict = label_dict + + if labels is not None: + assert label_dict is not None + + self.sizes = [] + self.offsets = [] + self.labels = [] + self.aux_tgt = None + + path = os.path.join(path, split) + data_path = path + self.data = np.load(data_path + ".npy", mmap_mode="r") + + offset = 0 + skipped = 0 + + if not os.path.exists(path + f".{labels}"): + labels = None + + with open(data_path + ".lengths", "r") as len_f, open( + path + f".{labels}", "r" + ) if labels is not None else contextlib.ExitStack() as lbl_f: + for line in len_f: + length = int(line.rstrip()) + lbl = None if labels is None else next(lbl_f).rstrip().split() + if length >= min_length and ( + max_length is None or length <= max_length + ): + self.sizes.append(length) + self.offsets.append(offset) + if lbl is not None: + self.labels.append(lbl) + offset += length + + self.sizes = np.asarray(self.sizes) + self.offsets = np.asarray(self.offsets) + + if aux_target_postfix is not None: + if not os.path.exists(path+f".{aux_target_postfix}"): + logger.info(f"auxaliry target for {split} missing") + else: + with open(path+f".{aux_target_postfix}", "r") as t_f: + self.aux_tgt = [ + torch.LongTensor(list(map(int,seg.strip().split())))\ + for seg in t_f] + + logger.info(f"loaded {len(self.offsets)}, skipped {skipped} samples") + + def __getitem__(self, index): + offset = self.offsets[index] + end = self.sizes[index] + offset + feats = torch.from_numpy(self.data[offset:end].copy()).float() + + res = {"id": index, "features": feats} + if len(self.labels) > 0: + res["target"] = self.label_dict.encode_line( + self.labels[index], + line_tokenizer=lambda x: x, + append_eos=False, + ) + + if self.aux_tgt: + res["aux_target"] = self.aux_tgt[index] + + return res + + def __len__(self): + return len(self.sizes) + + def collater(self, samples): + if len(samples) == 0: + return {} + + features = [s["features"] for s in samples] + sizes = [len(s) for s in features] + + target_size = max(sizes) + + collated_features = features[0].new_zeros( + len(features), target_size, features[0].size(-1) + ) + padding_mask = torch.BoolTensor(collated_features.shape[:-1]).fill_(False) + for i, (f, size) in enumerate(zip(features, sizes)): + collated_features[i, :size] = f + padding_mask[i, size:] = True + + res = { + "id": torch.LongTensor([s["id"] for s in samples]), + "net_input": {"features": collated_features, "padding_mask": padding_mask}, + } + + if len(self.labels) > 0: + target = data_utils.collate_tokens( + [s["target"] for s in samples], + pad_idx=self.label_dict.pad(), + left_pad=False, + ) + res["target"] = target + + if self.aux_tgt: + idxs = torch.nn.utils.rnn.pad_sequence( + [s["aux_target"] for s in samples], + batch_first=True, + padding_value=-1, + ) + res["net_input"]["aux_target"] = idxs + + return res + + def num_tokens(self, index): + return self.size(index) + + def size(self, index): + return self.sizes[index] + + def ordered_indices(self): + """Return an ordered list of indices. Batches will be constructed based + on this order.""" + if self.shuffle: + order = [np.random.permutation(len(self))] + else: + order = [np.arange(len(self))] + + if self.sort_by_length: + order.append(self.sizes) + return np.lexsort(order)[::-1] + else: + return order[0] diff --git a/fairseq/examples/wav2vec/unsupervised/kaldi_self_train/st/cmd.sh b/fairseq/examples/wav2vec/unsupervised/kaldi_self_train/st/cmd.sh new file mode 100644 index 0000000000000000000000000000000000000000..e74953194d41f0d93855d41b2acef08556d92477 --- /dev/null +++ b/fairseq/examples/wav2vec/unsupervised/kaldi_self_train/st/cmd.sh @@ -0,0 +1,15 @@ +# you can change cmd.sh depending on what type of queue you are using. +# If you have no queueing system and want to run on a local machine, you +# can change all instances 'queue.pl' to run.pl (but be careful and run +# commands one by one: most recipes will exhaust the memory on your +# machine). queue.pl works with GridEngine (qsub). slurm.pl works +# with slurm. Different queues are configured differently, with different +# queue names and different ways of specifying things like memory; +# to account for these differences you can create and edit the file +# conf/queue.conf to match your queue's configuration. Search for +# conf/queue.conf in http://kaldi-asr.org/doc/queue.html for more information, +# or search for the string 'default_config' in utils/queue.pl or utils/slurm.pl. + +export train_cmd="run.pl --mem 2G" +export decode_cmd="run.pl --mem 4G" +export mkgraph_cmd="run.pl --mem 8G" diff --git a/fairseq/examples/wav2vec/unsupervised/kaldi_self_train/st/local/prepare_lang.sh b/fairseq/examples/wav2vec/unsupervised/kaldi_self_train/st/local/prepare_lang.sh new file mode 100644 index 0000000000000000000000000000000000000000..e9a80001eb47d5af863d6aab11a59362a59cef61 --- /dev/null +++ b/fairseq/examples/wav2vec/unsupervised/kaldi_self_train/st/local/prepare_lang.sh @@ -0,0 +1,37 @@ +#!/bin/bash + +sil_prob=0.5 +num_sil_states=3 +num_nonsil_states=1 + +. ./cmd.sh +. ./path.sh +. parse_options.sh + +set -eux + +dict=$1 +data_dir=$2 + +dict_dir=$data_dir/local/dict +tmplm_dir=$data_dir/local/lang_tmp +lm_dir=$data_dir/lang + +mkdir -p $dict_dir $tmplm_dir $lm_dir + +# prepare dict +echo "SIL" > $dict_dir/silence_phones.txt +echo "SIL" > $dict_dir/optional_silence.txt +awk '{print $1}' $dict > $dict_dir/nonsilence_phones.txt + +echo "SIL SIL" > $dict_dir/lexicon.txt +echo " SIL" >> $dict_dir/lexicon.txt +awk '{print $1" "$1}' $dict >> $dict_dir/lexicon.txt + +echo "SIL" > $dict_dir/extra_questions.txt +awk '{printf $1" "} END {printf "\n"}' $dict >> $dict_dir/extra_questions.txt + +# prepare lang +utils/prepare_lang.sh --sil-prob $sil_prob --position-dependent-phones false \ + --num_sil_states $num_sil_states --num_nonsil_states $num_nonsil_states \ + $dict_dir "" $tmplm_dir $lm_dir diff --git a/fairseq/examples/wav2vec/unsupervised/kaldi_self_train/st/local/prepare_lm.sh b/fairseq/examples/wav2vec/unsupervised/kaldi_self_train/st/local/prepare_lm.sh new file mode 100644 index 0000000000000000000000000000000000000000..c2edcefede2da3b6a991b9c8fbc78c96d46d27cb --- /dev/null +++ b/fairseq/examples/wav2vec/unsupervised/kaldi_self_train/st/local/prepare_lm.sh @@ -0,0 +1,35 @@ +#!/usr/bin/env bash + +langdir="" +lmdir="" + +. ./cmd.sh +. ./path.sh +. parse_options.sh + +arpa_lm=$1 +data=$2 + +if [ -z $langdir ]; then + langdir=$data/lang +fi +if [ -z $lmdir ]; then + lmdir=$data/lang_test +fi + +if [ ! -d $langdir ]; then + echo "$langdir not found. run local/prepare_lang.sh first" && exit 1 +fi + +mkdir -p $lmdir +cp -r $langdir/* $lmdir + +if [[ "$arpa_lm" == *.gz ]]; then + gunzip -c $arpa_lm | arpa2fst --disambig-symbol=#0 --read-symbol-table=$lmdir/words.txt - $lmdir/G.fst +else + arpa2fst --disambig-symbol=#0 --read-symbol-table=$lmdir/words.txt $arpa_lm $lmdir/G.fst +fi +fstisstochastic $lmdir/G.fst +utils/validate_lang.pl $lmdir || exit 1 + +echo "done preparing lm ($lmdir)" diff --git a/fairseq/examples/wav2vec/unsupervised/kaldi_self_train/st/local/score.sh b/fairseq/examples/wav2vec/unsupervised/kaldi_self_train/st/local/score.sh new file mode 100644 index 0000000000000000000000000000000000000000..cb5bbb7277bfb9f2d5440da0514bf7b16da8140d --- /dev/null +++ b/fairseq/examples/wav2vec/unsupervised/kaldi_self_train/st/local/score.sh @@ -0,0 +1,63 @@ +#!/usr/bin/env bash +# Copyright 2012 Johns Hopkins University (Author: Daniel Povey) +# 2014 Guoguo Chen +# Apache 2.0 + +[ -f ./path.sh ] && . ./path.sh + +# begin configuration section. +cmd=run.pl +stage=0 +decode_mbr=true +word_ins_penalty=0.0,0.5,1.0 +min_lmwt=7 +max_lmwt=17 +iter=final +#end configuration section. + +[ -f ./path.sh ] && . ./path.sh +. parse_options.sh || exit 1; + +if [ $# -ne 3 ]; then + echo "Usage: local/score.sh [--cmd (run.pl|queue.pl...)] " + echo " Options:" + echo " --cmd (run.pl|queue.pl...) # specify how to run the sub-processes." + echo " --stage (0|1|2) # start scoring script from part-way through." + echo " --decode_mbr (true/false) # maximum bayes risk decoding (confusion network)." + echo " --min_lmwt # minumum LM-weight for lattice rescoring " + echo " --max_lmwt # maximum LM-weight for lattice rescoring " + exit 1; +fi + +data=$1 +lang_or_graph=$2 +dir=$3 + +symtab=$lang_or_graph/words.txt + +for f in $symtab $dir/lat.1.gz $data/text; do + [ ! -f $f ] && echo "score.sh: no such file $f" && exit 1; +done + +mkdir -p $dir/scoring/log + +cat $data/text | sed 's:::g' | sed 's:::g' > $dir/scoring/test_filt.txt + +for wip in $(echo $word_ins_penalty | sed 's/,/ /g'); do + $cmd LMWT=$min_lmwt:$max_lmwt $dir/scoring/log/best_path.LMWT.$wip.log \ + lattice-scale --inv-acoustic-scale=LMWT "ark:gunzip -c $dir/lat.*.gz|" ark:- \| \ + lattice-add-penalty --word-ins-penalty=$wip ark:- ark:- \| \ + lattice-best-path --word-symbol-table=$symtab \ + ark:- ark,t:$dir/scoring/LMWT.$wip.tra || exit 1; +done + +# Note: the double level of quoting for the sed command +for wip in $(echo $word_ins_penalty | sed 's/,/ /g'); do + $cmd LMWT=$min_lmwt:$max_lmwt $dir/scoring/log/score.LMWT.$wip.log \ + cat $dir/scoring/LMWT.$wip.tra \| \ + utils/int2sym.pl -f 2- $symtab \| sed 's:\::g' \| \ + compute-wer --text --mode=present \ + ark:$dir/scoring/test_filt.txt ark,p:- ">&" $dir/wer_LMWT_$wip || exit 1; +done + +exit 0; diff --git a/fairseq/examples/wav2vec/unsupervised/kaldi_self_train/st/local/show_wer.sh b/fairseq/examples/wav2vec/unsupervised/kaldi_self_train/st/local/show_wer.sh new file mode 100644 index 0000000000000000000000000000000000000000..9ecf1690c67f8a019009ef32d973fbd45b56c7ca --- /dev/null +++ b/fairseq/examples/wav2vec/unsupervised/kaldi_self_train/st/local/show_wer.sh @@ -0,0 +1,52 @@ +#!/bin/bash + +split="dev_other" +ref_data="" +get_best_wer=true +dec_name="decode" +graph_name="graph" + +. ./cmd.sh +. ./path.sh +. parse_options.sh + +exp_root=$1 + +set -eu + +echo "==== WER w.r.t. pseudo transcript" +for x in $exp_root/*/${dec_name}_${split}*; do grep WER $x/wer_* 2>/dev/null | utils/best_wer.sh; done + + +if [ ! -z $ref_data ]; then + echo "==== WER w.r.t. real transcript (select based on pseudo WER)" + ref_txt=$ref_data/$split/text + for x in $exp_root/*/${dec_name}_${split}*; do + lang=$(dirname $x)/$graph_name + + lmwt=$( + grep WER $x/wer_* 2>/dev/null | utils/best_wer.sh | + sed 's/.*wer_\(.*\)$/\1/g' | sed 's/_/./g' + ) + tra=$x/scoring/$lmwt.tra + cat $tra | utils/int2sym.pl -f 2- $lang/words.txt | sed 's:::g' | sed 's:::g' | \ + compute-wer --text --mode=present \ + ark:$ref_txt ark,p:- 2> /dev/null | grep WER | xargs -I{} echo {} $tra + done +fi + +if [ ! -z $ref_data ] && $get_best_wer; then + echo "==== WER w.r.t. real transcript (select based on true WER)" + ref_txt=$ref_data/$split/text + for x in $exp_root/*/${dec_name}_${split}*; do + lang=$(dirname $x)/$graph_name + + for tra in $x/scoring/*.tra; do + cat $tra | utils/int2sym.pl -f 2- $lang/words.txt | sed 's:::g' | sed 's:::g' | \ + compute-wer --text --mode=present \ + ark:$ref_txt ark,p:- 2> /dev/null | grep WER | xargs -I{} echo {} $tra + done | sort -k2n | head -n1 + done +fi + +exit 0; diff --git a/fairseq/examples/wav2vec/unsupervised/kaldi_self_train/st/local/train_subset_lgbeam.sh b/fairseq/examples/wav2vec/unsupervised/kaldi_self_train/st/local/train_subset_lgbeam.sh new file mode 100644 index 0000000000000000000000000000000000000000..913c1d8e4357c146026b86e78f0b16f921776441 --- /dev/null +++ b/fairseq/examples/wav2vec/unsupervised/kaldi_self_train/st/local/train_subset_lgbeam.sh @@ -0,0 +1,129 @@ +#!/usr/bin/env bash + +out_root=/tmp +out_name=train_${RANDOM} +num_nonsil_states=1 + +valid="dev_other" +train="train" +mono_size="-1" # 2000 +tri1_size="-1" # 5000 +tri2b_size="-1" # 10000 +tri3b_size="-1" # 10000 + +# Acoustic model parameters +numLeavesTri1=2000 +numGaussTri1=10000 +numLeavesMLLT=2500 +numGaussMLLT=15000 +numLeavesSAT=2500 +numGaussSAT=15000 + +stage=1 +max_stage=1 + +. ./cmd.sh +. ./path.sh +. parse_options.sh + +data=$1 +lang=$2 +lang_test=$3 + +exp_root=$out_root/$out_name + +# you might not want to do this for interactive shells. +set -e + + +if [ $stage -le 1 ] && [ $max_stage -ge 1 ]; then + # train a monophone system + if [ ! $mono_size -eq -1 ]; then + utils/subset_data_dir.sh $data/$train $mono_size $data/${train}_${mono_size} + mono_train=${train}_${mono_size} + else + mono_train=${train} + fi + + steps/train_mono.sh --boost-silence 1.25 --nj 20 --cmd "$train_cmd" \ + --initial-beam 40 --regular-beam 60 --retry-beam 120 \ + $data/$mono_train $lang $exp_root/mono + + utils/mkgraph.sh $lang_test $exp_root/mono $exp_root/mono/graph + steps/decode.sh --nj 20 --cmd "$decode_cmd" \ + $exp_root/mono/graph $data/$valid $exp_root/mono/decode_$valid & +fi + + +if [ $stage -le 2 ] && [ $max_stage -ge 2 ]; then + # train a first delta + delta-delta triphone system on a subset of 5000 utterances + if [ ! $tri1_size -eq -1 ]; then + utils/subset_data_dir.sh $data/$train $tri1_size $data/${train}_${tri1_size} + tri1_train=${train}_${tri1_size} + else + tri1_train=${train} + fi + + steps/align_si.sh --boost-silence 1.25 --nj 10 --cmd "$train_cmd" \ + $data/$tri1_train $lang \ + $exp_root/mono $exp_root/mono_ali_${tri1_train} + + steps_gan/train_deltas.sh --boost-silence 1.25 --cmd "$train_cmd" \ + --num_nonsil_states $num_nonsil_states $numLeavesTri1 $numGaussTri1 \ + $data/$tri1_train $lang \ + $exp_root/mono_ali_${tri1_train} $exp_root/tri1 + + utils/mkgraph.sh $lang_test $exp_root/tri1 $exp_root/tri1/graph + steps/decode.sh --nj 20 --cmd "$decode_cmd" \ + $exp_root/tri1/graph $data/$valid $exp_root/tri1/decode_$valid & +fi + +if [ $stage -le 3 ] && [ $max_stage -ge 3 ]; then + # train an LDA+MLLT system. + if [ ! $tri2b_size -eq -1 ]; then + utils/subset_data_dir.sh $data/$train $tri2b_size $data/${train}_${tri2b_size} + tri2b_train=${train}_${tri2b_size} + else + tri2b_train=${train} + fi + + steps/align_si.sh --nj 10 --cmd "$train_cmd" \ + $data/$tri2b_train $lang \ + $exp_root/tri1 $exp_root/tri1_ali_${tri2b_train} + + steps_gan/train_lda_mllt.sh --cmd "$train_cmd" \ + --num_nonsil_states $num_nonsil_states \ + --splice-opts "--left-context=3 --right-context=3" $numLeavesMLLT $numGaussMLLT \ + $data/$tri2b_train $lang \ + $exp_root/tri1_ali_${tri2b_train} $exp_root/tri2b + + utils/mkgraph.sh $lang_test $exp_root/tri2b $exp_root/tri2b/graph + steps/decode.sh --nj 20 --cmd "$decode_cmd" \ + $exp_root/tri2b/graph $data/$valid $exp_root/tri2b/decode_$valid & +fi + + +if [ $stage -le 4 ] && [ $max_stage -ge 4 ]; then + # Train tri3b, which is LDA+MLLT+SAT on 10k utts + if [ ! $tri3b_size -eq -1 ]; then + utils/subset_data_dir.sh $data/$train $tri3b_size $data/${train}_${tri3b_size} + tri3b_train=${train}_${tri3b_size} + else + tri3b_train=${train} + fi + + steps/align_si.sh --nj 10 --cmd "$train_cmd" --use-graphs true \ + $data/$tri3b_train $lang \ + $exp_root/tri2b $exp_root/tri2b_ali_${tri2b_train} + + steps_gan/train_sat.sh --cmd "$train_cmd" \ + --num_nonsil_states $num_nonsil_states $numLeavesSAT $numGaussSAT \ + $data/$tri3b_train $lang \ + $exp_root/tri2b_ali_${tri2b_train} $exp_root/tri3b + + utils/mkgraph.sh $lang_test $exp_root/tri3b $exp_root/tri3b/graph + steps/decode_fmllr.sh --nj 20 --cmd "$decode_cmd" \ + $exp_root/tri3b/graph $data/$valid $exp_root/tri3b/decode_$valid & +fi + +wait diff --git a/fairseq/examples/wav2vec/unsupervised/kaldi_self_train/st/local/unsup_select.py b/fairseq/examples/wav2vec/unsupervised/kaldi_self_train/st/local/unsup_select.py new file mode 100644 index 0000000000000000000000000000000000000000..1122c88c1964d8beead63bc8dfe21d41602b83bc --- /dev/null +++ b/fairseq/examples/wav2vec/unsupervised/kaldi_self_train/st/local/unsup_select.py @@ -0,0 +1,135 @@ +""" +Implement unsupervised metric for decoding hyperparameter selection: + $$ alpha * LM_PPL + ViterbitUER(%) * 100 $$ +""" +import argparse +import logging +import math +import sys + +import kenlm +import editdistance +from g2p_en import G2p + +logging.root.setLevel(logging.INFO) +logging.basicConfig(stream=sys.stdout, level=logging.INFO) +logger = logging.getLogger(__name__) + + +def get_parser(): + parser = argparse.ArgumentParser() + parser.add_argument("ref_tra", help="reference pseudo labels") + parser.add_argument("hyp_tra", help="decoded pseudo labels to be assess") + parser.add_argument("--kenlm_path", default="/checkpoint/abaevski/data/speech/libri/librispeech_lm_novox.phnc_o5.bin", help="") + parser.add_argument("--uppercase", action="store_true", help="") + parser.add_argument("--skipwords", default="", help="") + parser.add_argument("--gt_tra", default="", help="ground truth pseudo labels for computing oracle WER") + parser.add_argument("--min_vt_uer", default=0.0, type=float) + parser.add_argument("--phonemize", action="store_true", help="phonemize word hypotheses, used when reference is phone transcript") + parser.add_argument("--phonemize_lexicon", default="", type=str, help="use a lexicon for phonemizing") + return parser + +def load_tra(tra_path): + with open(tra_path, "r") as f: + uid_to_tra = {} + for line in f: + toks = line.rstrip().split() + uid, tra = toks[0], " ".join(toks[1:]) + uid_to_tra[uid] = tra + logger.debug(f"loaded {len(uid_to_tra)} utterances from {tra_path}") + return uid_to_tra + +def load_lex(lex_path): + with open(lex_path, "r") as f: + w2p = {} + for line in f: + w, p = line.rstrip().split(None, 1) + w2p[w] = p.split() + return w2p + +def compute_wer(ref_uid_to_tra, hyp_uid_to_tra, g2p, g2p_dict): + d_cnt = 0 + w_cnt = 0 + w_cnt_h = 0 + for uid in hyp_uid_to_tra: + ref = ref_uid_to_tra[uid].split() + if g2p_dict is not None: + hyp = [] + for word in hyp_uid_to_tra[uid].split(): + if word in g2p_dict: + hyp = hyp + g2p_dict[word] + else: + logger.warning(f"{word} not in g2p_dict") + elif g2p is not None: + hyp = g2p(hyp_uid_to_tra[uid]) + hyp = [p for p in hyp if p != "'" and p != " "] + hyp = [p[:-1] if p[-1].isnumeric() else p for p in hyp] + else: + hyp = hyp_uid_to_tra[uid].split() + logger.debug(( + f"======================\n" + f"HYP: {' '.join(hyp)}\n" + f"REF: {' '.join(ref)}" + )) + d_cnt += editdistance.eval(ref, hyp) + w_cnt += len(ref) + w_cnt_h += len(hyp) + wer = float(d_cnt) / w_cnt + logger.debug(( + f"wer = {wer*100:.2f}%; num. of ref words = {w_cnt}; " + f"num. of hyp words = {w_cnt_h}; num. of sentences = {len(ref_uid_to_tra)}" + )) + return wer + +def compute_lm_ppl(hyp_uid_to_tra, score_fn): + lm_score = 0. + w_cnt = 0 + for hyp in hyp_uid_to_tra.values(): + cur_score = score_fn(hyp) + cur_cnt = len(hyp.split()) + 1 # plus one for + lm_score += cur_score + w_cnt += cur_cnt + logger.debug(( + f"======================\n" + f"score sum/avg = {cur_score:.2f}/{cur_score/cur_cnt:.2f}\n" + f"hyp = {hyp}" + )) + lm_ppl = math.pow(10, -lm_score / w_cnt) + logger.debug(f"lm ppl = {lm_ppl:.2f}; num. of words = {w_cnt}") + return lm_ppl + +def main(): + args = get_parser().parse_args() + logger.debug(f"Args: {args}") + + ref_uid_to_tra = load_tra(args.ref_tra) + hyp_uid_to_tra = load_tra(args.hyp_tra) + assert not bool(set(hyp_uid_to_tra.keys()) - set(ref_uid_to_tra.keys())) + + lm = kenlm.Model(args.kenlm_path) + skipwords = set(args.skipwords.split(",")) + def compute_lm_score(s): + s = " ".join(w for w in s.split() if w not in skipwords) + s = s.upper() if args.uppercase else s + return lm.score(s) + + g2p, g2p_dict = None, None + if args.phonemize: + if args.phonemize_lexicon: + g2p_dict = load_lex(args.phonemize_lexicon) + else: + g2p = G2p() + + wer = compute_wer(ref_uid_to_tra, hyp_uid_to_tra, g2p, g2p_dict) + lm_ppl = compute_lm_ppl(hyp_uid_to_tra, compute_lm_score) + + gt_wer = -math.inf + if args.gt_tra: + gt_uid_to_tra = load_tra(args.gt_tra) + gt_wer = compute_wer(gt_uid_to_tra, hyp_uid_to_tra, None, None) + + score = math.log(lm_ppl) * max(wer, args.min_vt_uer) + logging.info(f"{args.hyp_tra}: score={score:.4f}; wer={wer*100:.2f}%; lm_ppl={lm_ppl:.4f}; gt_wer={gt_wer*100:.2f}%") + +if __name__ == "__main__": + main() diff --git a/fairseq/examples/wav2vec/unsupervised/kaldi_self_train/st/local/unsup_select_decode.sh b/fairseq/examples/wav2vec/unsupervised/kaldi_self_train/st/local/unsup_select_decode.sh new file mode 100644 index 0000000000000000000000000000000000000000..b34c5b6e0688914a53515162f817a93617b609e5 --- /dev/null +++ b/fairseq/examples/wav2vec/unsupervised/kaldi_self_train/st/local/unsup_select_decode.sh @@ -0,0 +1,37 @@ +#!/bin/bash + +split="dev_other" +ref_txt="" # ground truth transcript path +psd_txt="" # pseudo transcript path +get_best_wer=true +dec_name="decode" +graph_name="graph" +kenlm_path=/checkpoint/abaevski/data/speech/libri/librispeech_lm_novox.phnc_o6.bin + +. ./cmd.sh +. ./path.sh +. parse_options.sh + +exp_root=$1 +unsup_args="" +if [ $# -ge 2 ]; then + unsup_args=$2 +fi + +set -eu + +if [ ! -z $ref_txt ] && $get_best_wer; then + echo "==== WER w.r.t. real transcript (select based on unsupervised metric)" + for x in $exp_root/*/${dec_name}_${split}*; do + lang=$(dirname $x)/$graph_name + + ( + for tra in $x/scoring/*.tra; do + cat $tra | utils/int2sym.pl -f 2- $lang/words.txt | sed 's:::g' | sed 's:::g' > $tra.txt + python local/unsup_select.py $psd_txt $tra.txt --kenlm_path $kenlm_path --gt_tra $ref_txt $unsup_args + done 2>/dev/null | grep "score=" | sed 's/=/ /g' | sed 's/;//g' | sort -k3n | head -n1 + ) & + done +fi +wait + diff --git a/fairseq/examples/wav2vec/unsupervised/kaldi_self_train/st/local/unsup_select_decode_word.sh b/fairseq/examples/wav2vec/unsupervised/kaldi_self_train/st/local/unsup_select_decode_word.sh new file mode 100644 index 0000000000000000000000000000000000000000..c10a6b8809b77bca2b2c02df8b8702725bdd51c7 --- /dev/null +++ b/fairseq/examples/wav2vec/unsupervised/kaldi_self_train/st/local/unsup_select_decode_word.sh @@ -0,0 +1,35 @@ +#!/bin/bash + +split="dev_other" +ref_txt="" # ground truth transcript path +psd_txt="" # pseudo transcript path +get_best_wer=true +dec_name="decode" +graph_name="graph" +kenlm_path=/checkpoint/abaevski/data/speech/libri/librispeech_lm_novox.phnc_o6.bin +phonemize_lexicon="" + +. ./cmd.sh +. ./path.sh +. parse_options.sh +. /private/home/wnhsu/unsup_asr/fairseq-py-unsup/env.sh + +exp_root=$1 + +set -eu + +if [ ! -z $ref_txt ] && $get_best_wer; then + echo "==== WER w.r.t. real transcript (select based on unsupervised metric)" + for x in $exp_root/*/${dec_name}_${split}*; do + lang=$(dirname $x)/$graph_name + + for tra in $x/scoring/*.tra; do + cat $tra | utils/int2sym.pl -f 2- $lang/words.txt | sed 's:\::g' > $tra.txt + python local/unsup_select.py $psd_txt $tra.txt \ + --kenlm_path $kenlm_path --gt_tra $ref_txt --phonemize \ + --phonemize_lexicon "$phonemize_lexicon" + done | grep "score=" | sed 's/=/ /g' | sed 's/;//g' | sort -k3n | head -n1 + done +fi + + diff --git a/fairseq/examples/wav2vec/unsupervised/kaldi_self_train/st/steps_gan/train_deltas.sh b/fairseq/examples/wav2vec/unsupervised/kaldi_self_train/st/steps_gan/train_deltas.sh new file mode 100644 index 0000000000000000000000000000000000000000..af68715ab0d87ae40666596d9d877d593684f8e2 --- /dev/null +++ b/fairseq/examples/wav2vec/unsupervised/kaldi_self_train/st/steps_gan/train_deltas.sh @@ -0,0 +1,175 @@ +#!/usr/bin/env bash + +# Copyright 2012 Johns Hopkins University (Author: Daniel Povey) +# Apache 2.0 + +# Begin configuration. +stage=-4 # This allows restarting after partway, when something when wrong. +config= +cmd=run.pl +scale_opts="--transition-scale=1.0 --acoustic-scale=0.1 --self-loop-scale=0.1" +realign_iters="10 20 30"; +num_iters=35 # Number of iterations of training +max_iter_inc=25 # Last iter to increase #Gauss on. +beam=10 +careful=false +retry_beam=40 +boost_silence=1.0 # Factor by which to boost silence likelihoods in alignment +power=0.25 # Exponent for number of gaussians according to occurrence counts +cluster_thresh=-1 # for build-tree control final bottom-up clustering of leaves +norm_vars=false # deprecated. Prefer --cmvn-opts "--norm-vars=true" + # use the option --cmvn-opts "--norm-means=false" +cmvn_opts= +delta_opts= +context_opts= # use"--context-width=5 --central-position=2" for quinphone +num_nonsil_states=3 +# End configuration. + +echo "$0 $@" # Print the command line for logging + +[ -f path.sh ] && . ./path.sh; +. parse_options.sh || exit 1; + +if [ $# != 6 ]; then + echo "Usage: steps/train_deltas.sh " + echo "e.g.: steps/train_deltas.sh 2000 10000 data/train_si84_half data/lang exp/mono_ali exp/tri1" + echo "main options (for others, see top of script file)" + echo " --cmd (utils/run.pl|utils/queue.pl ) # how to run jobs." + echo " --config # config containing options" + echo " --stage # stage to do partial re-run from." + exit 1; +fi + +numleaves=$1 +totgauss=$2 +data=$3 +lang=$4 +alidir=$5 +dir=$6 + +for f in $alidir/final.mdl $alidir/ali.1.gz $data/feats.scp $lang/phones.txt; do + [ ! -f $f ] && echo "train_deltas.sh: no such file $f" && exit 1; +done + +numgauss=$numleaves +incgauss=$[($totgauss-$numgauss)/$max_iter_inc] # per-iter increment for #Gauss +oov=`cat $lang/oov.int` || exit 1; +ciphonelist=`cat $lang/phones/context_indep.csl` || exit 1; +nj=`cat $alidir/num_jobs` || exit 1; +mkdir -p $dir/log +echo $nj > $dir/num_jobs + +utils/lang/check_phones_compatible.sh $lang/phones.txt $alidir/phones.txt || exit 1; +cp $lang/phones.txt $dir || exit 1; + +sdata=$data/split$nj; +split_data.sh $data $nj || exit 1; + + +[ $(cat $alidir/cmvn_opts 2>/dev/null | wc -c) -gt 1 ] && [ -z "$cmvn_opts" ] && \ + echo "$0: warning: ignoring CMVN options from source directory $alidir" +$norm_vars && cmvn_opts="--norm-vars=true $cmvn_opts" +echo $cmvn_opts > $dir/cmvn_opts # keep track of options to CMVN. +[ ! -z $delta_opts ] && echo $delta_opts > $dir/delta_opts + +feats="ark,s,cs:apply-cmvn $cmvn_opts --utt2spk=ark:$sdata/JOB/utt2spk scp:$sdata/JOB/cmvn.scp scp:$sdata/JOB/feats.scp ark:- | add-deltas $delta_opts ark:- ark:- |" + +rm $dir/.error 2>/dev/null + +if [ $stage -le -3 ]; then + echo "$0: accumulating tree stats" + $cmd JOB=1:$nj $dir/log/acc_tree.JOB.log \ + acc-tree-stats $context_opts \ + --ci-phones=$ciphonelist $alidir/final.mdl "$feats" \ + "ark:gunzip -c $alidir/ali.JOB.gz|" $dir/JOB.treeacc || exit 1; + sum-tree-stats $dir/treeacc $dir/*.treeacc 2>$dir/log/sum_tree_acc.log || exit 1; + rm $dir/*.treeacc +fi + +if [ $stage -le -2 ]; then + echo "$0: getting questions for tree-building, via clustering" + # preparing questions, roots file... + cluster-phones --pdf-class-list=$(($num_nonsil_states / 2)) $context_opts \ + $dir/treeacc $lang/phones/sets.int \ + $dir/questions.int 2> $dir/log/questions.log || exit 1; + cat $lang/phones/extra_questions.int >> $dir/questions.int + compile-questions $context_opts $lang/topo $dir/questions.int \ + $dir/questions.qst 2>$dir/log/compile_questions.log || exit 1; + + echo "$0: building the tree" + $cmd $dir/log/build_tree.log \ + build-tree $context_opts --verbose=1 --max-leaves=$numleaves \ + --cluster-thresh=$cluster_thresh $dir/treeacc $lang/phones/roots.int \ + $dir/questions.qst $lang/topo $dir/tree || exit 1; + + $cmd $dir/log/init_model.log \ + gmm-init-model --write-occs=$dir/1.occs \ + $dir/tree $dir/treeacc $lang/topo $dir/1.mdl || exit 1; + if grep 'no stats' $dir/log/init_model.log; then + echo "** The warnings above about 'no stats' generally mean you have phones **" + echo "** (or groups of phones) in your phone set that had no corresponding data. **" + echo "** You should probably figure out whether something went wrong, **" + echo "** or whether your data just doesn't happen to have examples of those **" + echo "** phones. **" + fi + + gmm-mixup --mix-up=$numgauss $dir/1.mdl $dir/1.occs $dir/1.mdl 2>$dir/log/mixup.log || exit 1; + rm $dir/treeacc +fi + +if [ $stage -le -1 ]; then + # Convert the alignments. + echo "$0: converting alignments from $alidir to use current tree" + $cmd JOB=1:$nj $dir/log/convert.JOB.log \ + convert-ali $alidir/final.mdl $dir/1.mdl $dir/tree \ + "ark:gunzip -c $alidir/ali.JOB.gz|" "ark:|gzip -c >$dir/ali.JOB.gz" || exit 1; +fi + +if [ $stage -le 0 ]; then + echo "$0: compiling graphs of transcripts" + $cmd JOB=1:$nj $dir/log/compile_graphs.JOB.log \ + compile-train-graphs --read-disambig-syms=$lang/phones/disambig.int $dir/tree $dir/1.mdl $lang/L.fst \ + "ark:utils/sym2int.pl --map-oov $oov -f 2- $lang/words.txt < $sdata/JOB/text |" \ + "ark:|gzip -c >$dir/fsts.JOB.gz" || exit 1; +fi + +x=1 +while [ $x -lt $num_iters ]; do + echo "$0: training pass $x" + if [ $stage -le $x ]; then + if echo $realign_iters | grep -w $x >/dev/null; then + echo "$0: aligning data" + mdl="gmm-boost-silence --boost=$boost_silence `cat $lang/phones/optional_silence.csl` $dir/$x.mdl - |" + $cmd JOB=1:$nj $dir/log/align.$x.JOB.log \ + gmm-align-compiled $scale_opts --beam=$beam --retry-beam=$retry_beam --careful=$careful "$mdl" \ + "ark:gunzip -c $dir/fsts.JOB.gz|" "$feats" \ + "ark:|gzip -c >$dir/ali.JOB.gz" || exit 1; + fi + $cmd JOB=1:$nj $dir/log/acc.$x.JOB.log \ + gmm-acc-stats-ali $dir/$x.mdl "$feats" \ + "ark,s,cs:gunzip -c $dir/ali.JOB.gz|" $dir/$x.JOB.acc || exit 1; + $cmd $dir/log/update.$x.log \ + gmm-est --mix-up=$numgauss --power=$power \ + --write-occs=$dir/$[$x+1].occs $dir/$x.mdl \ + "gmm-sum-accs - $dir/$x.*.acc |" $dir/$[$x+1].mdl || exit 1; + rm $dir/$x.mdl $dir/$x.*.acc + rm $dir/$x.occs + fi + [ $x -le $max_iter_inc ] && numgauss=$[$numgauss+$incgauss]; + x=$[$x+1]; +done + +rm $dir/final.mdl $dir/final.occs 2>/dev/null +ln -s $x.mdl $dir/final.mdl +ln -s $x.occs $dir/final.occs + +steps/diagnostic/analyze_alignments.sh --cmd "$cmd" $lang $dir + +# Summarize warning messages... +utils/summarize_warnings.pl $dir/log + +steps/info/gmm_dir_info.pl $dir + +echo "$0: Done training system with delta+delta-delta features in $dir" + +exit 0 diff --git a/fairseq/examples/wav2vec/unsupervised/kaldi_self_train/st/steps_gan/train_lda_mllt.sh b/fairseq/examples/wav2vec/unsupervised/kaldi_self_train/st/steps_gan/train_lda_mllt.sh new file mode 100644 index 0000000000000000000000000000000000000000..9d8c319ce848e431ec47a3548156347ae3b50ced --- /dev/null +++ b/fairseq/examples/wav2vec/unsupervised/kaldi_self_train/st/steps_gan/train_lda_mllt.sh @@ -0,0 +1,239 @@ +#!/usr/bin/env bash + +# Copyright 2012 Johns Hopkins University (Author: Daniel Povey) +# +# LDA+MLLT refers to the way we transform the features after computing +# the MFCCs: we splice across several frames, reduce the dimension (to 40 +# by default) using Linear Discriminant Analysis), and then later estimate, +# over multiple iterations, a diagonalizing transform known as MLLT or STC. +# See http://kaldi-asr.org/doc/transform.html for more explanation. +# +# Apache 2.0. + +# Begin configuration. +cmd=run.pl +config= +stage=-5 +scale_opts="--transition-scale=1.0 --acoustic-scale=0.1 --self-loop-scale=0.1" +realign_iters="10 20 30"; +mllt_iters="2 4 6 12"; +num_iters=35 # Number of iterations of training +max_iter_inc=25 # Last iter to increase #Gauss on. +dim=40 +beam=10 +retry_beam=40 +careful=false +boost_silence=1.0 # Factor by which to boost silence likelihoods in alignment +power=0.25 # Exponent for number of gaussians according to occurrence counts +randprune=4.0 # This is approximately the ratio by which we will speed up the + # LDA and MLLT calculations via randomized pruning. +splice_opts= +cluster_thresh=-1 # for build-tree control final bottom-up clustering of leaves +norm_vars=false # deprecated. Prefer --cmvn-opts "--norm-vars=false" +cmvn_opts= +context_opts= # use "--context-width=5 --central-position=2" for quinphone. +# End configuration. +train_tree=true # if false, don't actually train the tree. +use_lda_mat= # If supplied, use this LDA[+MLLT] matrix. +num_nonsil_states=3 + +echo "$0 $@" # Print the command line for logging + +[ -f path.sh ] && . ./path.sh +. parse_options.sh || exit 1; + +if [ $# != 6 ]; then + echo "Usage: steps/train_lda_mllt.sh [options] <#leaves> <#gauss> " + echo " e.g.: steps/train_lda_mllt.sh 2500 15000 data/train_si84 data/lang exp/tri1_ali_si84 exp/tri2b" + echo "Main options (for others, see top of script file)" + echo " --cmd (utils/run.pl|utils/queue.pl ) # how to run jobs." + echo " --config # config containing options" + echo " --stage # stage to do partial re-run from." + exit 1; +fi + +numleaves=$1 +totgauss=$2 +data=$3 +lang=$4 +alidir=$5 +dir=$6 + +for f in $alidir/final.mdl $alidir/ali.1.gz $data/feats.scp $lang/phones.txt; do + [ ! -f $f ] && echo "train_lda_mllt.sh: no such file $f" && exit 1; +done + +numgauss=$numleaves +incgauss=$[($totgauss-$numgauss)/$max_iter_inc] # per-iter #gauss increment +oov=`cat $lang/oov.int` || exit 1; +nj=`cat $alidir/num_jobs` || exit 1; +silphonelist=`cat $lang/phones/silence.csl` || exit 1; +ciphonelist=`cat $lang/phones/context_indep.csl` || exit 1; + +mkdir -p $dir/log + +utils/lang/check_phones_compatible.sh $lang/phones.txt $alidir/phones.txt || exit 1; +cp $lang/phones.txt $dir || exit 1; + +echo $nj >$dir/num_jobs +echo "$splice_opts" >$dir/splice_opts # keep track of frame-splicing options + # so that later stages of system building can know what they were. + + +[ $(cat $alidir/cmvn_opts 2>/dev/null | wc -c) -gt 1 ] && [ -z "$cmvn_opts" ] && \ + echo "$0: warning: ignoring CMVN options from source directory $alidir" +$norm_vars && cmvn_opts="--norm-vars=true $cmvn_opts" +echo $cmvn_opts > $dir/cmvn_opts # keep track of options to CMVN. + +sdata=$data/split$nj; +split_data.sh $data $nj || exit 1; + +splicedfeats="ark,s,cs:apply-cmvn $cmvn_opts --utt2spk=ark:$sdata/JOB/utt2spk scp:$sdata/JOB/cmvn.scp scp:$sdata/JOB/feats.scp ark:- | splice-feats $splice_opts ark:- ark:- |" +# Note: $feats gets overwritten later in the script. +feats="$splicedfeats transform-feats $dir/0.mat ark:- ark:- |" + + + +if [ $stage -le -5 ]; then + if [ -z "$use_lda_mat" ]; then + echo "$0: Accumulating LDA statistics." + rm $dir/lda.*.acc 2>/dev/null + $cmd JOB=1:$nj $dir/log/lda_acc.JOB.log \ + ali-to-post "ark:gunzip -c $alidir/ali.JOB.gz|" ark:- \| \ + weight-silence-post 0.0 $silphonelist $alidir/final.mdl ark:- ark:- \| \ + acc-lda --rand-prune=$randprune $alidir/final.mdl "$splicedfeats" ark,s,cs:- \ + $dir/lda.JOB.acc || exit 1; + est-lda --write-full-matrix=$dir/full.mat --dim=$dim $dir/0.mat $dir/lda.*.acc \ + 2>$dir/log/lda_est.log || exit 1; + rm $dir/lda.*.acc + else + echo "$0: Using supplied LDA matrix $use_lda_mat" + cp $use_lda_mat $dir/0.mat || exit 1; + [ ! -z "$mllt_iters" ] && \ + echo "$0: Warning: using supplied LDA matrix $use_lda_mat but we will do MLLT," && \ + echo " which you might not want; to disable MLLT, specify --mllt-iters ''" && \ + sleep 5 + fi +fi + +cur_lda_iter=0 + +if [ $stage -le -4 ] && $train_tree; then + echo "$0: Accumulating tree stats" + $cmd JOB=1:$nj $dir/log/acc_tree.JOB.log \ + acc-tree-stats $context_opts \ + --ci-phones=$ciphonelist $alidir/final.mdl "$feats" \ + "ark:gunzip -c $alidir/ali.JOB.gz|" $dir/JOB.treeacc || exit 1; + [ `ls $dir/*.treeacc | wc -w` -ne "$nj" ] && echo "$0: Wrong #tree-accs" && exit 1; + $cmd $dir/log/sum_tree_acc.log \ + sum-tree-stats $dir/treeacc $dir/*.treeacc || exit 1; + rm $dir/*.treeacc +fi + + +if [ $stage -le -3 ] && $train_tree; then + echo "$0: Getting questions for tree clustering." + # preparing questions, roots file... + cluster-phones --pdf-class-list=$(($num_nonsil_states / 2)) $context_opts $dir/treeacc $lang/phones/sets.int \ + $dir/questions.int 2> $dir/log/questions.log || exit 1; + cat $lang/phones/extra_questions.int >> $dir/questions.int + compile-questions $context_opts $lang/topo $dir/questions.int \ + $dir/questions.qst 2>$dir/log/compile_questions.log || exit 1; + + echo "$0: Building the tree" + $cmd $dir/log/build_tree.log \ + build-tree $context_opts --verbose=1 --max-leaves=$numleaves \ + --cluster-thresh=$cluster_thresh $dir/treeacc $lang/phones/roots.int \ + $dir/questions.qst $lang/topo $dir/tree || exit 1; +fi + +if [ $stage -le -2 ]; then + echo "$0: Initializing the model" + if $train_tree; then + gmm-init-model --write-occs=$dir/1.occs \ + $dir/tree $dir/treeacc $lang/topo $dir/1.mdl 2> $dir/log/init_model.log || exit 1; + grep 'no stats' $dir/log/init_model.log && echo "This is a bad warning."; + rm $dir/treeacc + else + cp $alidir/tree $dir/ || exit 1; + $cmd JOB=1 $dir/log/init_model.log \ + gmm-init-model-flat $dir/tree $lang/topo $dir/1.mdl \ + "$feats subset-feats ark:- ark:-|" || exit 1; + fi +fi + + +if [ $stage -le -1 ]; then + # Convert the alignments. + echo "$0: Converting alignments from $alidir to use current tree" + $cmd JOB=1:$nj $dir/log/convert.JOB.log \ + convert-ali $alidir/final.mdl $dir/1.mdl $dir/tree \ + "ark:gunzip -c $alidir/ali.JOB.gz|" "ark:|gzip -c >$dir/ali.JOB.gz" || exit 1; +fi + +if [ $stage -le 0 ] && [ "$realign_iters" != "" ]; then + echo "$0: Compiling graphs of transcripts" + $cmd JOB=1:$nj $dir/log/compile_graphs.JOB.log \ + compile-train-graphs --read-disambig-syms=$lang/phones/disambig.int $dir/tree $dir/1.mdl $lang/L.fst \ + "ark:utils/sym2int.pl --map-oov $oov -f 2- $lang/words.txt < $data/split$nj/JOB/text |" \ + "ark:|gzip -c >$dir/fsts.JOB.gz" || exit 1; +fi + + +x=1 +while [ $x -lt $num_iters ]; do + echo Training pass $x + if echo $realign_iters | grep -w $x >/dev/null && [ $stage -le $x ]; then + echo Aligning data + mdl="gmm-boost-silence --boost=$boost_silence `cat $lang/phones/optional_silence.csl` $dir/$x.mdl - |" + $cmd JOB=1:$nj $dir/log/align.$x.JOB.log \ + gmm-align-compiled $scale_opts --beam=$beam --retry-beam=$retry_beam --careful=$careful "$mdl" \ + "ark:gunzip -c $dir/fsts.JOB.gz|" "$feats" \ + "ark:|gzip -c >$dir/ali.JOB.gz" || exit 1; + fi + if echo $mllt_iters | grep -w $x >/dev/null; then + if [ $stage -le $x ]; then + echo "$0: Estimating MLLT" + $cmd JOB=1:$nj $dir/log/macc.$x.JOB.log \ + ali-to-post "ark:gunzip -c $dir/ali.JOB.gz|" ark:- \| \ + weight-silence-post 0.0 $silphonelist $dir/$x.mdl ark:- ark:- \| \ + gmm-acc-mllt --rand-prune=$randprune $dir/$x.mdl "$feats" ark:- $dir/$x.JOB.macc \ + || exit 1; + est-mllt $dir/$x.mat.new $dir/$x.*.macc 2> $dir/log/mupdate.$x.log || exit 1; + gmm-transform-means $dir/$x.mat.new $dir/$x.mdl $dir/$x.mdl \ + 2> $dir/log/transform_means.$x.log || exit 1; + compose-transforms --print-args=false $dir/$x.mat.new $dir/$cur_lda_iter.mat $dir/$x.mat || exit 1; + rm $dir/$x.*.macc + fi + feats="$splicedfeats transform-feats $dir/$x.mat ark:- ark:- |" + cur_lda_iter=$x + fi + + if [ $stage -le $x ]; then + $cmd JOB=1:$nj $dir/log/acc.$x.JOB.log \ + gmm-acc-stats-ali $dir/$x.mdl "$feats" \ + "ark,s,cs:gunzip -c $dir/ali.JOB.gz|" $dir/$x.JOB.acc || exit 1; + $cmd $dir/log/update.$x.log \ + gmm-est --write-occs=$dir/$[$x+1].occs --mix-up=$numgauss --power=$power \ + $dir/$x.mdl "gmm-sum-accs - $dir/$x.*.acc |" $dir/$[$x+1].mdl || exit 1; + rm $dir/$x.mdl $dir/$x.*.acc $dir/$x.occs + fi + [ $x -le $max_iter_inc ] && numgauss=$[$numgauss+$incgauss]; + x=$[$x+1]; +done + +rm $dir/final.{mdl,mat,occs} 2>/dev/null +ln -s $x.mdl $dir/final.mdl +ln -s $x.occs $dir/final.occs +ln -s $cur_lda_iter.mat $dir/final.mat + +steps/diagnostic/analyze_alignments.sh --cmd "$cmd" $lang $dir + +# Summarize warning messages... +utils/summarize_warnings.pl $dir/log + +steps/info/gmm_dir_info.pl $dir + +echo "$0: Done training system with LDA+MLLT features in $dir" + +exit 0 diff --git a/fairseq/examples/wav2vec/unsupervised/kaldi_self_train/st/steps_gan/train_sat.sh b/fairseq/examples/wav2vec/unsupervised/kaldi_self_train/st/steps_gan/train_sat.sh new file mode 100644 index 0000000000000000000000000000000000000000..f75afafb1c4ad04ee71ab8541064ab0477430616 --- /dev/null +++ b/fairseq/examples/wav2vec/unsupervised/kaldi_self_train/st/steps_gan/train_sat.sh @@ -0,0 +1,281 @@ +#!/usr/bin/env bash +# Copyright 2012 Johns Hopkins University (Author: Daniel Povey). Apache 2.0. + + +# This does Speaker Adapted Training (SAT), i.e. train on +# fMLLR-adapted features. It can be done on top of either LDA+MLLT, or +# delta and delta-delta features. If there are no transforms supplied +# in the alignment directory, it will estimate transforms itself before +# building the tree (and in any case, it estimates transforms a number +# of times during training). + + +# Begin configuration section. +stage=-5 +exit_stage=-100 # you can use this to require it to exit at the + # beginning of a specific stage. Not all values are + # supported. +fmllr_update_type=full +cmd=run.pl +scale_opts="--transition-scale=1.0 --acoustic-scale=0.1 --self-loop-scale=0.1" +beam=10 +retry_beam=40 +careful=false +boost_silence=1.0 # Factor by which to boost silence likelihoods in alignment +context_opts= # e.g. set this to "--context-width 5 --central-position 2" for quinphone. +realign_iters="10 20 30"; +fmllr_iters="2 4 6 12"; +silence_weight=0.0 # Weight on silence in fMLLR estimation. +num_iters=35 # Number of iterations of training +max_iter_inc=25 # Last iter to increase #Gauss on. +power=0.2 # Exponent for number of gaussians according to occurrence counts +cluster_thresh=-1 # for build-tree control final bottom-up clustering of leaves +phone_map= +train_tree=true +tree_stats_opts= +cluster_phones_opts= +compile_questions_opts= +# End configuration section. +num_nonsil_states=3 + +echo "$0 $@" # Print the command line for logging + +[ -f path.sh ] && . ./path.sh +. parse_options.sh || exit 1; + +if [ $# != 6 ]; then + echo "Usage: steps/train_sat.sh <#leaves> <#gauss> " + echo " e.g.: steps/train_sat.sh 2500 15000 data/train_si84 data/lang exp/tri2b_ali_si84 exp/tri3b" + echo "Main options (for others, see top of script file)" + echo " --cmd (utils/run.pl|utils/queue.pl ) # how to run jobs." + echo " --config # config containing options" + echo " --stage # stage to do partial re-run from." + exit 1; +fi + +numleaves=$1 +totgauss=$2 +data=$3 +lang=$4 +alidir=$5 +dir=$6 + +for f in $data/feats.scp $lang/phones.txt $alidir/final.mdl $alidir/ali.1.gz; do + [ ! -f $f ] && echo "train_sat.sh: no such file $f" && exit 1; +done + +numgauss=$numleaves +incgauss=$[($totgauss-$numgauss)/$max_iter_inc] # per-iter #gauss increment +oov=`cat $lang/oov.int` +nj=`cat $alidir/num_jobs` || exit 1; +silphonelist=`cat $lang/phones/silence.csl` +ciphonelist=`cat $lang/phones/context_indep.csl` || exit 1; +sdata=$data/split$nj; +splice_opts=`cat $alidir/splice_opts 2>/dev/null` # frame-splicing options. +cmvn_opts=`cat $alidir/cmvn_opts 2>/dev/null` +delta_opts=`cat $alidir/delta_opts 2>/dev/null` +phone_map_opt= +[ ! -z "$phone_map" ] && phone_map_opt="--phone-map='$phone_map'" + +mkdir -p $dir/log +cp $alidir/splice_opts $dir 2>/dev/null # frame-splicing options. +cp $alidir/cmvn_opts $dir 2>/dev/null # cmn/cmvn option. +cp $alidir/delta_opts $dir 2>/dev/null # delta option. + +utils/lang/check_phones_compatible.sh $lang/phones.txt $alidir/phones.txt || exit 1; +cp $lang/phones.txt $dir || exit 1; + +echo $nj >$dir/num_jobs +[[ -d $sdata && $data/feats.scp -ot $sdata ]] || split_data.sh $data $nj || exit 1; + +# Set up features. + +if [ -f $alidir/final.mat ]; then feat_type=lda; else feat_type=delta; fi +echo "$0: feature type is $feat_type" + +## Set up speaker-independent features. +case $feat_type in + delta) sifeats="ark,s,cs:apply-cmvn $cmvn_opts --utt2spk=ark:$sdata/JOB/utt2spk scp:$sdata/JOB/cmvn.scp scp:$sdata/JOB/feats.scp ark:- | add-deltas $delta_opts ark:- ark:- |";; + lda) sifeats="ark,s,cs:apply-cmvn $cmvn_opts --utt2spk=ark:$sdata/JOB/utt2spk scp:$sdata/JOB/cmvn.scp scp:$sdata/JOB/feats.scp ark:- | splice-feats $splice_opts ark:- ark:- | transform-feats $alidir/final.mat ark:- ark:- |" + cp $alidir/final.mat $dir + cp $alidir/full.mat $dir 2>/dev/null + ;; + *) echo "$0: invalid feature type $feat_type" && exit 1; +esac + +## Get initial fMLLR transforms (possibly from alignment dir) +if [ -f $alidir/trans.1 ]; then + echo "$0: Using transforms from $alidir" + feats="$sifeats transform-feats --utt2spk=ark:$sdata/JOB/utt2spk ark,s,cs:$alidir/trans.JOB ark:- ark:- |" + cur_trans_dir=$alidir +else + if [ $stage -le -5 ]; then + echo "$0: obtaining initial fMLLR transforms since not present in $alidir" + # The next line is necessary because of $silphonelist otherwise being incorrect; would require + # old $lang dir which would require another option. Not needed anyway. + [ ! -z "$phone_map" ] && \ + echo "$0: error: you must provide transforms if you use the --phone-map option." && exit 1; + $cmd JOB=1:$nj $dir/log/fmllr.0.JOB.log \ + ali-to-post "ark:gunzip -c $alidir/ali.JOB.gz|" ark:- \| \ + weight-silence-post $silence_weight $silphonelist $alidir/final.mdl ark:- ark:- \| \ + gmm-est-fmllr --fmllr-update-type=$fmllr_update_type \ + --spk2utt=ark:$sdata/JOB/spk2utt $alidir/final.mdl "$sifeats" \ + ark:- ark:$dir/trans.JOB || exit 1; + fi + feats="$sifeats transform-feats --utt2spk=ark:$sdata/JOB/utt2spk ark,s,cs:$dir/trans.JOB ark:- ark:- |" + cur_trans_dir=$dir +fi + +if [ $stage -le -4 ] && $train_tree; then + # Get tree stats. + echo "$0: Accumulating tree stats" + $cmd JOB=1:$nj $dir/log/acc_tree.JOB.log \ + acc-tree-stats $context_opts $tree_stats_opts $phone_map_opt --ci-phones=$ciphonelist $alidir/final.mdl "$feats" \ + "ark:gunzip -c $alidir/ali.JOB.gz|" $dir/JOB.treeacc || exit 1; + [ "`ls $dir/*.treeacc | wc -w`" -ne "$nj" ] && echo "$0: Wrong #tree-accs" && exit 1; + $cmd $dir/log/sum_tree_acc.log \ + sum-tree-stats $dir/treeacc $dir/*.treeacc || exit 1; + rm $dir/*.treeacc +fi + +if [ $stage -le -3 ] && $train_tree; then + echo "$0: Getting questions for tree clustering." + # preparing questions, roots file... + cluster-phones --pdf-class-list=$(($num_nonsil_states / 2)) \ + $cluster_phones_opts $context_opts \ + $dir/treeacc $lang/phones/sets.int $dir/questions.int 2>$dir/log/questions.log || exit 1; + cat $lang/phones/extra_questions.int >> $dir/questions.int + compile-questions $context_opts $compile_questions_opts $lang/topo $dir/questions.int $dir/questions.qst 2>$dir/log/compile_questions.log || exit 1; + + echo "$0: Building the tree" + $cmd $dir/log/build_tree.log \ + build-tree $context_opts --verbose=1 --max-leaves=$numleaves \ + --cluster-thresh=$cluster_thresh $dir/treeacc $lang/phones/roots.int \ + $dir/questions.qst $lang/topo $dir/tree || exit 1; +fi + +if [ $stage -le -2 ]; then + echo "$0: Initializing the model" + if $train_tree; then + gmm-init-model --write-occs=$dir/1.occs \ + $dir/tree $dir/treeacc $lang/topo $dir/1.mdl 2> $dir/log/init_model.log || exit 1; + grep 'no stats' $dir/log/init_model.log && echo "This is a bad warning."; + rm $dir/treeacc + else + cp $alidir/tree $dir/ || exit 1; + $cmd JOB=1 $dir/log/init_model.log \ + gmm-init-model-flat $dir/tree $lang/topo $dir/1.mdl \ + "$feats subset-feats ark:- ark:-|" || exit 1; + fi +fi + +if [ $stage -le -1 ]; then + # Convert the alignments. + echo "$0: Converting alignments from $alidir to use current tree" + $cmd JOB=1:$nj $dir/log/convert.JOB.log \ + convert-ali $phone_map_opt $alidir/final.mdl $dir/1.mdl $dir/tree \ + "ark:gunzip -c $alidir/ali.JOB.gz|" "ark:|gzip -c >$dir/ali.JOB.gz" || exit 1; +fi + +[ "$exit_stage" -eq 0 ] && echo "$0: Exiting early: --exit-stage $exit_stage" && exit 0; + +if [ $stage -le 0 ] && [ "$realign_iters" != "" ]; then + echo "$0: Compiling graphs of transcripts" + $cmd JOB=1:$nj $dir/log/compile_graphs.JOB.log \ + compile-train-graphs --read-disambig-syms=$lang/phones/disambig.int $dir/tree $dir/1.mdl $lang/L.fst \ + "ark:utils/sym2int.pl --map-oov $oov -f 2- $lang/words.txt < $sdata/JOB/text |" \ + "ark:|gzip -c >$dir/fsts.JOB.gz" || exit 1; +fi + +x=1 +while [ $x -lt $num_iters ]; do + echo Pass $x + if echo $realign_iters | grep -w $x >/dev/null && [ $stage -le $x ]; then + echo Aligning data + mdl="gmm-boost-silence --boost=$boost_silence `cat $lang/phones/optional_silence.csl` $dir/$x.mdl - |" + $cmd JOB=1:$nj $dir/log/align.$x.JOB.log \ + gmm-align-compiled $scale_opts --beam=$beam --retry-beam=$retry_beam --careful=$careful "$mdl" \ + "ark:gunzip -c $dir/fsts.JOB.gz|" "$feats" \ + "ark:|gzip -c >$dir/ali.JOB.gz" || exit 1; + fi + + if echo $fmllr_iters | grep -w $x >/dev/null; then + if [ $stage -le $x ]; then + echo Estimating fMLLR transforms + # We estimate a transform that's additional to the previous transform; + # we'll compose them. + $cmd JOB=1:$nj $dir/log/fmllr.$x.JOB.log \ + ali-to-post "ark:gunzip -c $dir/ali.JOB.gz|" ark:- \| \ + weight-silence-post $silence_weight $silphonelist $dir/$x.mdl ark:- ark:- \| \ + gmm-est-fmllr --fmllr-update-type=$fmllr_update_type \ + --spk2utt=ark:$sdata/JOB/spk2utt $dir/$x.mdl \ + "$feats" ark:- ark:$dir/tmp_trans.JOB || exit 1; + for n in `seq $nj`; do + ! ( compose-transforms --b-is-affine=true \ + ark:$dir/tmp_trans.$n ark:$cur_trans_dir/trans.$n ark:$dir/composed_trans.$n \ + && mv $dir/composed_trans.$n $dir/trans.$n && \ + rm $dir/tmp_trans.$n ) 2>$dir/log/compose_transforms.$x.log \ + && echo "$0: Error composing transforms" && exit 1; + done + fi + feats="$sifeats transform-feats --utt2spk=ark:$sdata/JOB/utt2spk ark:$dir/trans.JOB ark:- ark:- |" + cur_trans_dir=$dir + fi + + if [ $stage -le $x ]; then + $cmd JOB=1:$nj $dir/log/acc.$x.JOB.log \ + gmm-acc-stats-ali $dir/$x.mdl "$feats" \ + "ark,s,cs:gunzip -c $dir/ali.JOB.gz|" $dir/$x.JOB.acc || exit 1; + [ `ls $dir/$x.*.acc | wc -w` -ne "$nj" ] && echo "$0: Wrong #accs" && exit 1; + $cmd $dir/log/update.$x.log \ + gmm-est --power=$power --write-occs=$dir/$[$x+1].occs --mix-up=$numgauss $dir/$x.mdl \ + "gmm-sum-accs - $dir/$x.*.acc |" $dir/$[$x+1].mdl || exit 1; + rm $dir/$x.mdl $dir/$x.*.acc + rm $dir/$x.occs + fi + [ $x -le $max_iter_inc ] && numgauss=$[$numgauss+$incgauss]; + x=$[$x+1]; +done + + +if [ $stage -le $x ]; then + # Accumulate stats for "alignment model"-- this model is + # computed with the speaker-independent features, but matches Gaussian-for-Gaussian + # with the final speaker-adapted model. + $cmd JOB=1:$nj $dir/log/acc_alimdl.JOB.log \ + ali-to-post "ark:gunzip -c $dir/ali.JOB.gz|" ark:- \| \ + gmm-acc-stats-twofeats $dir/$x.mdl "$feats" "$sifeats" \ + ark,s,cs:- $dir/$x.JOB.acc || exit 1; + [ `ls $dir/$x.*.acc | wc -w` -ne "$nj" ] && echo "$0: Wrong #accs" && exit 1; + # Update model. + $cmd $dir/log/est_alimdl.log \ + gmm-est --power=$power --remove-low-count-gaussians=false $dir/$x.mdl \ + "gmm-sum-accs - $dir/$x.*.acc|" $dir/$x.alimdl || exit 1; + rm $dir/$x.*.acc +fi + +rm $dir/final.{mdl,alimdl,occs} 2>/dev/null +ln -s $x.mdl $dir/final.mdl +ln -s $x.occs $dir/final.occs +ln -s $x.alimdl $dir/final.alimdl + + +steps/diagnostic/analyze_alignments.sh --cmd "$cmd" $lang $dir + +utils/summarize_warnings.pl $dir/log +( + echo "$0: Likelihood evolution:" + for x in `seq $[$num_iters-1]`; do + tail -n 30 $dir/log/acc.$x.*.log | awk '/Overall avg like/{l += $(NF-3)*$(NF-1); t += $(NF-1); } + /Overall average logdet/{d += $(NF-3)*$(NF-1); t2 += $(NF-1);} + END{ d /= t2; l /= t; printf("%s ", d+l); } ' + done + echo +) | tee $dir/log/summary.log + + +steps/info/gmm_dir_info.pl $dir + +echo "$0: done training SAT system in $dir" + +exit 0 diff --git a/fairseq/examples/wav2vec/unsupervised/models/__init__.py b/fairseq/examples/wav2vec/unsupervised/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3e3039b7081a9e3228c8abefb6391a75b4864439 --- /dev/null +++ b/fairseq/examples/wav2vec/unsupervised/models/__init__.py @@ -0,0 +1,11 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from .wav2vec_u import Wav2vec_U + + +__all__ = [ + "Wav2vec_U", +] diff --git a/fairseq/examples/wav2vec/unsupervised/models/wav2vec_u.py b/fairseq/examples/wav2vec/unsupervised/models/wav2vec_u.py new file mode 100644 index 0000000000000000000000000000000000000000..8a1e9055e3a2872cbb6bf58f9c711bcd444de8d0 --- /dev/null +++ b/fairseq/examples/wav2vec/unsupervised/models/wav2vec_u.py @@ -0,0 +1,687 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from dataclasses import dataclass +from enum import Enum, auto +import math +import numpy as np +from typing import Tuple, List, Optional, Dict + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch import autograd + +from fairseq import checkpoint_utils, utils +from fairseq.dataclass import FairseqDataclass +from fairseq.models import BaseFairseqModel, register_model +from fairseq.modules import ( + SamePad, + TransposeLast, +) + + +class SegmentationType(Enum): + NONE = auto() + RANDOM = auto() + UNIFORM_RANDOM = auto() + UNIFORM_RANDOM_JOIN = auto() + JOIN = auto() + + +@dataclass +class SegmentationConfig(FairseqDataclass): + type: SegmentationType = SegmentationType.NONE + subsample_rate: float = 0.25 + mean_pool: bool = True + mean_pool_join: bool = False + remove_zeros: bool = False + + +@dataclass +class Wav2vec_UConfig(FairseqDataclass): + discriminator_kernel: int = 3 + discriminator_dilation: int = 1 + discriminator_dim: int = 256 + discriminator_causal: bool = True + discriminator_linear_emb: bool = False + discriminator_depth: int = 1 + discriminator_max_pool: bool = False + discriminator_act_after_linear: bool = False + discriminator_dropout: float = 0.0 + discriminator_spectral_norm: bool = False + discriminator_weight_norm: bool = False + + generator_kernel: int = 4 + generator_dilation: int = 1 + generator_stride: int = 1 + generator_pad: int = -1 + generator_bias: bool = False + generator_dropout: float = 0.0 + generator_batch_norm: int = 0 + generator_residual: bool = False + + blank_weight: float = 0 + blank_mode: str = "add" + blank_is_sil: bool = False + no_softmax: bool = False + + smoothness_weight: float = 0.0 + smoothing: float = 0.0 + smoothing_one_sided: bool = False + gradient_penalty: float = 0.0 + probabilistic_grad_penalty_slicing: bool = False + code_penalty: float = 0.0 + mmi_weight: float = 0.0 + target_dim: int = 64 + target_downsample_rate: int = 2 + gumbel: bool = False + hard_gumbel: bool = True + temp: Tuple[float, float, float] = (2, 0.1, 0.99995) + input_dim: int = 128 + + segmentation: SegmentationConfig = SegmentationConfig() + + +class Segmenter(nn.Module): + cfg: SegmentationConfig + + def __init__(self, cfg: SegmentationConfig): + super().__init__() + self.cfg = cfg + self.subsample_rate = cfg.subsample_rate + + def pre_segment(self, dense_x, dense_padding_mask): + return dense_x, dense_padding_mask + + def logit_segment(self, logits, padding_mask): + return logits, padding_mask + + +class RandomSegmenter(Segmenter): + def pre_segment(self, dense_x, dense_padding_mask): + target_num = math.ceil(dense_x.size(1) * self.subsample_rate) + ones = torch.ones(dense_x.shape[:-1], device=dense_x.device) + indices, _ = ones.multinomial(target_num).sort(dim=-1) + indices_ld = indices.unsqueeze(-1).expand(-1, -1, dense_x.size(-1)) + dense_x = dense_x.gather(1, indices_ld) + dense_padding_mask = dense_padding_mask.gather(1, index=indices) + return dense_x, dense_padding_mask + + +class UniformRandomSegmenter(Segmenter): + def pre_segment(self, dense_x, dense_padding_mask): + bsz, tsz, fsz = dense_x.shape + + target_num = math.ceil(tsz * self.subsample_rate) + + rem = tsz % target_num + + if rem > 0: + dense_x = F.pad(dense_x, [0, 0, 0, target_num - rem]) + dense_padding_mask = F.pad( + dense_padding_mask, [0, target_num - rem], value=True + ) + + dense_x = dense_x.view(bsz, target_num, -1, fsz) + dense_padding_mask = dense_padding_mask.view(bsz, target_num, -1) + + if self.cfg.mean_pool: + dense_x = dense_x.mean(dim=-2) + dense_padding_mask = dense_padding_mask.all(dim=-1) + else: + ones = torch.ones((bsz, dense_x.size(2)), device=dense_x.device) + indices = ones.multinomial(1) + indices = indices.unsqueeze(-1).expand(-1, target_num, -1) + indices_ld = indices.unsqueeze(-1).expand(-1, -1, -1, fsz) + dense_x = dense_x.gather(2, indices_ld).reshape(bsz, -1, fsz) + dense_padding_mask = dense_padding_mask.gather(2, index=indices).reshape( + bsz, -1 + ) + return dense_x, dense_padding_mask + + +class JoinSegmenter(Segmenter): + def logit_segment(self, logits, padding_mask): + preds = logits.argmax(dim=-1) + + if padding_mask.any(): + preds[padding_mask] = -1 # mark pad + uniques = [] + + bsz, tsz, csz = logits.shape + + for p in preds: + uniques.append( + p.cpu().unique_consecutive(return_inverse=True, return_counts=True) + ) + + new_tsz = max(u[0].numel() for u in uniques) + new_logits = logits.new_zeros(bsz, new_tsz, csz) + new_pad = padding_mask.new_zeros(bsz, new_tsz) + + for b in range(bsz): + u, idx, c = uniques[b] + keep = u != -1 + + if self.cfg.remove_zeros: + keep.logical_and_(u != 0) + + if self.training and not self.cfg.mean_pool_join: + u[0] = 0 + u[1:] = c.cumsum(0)[:-1] + m = c > 1 + r = torch.rand(m.sum()) + o = (c[m] * r).long() + u[m] += o + new_logits[b, : u.numel()] = logits[b, u] + else: + new_logits[b].index_add_( + dim=0, index=idx.to(new_logits.device), source=logits[b] + ) + new_logits[b, : c.numel()] /= c.unsqueeze(-1).to(new_logits.device) + + new_sz = keep.sum() + if not keep.all(): + kept_logits = new_logits[b, : c.numel()][keep] + new_logits[b, :new_sz] = kept_logits + + if new_sz < new_tsz: + pad = new_tsz - new_sz + new_logits[b, -pad:] = 0 + new_pad[b, -pad:] = True + + return new_logits, new_pad + + +class UniformRandomJoinSegmenter(UniformRandomSegmenter, JoinSegmenter): + pass + + +SEGMENT_FACTORY = { + SegmentationType.NONE: Segmenter, + SegmentationType.RANDOM: RandomSegmenter, + SegmentationType.UNIFORM_RANDOM: UniformRandomSegmenter, + SegmentationType.UNIFORM_RANDOM_JOIN: UniformRandomJoinSegmenter, + SegmentationType.JOIN: JoinSegmenter, +} + + +class Discriminator(nn.Module): + def __init__(self, dim, cfg: Wav2vec_UConfig): + super().__init__() + + inner_dim = cfg.discriminator_dim + kernel = cfg.discriminator_kernel + dilation = cfg.discriminator_dilation + self.max_pool = cfg.discriminator_max_pool + + if cfg.discriminator_causal: + padding = kernel - 1 + else: + padding = kernel // 2 + + def make_conv(in_d, out_d, k, p=0, has_dilation=True): + conv = nn.Conv1d( + in_d, + out_d, + kernel_size=k, + padding=p, + dilation=dilation if has_dilation else 1, + ) + if cfg.discriminator_spectral_norm: + conv = nn.utils.spectral_norm(conv) + elif cfg.discriminator_weight_norm: + conv = nn.utils.weight_norm(conv) + return conv + + inner_net = [ + nn.Sequential( + make_conv(inner_dim, inner_dim, kernel, padding), + SamePad(kernel_size=kernel, causal=cfg.discriminator_causal), + nn.Dropout(cfg.discriminator_dropout), + nn.GELU(), + ) + for _ in range(cfg.discriminator_depth - 1) + ] + [ + make_conv(inner_dim, 1, kernel, padding, has_dilation=False), + SamePad(kernel_size=kernel, causal=cfg.discriminator_causal), + ] + + if cfg.discriminator_linear_emb: + emb_net = [make_conv(dim, inner_dim, 1)] + else: + emb_net = [ + make_conv(dim, inner_dim, kernel, padding), + SamePad(kernel_size=kernel, causal=cfg.discriminator_causal), + ] + + if cfg.discriminator_act_after_linear: + emb_net.append(nn.GELU()) + + self.net = nn.Sequential( + *emb_net, + nn.Dropout(cfg.discriminator_dropout), + *inner_net, + ) + + def forward(self, x, padding_mask): + x = x.transpose(1, 2) # BTC -> BCT + x = self.net(x) + x = x.transpose(1, 2) + x_sz = x.size(1) + if padding_mask is not None and padding_mask.any() and padding_mask.dim() > 1: + padding_mask = padding_mask[:, : x.size(1)] + x[padding_mask] = float("-inf") if self.max_pool else 0 + x_sz = x_sz - padding_mask.sum(dim=-1) + x = x.squeeze(-1) + if self.max_pool: + x, _ = x.max(dim=-1) + else: + x = x.sum(dim=-1) + x = x / x_sz + return x + + +class Generator(nn.Module): + def __init__(self, input_dim, output_dim, cfg: Wav2vec_UConfig): + super().__init__() + + self.cfg = cfg + self.output_dim = output_dim + self.stride = cfg.generator_stride + self.dropout = nn.Dropout(cfg.generator_dropout) + self.batch_norm = cfg.generator_batch_norm != 0 + self.residual = cfg.generator_residual + + padding = ( + cfg.generator_kernel // 2 if cfg.generator_pad < 0 else cfg.generator_pad + ) + self.proj = nn.Sequential( + TransposeLast(), + nn.Conv1d( + input_dim, + output_dim, + kernel_size=cfg.generator_kernel, + stride=cfg.generator_stride, + dilation=cfg.generator_dilation, + padding=padding, + bias=cfg.generator_bias, + ), + TransposeLast(), + ) + + if self.batch_norm: + self.bn = nn.BatchNorm1d(input_dim) + self.bn.weight.data.fill_(cfg.generator_batch_norm) + if self.residual: + self.in_proj = nn.Linear(input_dim, input_dim) + + def forward(self, dense_x, tokens, dense_padding_mask): + result = {} + + if self.batch_norm: + dense_x = self.bn_padded_data(dense_x, dense_padding_mask) + if self.residual: + inter_x = self.in_proj(self.dropout(dense_x)) + dense_x = dense_x + inter_x + result["inter_x"] = inter_x + + dense_x = self.dropout(dense_x) + + dense_x = self.proj(dense_x) + if self.stride > 1: + dense_padding_mask = dense_padding_mask[:, :: self.stride] + + if dense_padding_mask.size(1) != dense_x.size(1): + new_padding = dense_padding_mask.new_zeros(dense_x.shape[:-1]) + diff = new_padding.size(1) - dense_padding_mask.size(1) + + if diff > 0: + new_padding[:, diff:] = dense_padding_mask + else: + assert diff < 0 + new_padding = dense_padding_mask[:, :diff] + + dense_padding_mask = new_padding + + token_x = None + if tokens is not None: + token_x = dense_x.new_zeros(tokens.numel(), self.output_dim) + token_x.scatter_(1, tokens.view(-1, 1).long(), 1) + token_x = token_x.view(tokens.shape + (self.output_dim,)) + + result["dense_x"] = dense_x + result["token_x"] = token_x + result["dense_padding_mask"] = dense_padding_mask + + return result + + def bn_padded_data(self, feature, padding_mask): + normed_feature = feature.clone() + normed_feature[~padding_mask] = self.bn( + feature[~padding_mask].unsqueeze(-1) + ).squeeze(-1) + return normed_feature + + +@register_model("wav2vec_u", dataclass=Wav2vec_UConfig) +class Wav2vec_U(BaseFairseqModel): + def calc_gradient_penalty(self, real_data, fake_data): + + b_size = min(real_data.size(0), fake_data.size(0)) + t_size = min(real_data.size(1), fake_data.size(1)) + + if self.cfg.probabilistic_grad_penalty_slicing: + + def get_slice(data, dim, target_size): + + size = data.size(dim) + diff = size - target_size + if diff <= 0: + return data + + start = np.random.randint(0, diff + 1) + return data.narrow(dim=dim, start=start, length=target_size) + + real_data = get_slice(real_data, 0, b_size) + real_data = get_slice(real_data, 1, t_size) + fake_data = get_slice(fake_data, 0, b_size) + fake_data = get_slice(fake_data, 1, t_size) + + else: + real_data = real_data[:b_size, :t_size] + fake_data = fake_data[:b_size, :t_size] + + alpha = torch.rand(real_data.size(0), 1, 1) + alpha = alpha.expand(real_data.size()) + alpha = alpha.to(real_data.device) + + interpolates = alpha * real_data + ((1 - alpha) * fake_data) + + disc_interpolates = self.discriminator(interpolates, None) + + gradients = autograd.grad( + outputs=disc_interpolates, + inputs=interpolates, + grad_outputs=torch.ones(disc_interpolates.size(), device=real_data.device), + create_graph=True, + retain_graph=True, + only_inputs=True, + )[0] + + gradient_penalty = (gradients.norm(2, dim=1) - 1) ** 2 + return gradient_penalty + + def set_num_updates(self, num_updates): + super().set_num_updates(num_updates) + self.update_num = num_updates + self.curr_temp = max( + self.max_temp * self.temp_decay ** num_updates, self.min_temp + ) + + def discrim_step(self, num_updates): + return num_updates % 2 == 1 + + def get_groups_for_update(self, num_updates): + return "discriminator" if self.discrim_step(num_updates) else "generator" + + def __init__(self, cfg: Wav2vec_UConfig, target_dict): + super().__init__() + + self.cfg = cfg + self.zero_index = target_dict.index("") if "" in target_dict else 0 + self.smoothness_weight = cfg.smoothness_weight + + output_size = len(target_dict) + self.pad = target_dict.pad() + self.eos = target_dict.eos() + self.smoothing = cfg.smoothing + self.smoothing_one_sided = cfg.smoothing_one_sided + self.no_softmax = cfg.no_softmax + self.gumbel = cfg.gumbel + self.hard_gumbel = cfg.hard_gumbel + self.last_acc = None + + self.gradient_penalty = cfg.gradient_penalty + self.code_penalty = cfg.code_penalty + self.mmi_weight = cfg.mmi_weight + self.blank_weight = cfg.blank_weight + self.blank_mode = cfg.blank_mode + self.blank_index = target_dict.index("") if cfg.blank_is_sil else 0 + assert self.blank_index != target_dict.unk() + + self.discriminator = Discriminator(output_size, cfg) + for p in self.discriminator.parameters(): + p.param_group = "discriminator" + + self.pca_A = self.pca_b = None + d = cfg.input_dim + + self.segmenter = SEGMENT_FACTORY[cfg.segmentation.type](cfg.segmentation) + + self.generator = Generator(d, output_size, cfg) + + for p in self.generator.parameters(): + p.param_group = "generator" + + for p in self.segmenter.parameters(): + p.param_group = "generator" + + self.max_temp, self.min_temp, self.temp_decay = cfg.temp + self.curr_temp = self.max_temp + self.update_num = 0 + + if self.mmi_weight > 0: + self.target_downsample_rate = cfg.target_downsample_rate + self.decoder = nn.Linear(d, cfg.target_dim) + for p in self.decoder.parameters(): + p.param_group = "generator" + + @classmethod + def build_model(cls, cfg, task): + return cls(cfg, task.target_dictionary) + + def get_logits( + self, + net_output: Optional[Dict[str, List[Optional[torch.Tensor]]]], + normalize: bool = False, + ): + logits = net_output["logits"] + + if self.blank_weight != 0: + if self.blank_mode == "add": + logits[..., self.blank_index] += self.blank_weight + elif self.blank_mode == "set": + logits[..., self.blank_index] = self.blank_weight + else: + raise Exception(f"invalid blank mode {self.blank_mode}") + + padding = net_output["padding_mask"] + if padding.any(): + logits[padding] = float("-inf") + logits[padding][..., self.blank_index] = float("inf") + + if normalize: + logits = utils.log_softmax(logits.float(), dim=-1) + + return logits.transpose(0, 1) + + def get_normalized_probs( + self, + net_output: Tuple[ + torch.Tensor, Optional[Dict[str, List[Optional[torch.Tensor]]]] + ], + log_probs: bool, + sample: Optional[Dict[str, torch.Tensor]] = None, + ): + logits = self.get_logits(net_output) + + probs = super().get_normalized_probs(logits, log_probs, sample) + # BTC -> TBC for ctc + probs = probs.transpose(0, 1) + return probs + + def normalize(self, dense_x): + + bsz, tsz, csz = dense_x.shape + + if dense_x.numel() == 0: + raise Exception(dense_x.shape) + _, k = dense_x.max(-1) + hard_x = ( + dense_x.new_zeros(bsz * tsz, csz) + .scatter_(-1, k.view(-1, 1), 1.0) + .view(-1, csz) + ) + hard_probs = torch.mean(hard_x.float(), dim=0) + code_perplexity = torch.exp( + -torch.sum(hard_probs * torch.log(hard_probs + 1e-7), dim=-1) + ) + + avg_probs = torch.softmax(dense_x.reshape(-1, csz).float(), dim=-1).mean(dim=0) + prob_perplexity = torch.exp( + -torch.sum(avg_probs * torch.log(avg_probs + 1e-7), dim=-1) + ) + + if not self.no_softmax: + if self.training and self.gumbel: + dense_x = F.gumbel_softmax( + dense_x.float(), tau=self.curr_temp, hard=self.hard_gumbel + ).type_as(dense_x) + else: + dense_x = dense_x.softmax(-1) + + return dense_x, code_perplexity, prob_perplexity + + def forward( + self, + features, + padding_mask, + random_label=None, + dense_x_only=False, + segment=True, + aux_target=None, + ): + if segment: + features, padding_mask = self.segmenter.pre_segment(features, padding_mask) + + orig_size = features.size(0) * features.size(1) - padding_mask.sum() + + gen_result = self.generator(features, random_label, padding_mask) + + orig_dense_x, token_x = gen_result["dense_x"], gen_result["token_x"] + orig_dense_padding_mask = gen_result["dense_padding_mask"] + + if segment: + dense_x, dense_padding_mask = self.segmenter.logit_segment( + orig_dense_x, orig_dense_padding_mask + ) + else: + dense_x = orig_dense_x + dense_padding_mask = orig_dense_padding_mask + + dense_logits = dense_x + prob_perplexity = None + code_perplexity = None + + if not (self.no_softmax and dense_x_only): + dense_x, code_perplexity, prob_perplexity = self.normalize(dense_logits) + + if dense_x_only or self.discriminator is None: + return { + "logits": dense_x, + "padding_mask": dense_padding_mask, + } + + token_padding_mask = random_label == self.pad + + dense_y = self.discriminator(dense_x, dense_padding_mask) + token_y = self.discriminator(token_x, token_padding_mask) + + sample_size = features.size(0) + + d_step = self.discrim_step(self.update_num) + + fake_smooth = self.smoothing + real_smooth = self.smoothing + if self.smoothing_one_sided: + fake_smooth = 0 + + zero_loss = None + smoothness_loss = None + code_pen = None + mmi_loss = None + + if d_step: + loss_dense = F.binary_cross_entropy_with_logits( + dense_y, + dense_y.new_ones(dense_y.shape) - fake_smooth, + reduction="sum", + ) + loss_token = F.binary_cross_entropy_with_logits( + token_y, + token_y.new_zeros(token_y.shape) + real_smooth, + reduction="sum", + ) + if self.training and self.gradient_penalty > 0: + grad_pen = self.calc_gradient_penalty(token_x, dense_x) + grad_pen = grad_pen.sum() * self.gradient_penalty + else: + grad_pen = None + else: + grad_pen = None + loss_token = None + loss_dense = F.binary_cross_entropy_with_logits( + dense_y, + dense_y.new_zeros(dense_y.shape) + fake_smooth, + reduction="sum", + ) + num_vars = dense_x.size(-1) + if prob_perplexity is not None: + code_pen = (num_vars - prob_perplexity) / num_vars + code_pen = code_pen * sample_size * self.code_penalty + + if self.smoothness_weight > 0: + smoothness_loss = F.mse_loss( + dense_logits[:, :-1], dense_logits[:, 1:], reduction="none" + ) + smoothness_loss[dense_padding_mask[:, 1:]] = 0 + smoothness_loss = ( + smoothness_loss.mean() * sample_size * self.smoothness_weight + ) + + if (self.mmi_weight > 0) and (aux_target is not None): + inter_x = self.decoder(gen_result["inter_x"]) + if self.target_downsample_rate > 1: + aux_target = aux_target[:, :: self.target_downsample_rate] + max_t_len = min(aux_target.shape[1], inter_x.shape[1]) + mmi_loss = F.cross_entropy( + inter_x[:, :max_t_len].transpose(1, 2), + aux_target[:, :max_t_len], + ignore_index=-1, + reduction="none", + ) + mmi_loss = mmi_loss.mean() * mmi_loss.shape[0] * self.mmi_weight + + result = { + "losses": { + "grad_pen": grad_pen, + "code_pen": code_pen, + "smoothness": smoothness_loss, + "mmi": mmi_loss, + }, + "temp": self.curr_temp, + "code_ppl": code_perplexity, + "prob_ppl": prob_perplexity, + "d_steps": int(d_step), + "sample_size": sample_size, + } + + suff = "_d" if d_step else "_g" + result["losses"]["dense" + suff] = loss_dense + result["losses"]["token" + suff] = loss_token + + return result diff --git a/fairseq/examples/wav2vec/unsupervised/scripts/apply_pca.py b/fairseq/examples/wav2vec/unsupervised/scripts/apply_pca.py new file mode 100644 index 0000000000000000000000000000000000000000..10ad6ce47cfdf0a87ba089b299fe9551b29fa167 --- /dev/null +++ b/fairseq/examples/wav2vec/unsupervised/scripts/apply_pca.py @@ -0,0 +1,76 @@ +#!/usr/bin/env python3 -u +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import argparse +import os +import os.path as osp +import math +import numpy as np +import tqdm +import torch +from shutil import copyfile + +from npy_append_array import NpyAppendArray + + +def get_parser(): + parser = argparse.ArgumentParser( + description="transforms features via a given pca and stored them in target dir" + ) + # fmt: off + parser.add_argument('source', help='directory with features') + parser.add_argument('--split', help='which split to read', required=True) + parser.add_argument('--save-dir', help='where to save the output', required=True) + parser.add_argument('--pca-path', type=str, help='pca location. will append _A.npy and _b.npy', required=True) + parser.add_argument('--batch-size', type=int, default=2048000, help='batch size') + parser.add_argument('--unfiltered', action='store_true', help='process the unfiltered version') + # fmt: on + + return parser + + +def main(): + parser = get_parser() + args = parser.parse_args() + + source_path = osp.join(args.source, args.split) + data_poth = source_path + "_unfiltered" if args.unfiltered else source_path + + print(f"data path: {data_poth}") + + features = np.load(data_poth + ".npy", mmap_mode="r") + pca_A = torch.from_numpy(np.load(args.pca_path + "_A.npy")).cuda() + pca_b = torch.from_numpy(np.load(args.pca_path + "_b.npy")).cuda() + + os.makedirs(args.save_dir, exist_ok=True) + save_path = osp.join(args.save_dir, args.split) + + copyfile(source_path + ".tsv", save_path + ".tsv") + copyfile(data_poth + ".lengths", save_path + ".lengths") + + if osp.exists(source_path + ".phn"): + copyfile(source_path + ".phn", save_path + ".phn") + + if osp.exists(source_path + ".wrd"): + copyfile(source_path + ".wrd", save_path + ".wrd") + + if osp.exists(save_path + ".npy"): + os.remove(save_path + ".npy") + npaa = NpyAppendArray(save_path + ".npy") + + batches = math.ceil(features.shape[0] / args.batch_size) + + with torch.no_grad(): + for b in tqdm.trange(batches): + start = b * args.batch_size + end = start + args.batch_size + x = torch.from_numpy(features[start:end]).cuda() + x = torch.matmul(x, pca_A) + pca_b + npaa.append(x.cpu().numpy()) + + +if __name__ == "__main__": + main() diff --git a/fairseq/examples/wav2vec/unsupervised/scripts/copy_labels.py b/fairseq/examples/wav2vec/unsupervised/scripts/copy_labels.py new file mode 100644 index 0000000000000000000000000000000000000000..989868388eefccc37c82d7602f709632035c7aa1 --- /dev/null +++ b/fairseq/examples/wav2vec/unsupervised/scripts/copy_labels.py @@ -0,0 +1,10 @@ +#!/usr/bin/env python3 -u +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import sys + +for idx, line in enumerate(sys.stdin): + print(f"utt{idx:010d} {line}", end="") diff --git a/fairseq/examples/wav2vec/unsupervised/scripts/filter_lexicon.py b/fairseq/examples/wav2vec/unsupervised/scripts/filter_lexicon.py new file mode 100644 index 0000000000000000000000000000000000000000..5bf3e51e7a50ac3f07cc41739198cde946dc79aa --- /dev/null +++ b/fairseq/examples/wav2vec/unsupervised/scripts/filter_lexicon.py @@ -0,0 +1,40 @@ +#!/usr/bin/env python3 -u +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import argparse +import sys + +from fairseq.data import Dictionary + + +def get_parser(): + parser = argparse.ArgumentParser( + description="filters a lexicon given a unit dictionary" + ) + parser.add_argument("-d", "--unit-dict", help="unit dictionary", required=True) + return parser + + +def main(): + parser = get_parser() + args = parser.parse_args() + + d = Dictionary.load(args.unit_dict) + symbols = set(d.symbols) + + for line in sys.stdin: + items = line.rstrip().split() + skip = len(items) < 2 + for x in items[1:]: + if x not in symbols: + skip = True + break + if not skip: + print(line, end="") + + +if __name__ == "__main__": + main() diff --git a/fairseq/examples/wav2vec/unsupervised/scripts/filter_tsv.py b/fairseq/examples/wav2vec/unsupervised/scripts/filter_tsv.py new file mode 100644 index 0000000000000000000000000000000000000000..a09d79acf31414ea3eae82db59cf9f105aefcdf1 --- /dev/null +++ b/fairseq/examples/wav2vec/unsupervised/scripts/filter_tsv.py @@ -0,0 +1,37 @@ +#!/usr/bin/env python3 -u +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import os +import argparse +import sys + + +parser = argparse.ArgumentParser() +parser.add_argument("--tsv", required=True, type=str) +parser.add_argument("--no-skip", action="store_true") +parser.add_argument("--keep", action="store_true") +params = parser.parse_args() + + +def get_fname(line): + p = os.path.basename(line.split("\t")[0]) + p = os.path.splitext(p)[0] + return p + + +# filenames to exclude +seen = set() +with open(params.tsv) as f: + if not params.no_skip: + root = next(f).rstrip() + for line in f: + seen.add(get_fname(line)) + +for i, line in enumerate(sys.stdin): + exists = get_fname(line) in seen + keep = (exists and params.keep) or (not exists and not params.keep) + if i == 0 or keep: + print(line, end="") diff --git a/fairseq/examples/wav2vec/unsupervised/scripts/g2p_wrd_to_phn.py b/fairseq/examples/wav2vec/unsupervised/scripts/g2p_wrd_to_phn.py new file mode 100644 index 0000000000000000000000000000000000000000..2e31c307bd67d10941150160c7fb8c9e085ac5d9 --- /dev/null +++ b/fairseq/examples/wav2vec/unsupervised/scripts/g2p_wrd_to_phn.py @@ -0,0 +1,45 @@ +#!/usr/bin/env python3 -u +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import argparse +import sys + +from g2p_en import G2p + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--compact", + action="store_true", + help="if set, compacts phones", + ) + args = parser.parse_args() + + compact = args.compact + + wrd_to_phn = {} + g2p = G2p() + for line in sys.stdin: + words = line.strip().split() + phones = [] + for w in words: + if w not in wrd_to_phn: + wrd_to_phn[w] = g2p(w) + if compact: + wrd_to_phn[w] = [ + p[:-1] if p[-1].isnumeric() else p for p in wrd_to_phn[w] + ] + phones.extend(wrd_to_phn[w]) + try: + print(" ".join(phones)) + except: + print(wrd_to_phn, words, phones, file=sys.stderr) + raise + + +if __name__ == "__main__": + main() diff --git a/fairseq/examples/wav2vec/unsupervised/scripts/ltr_to_wrd.py b/fairseq/examples/wav2vec/unsupervised/scripts/ltr_to_wrd.py new file mode 100644 index 0000000000000000000000000000000000000000..36c85d1e2f60487494a92207feb4685e78db8aa2 --- /dev/null +++ b/fairseq/examples/wav2vec/unsupervised/scripts/ltr_to_wrd.py @@ -0,0 +1,16 @@ +#!/usr/bin/env python3 -u +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import sys + + +def main(): + for line in sys.stdin: + print(line.replace(" ", "").replace("|", " ").strip()) + + +if __name__ == "__main__": + main() diff --git a/fairseq/examples/wav2vec/unsupervised/scripts/mean_pool.py b/fairseq/examples/wav2vec/unsupervised/scripts/mean_pool.py new file mode 100644 index 0000000000000000000000000000000000000000..4eea048ef3455cb3c897e74c18778c78fdc9fcbf --- /dev/null +++ b/fairseq/examples/wav2vec/unsupervised/scripts/mean_pool.py @@ -0,0 +1,99 @@ +#!/usr/bin/env python3 -u +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import argparse +import os +import os.path as osp +import math +import numpy as np +import tqdm +import torch +import torch.nn.functional as F +from shutil import copyfile + +from npy_append_array import NpyAppendArray + + +def get_parser(): + parser = argparse.ArgumentParser( + description="mean pools representations by compressing uniform splits of the data" + ) + # fmt: off + parser.add_argument('source', help='directory with features') + parser.add_argument('--split', help='which split to read', required=True) + parser.add_argument('--save-dir', help='where to save the output', required=True) + parser.add_argument('--subsample-rate', type=float, default=0.5, help='size to subsample data to') + + parser.add_argument('--remove-extra', action='store_true', help='if true, removes extra states that cant be pooled, otherwise pads with 0s') + # fmt: on + + return parser + + +def main(): + parser = get_parser() + args = parser.parse_args() + + source_path = osp.join(args.source, args.split) + + print(f"data path: {source_path}") + + features = np.load(source_path + ".npy", mmap_mode="r") + + os.makedirs(args.save_dir, exist_ok=True) + save_path = osp.join(args.save_dir, args.split) + + copyfile(source_path + ".tsv", save_path + ".tsv") + + if os.path.exists(source_path + ".phn"): + copyfile(source_path + ".phn", save_path + ".phn") + if os.path.exists(source_path + ".wrd"): + copyfile(source_path + ".wrd", save_path + ".wrd") + + if os.path.exists(osp.join(args.source, "dict.phn.txt")): + copyfile( + osp.join(args.source, "dict.phn.txt"), + osp.join(args.save_dir, "dict.phn.txt"), + ) + + if osp.exists(save_path + ".npy"): + os.remove(save_path + ".npy") + npaa = NpyAppendArray(save_path + ".npy") + + with open(source_path + ".lengths", "r") as lf: + lengths = lf.readlines() + + fsz = features.shape[-1] + start = 0 + with torch.no_grad(): + with open(save_path + ".lengths", "w") as lengths_out: + for length in tqdm.tqdm(lengths): + length = int(length) + end = start + length + feats = features[start:end] + start += length + x = torch.from_numpy(feats).cuda() + target_num = math.ceil(length * args.subsample_rate) + rem = length % target_num + + if rem > 0: + if args.remove_extra: + to_rem = target_num - rem + target_num -= 1 + x = x[:-to_rem] + else: + to_add = target_num - rem + x = F.pad(x, [0, 0, 0, to_add]) + x[-to_add:] = x[-to_add - 1] + + x = x.view(target_num, -1, fsz) + x = x.mean(dim=-2) + print(target_num, file=lengths_out) + npaa.append(x.cpu().numpy()) + + +if __name__ == "__main__": + main() diff --git a/fairseq/examples/wav2vec/unsupervised/scripts/merge_clusters.py b/fairseq/examples/wav2vec/unsupervised/scripts/merge_clusters.py new file mode 100644 index 0000000000000000000000000000000000000000..2780f9d971d847b3ad0b59e9a33780553ebce902 --- /dev/null +++ b/fairseq/examples/wav2vec/unsupervised/scripts/merge_clusters.py @@ -0,0 +1,114 @@ +#!/usr/bin/env python3 -u +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import argparse +import os +import os.path as osp +import numpy as np +import tqdm +import torch +import random +from shutil import copyfile + +from npy_append_array import NpyAppendArray + + +def get_parser(): + parser = argparse.ArgumentParser( + description="transforms features via a given pca and stored them in target dir" + ) + # fmt: off + parser.add_argument('source', help='directory with features') + parser.add_argument('--split', help='which split to read', required=True) + parser.add_argument('--save-dir', help='where to save the output', required=True) + parser.add_argument('--cluster-dir', help='where the clusters are') + parser.add_argument('--pooling', type=str, default='mean', choices=['mean', 'sample'], help='how to pool') + # fmt: on + + return parser + + +def main(): + parser = get_parser() + args = parser.parse_args() + + source_path = osp.join(args.source, args.split) + cluster_path = osp.join(args.cluster_dir, args.split + ".src") + print(f"data path: {source_path}") + + features = np.load(source_path + ".npy", mmap_mode="r") + sizes = [] + offsets = [] + offset = 0 + with open(source_path + ".lengths", "r") as len_f: + for line in len_f: + length = int(line.rstrip()) + sizes.append(length) + offsets.append(offset) + offset += length + + clusters = [] + with open(cluster_path, "r") as cf: + for line in cf: + line = line.rstrip() + items = line.split() + items = list(map(int, items)) + clusters.append(items) + + os.makedirs(args.save_dir, exist_ok=True) + save_path = osp.join(args.save_dir, args.split) + + copyfile(source_path + ".tsv", save_path + ".tsv") + + if os.path.exists(source_path + ".phn"): + copyfile(source_path + ".phn", save_path + ".phn") + if os.path.exists(osp.join(args.source, "dict.phn.txt")): + copyfile( + osp.join(args.source, "dict.phn.txt"), + osp.join(args.save_dir, "dict.phn.txt"), + ) + if os.path.exists(source_path + ".wrd"): + copyfile(source_path + ".wrd", save_path + ".wrd") + + if osp.exists(save_path + ".npy"): + os.remove(save_path + ".npy") + npaa = NpyAppendArray(save_path + ".npy") + + def merge(feats, clust): + feats = torch.from_numpy(feats.copy()) + clust = torch.LongTensor(clust) + _, counts = clust.unique_consecutive(return_counts=True) + curr = 0 + + merged = [] + for c in counts: + c = c.item() + start = curr + end = curr + c + curr += c + if args.pooling == "mean": + new_x = feats[start:end].mean(dim=0) + elif args.pooling == "sample": + new_x = feats[start + int(random.random() * c)] + else: + raise NotImplementedError() + merged.append(new_x) + + return torch.stack(merged, dim=0).numpy() + + with open(save_path + ".lengths", "w") as l_f: + for size, offset, clust in tqdm.tqdm( + zip(sizes, offsets, clusters), total=len(sizes) + ): + end = size + offset + feats = features[offset:end] + feats = merge(feats, clust) + print(len(feats), file=l_f) + npaa.append(feats) + + +if __name__ == "__main__": + main() diff --git a/fairseq/examples/wav2vec/unsupervised/scripts/normalize_and_filter_text.py b/fairseq/examples/wav2vec/unsupervised/scripts/normalize_and_filter_text.py new file mode 100644 index 0000000000000000000000000000000000000000..c2bd16efb530af5af3f72ab0edb3044b4e9fcd5c --- /dev/null +++ b/fairseq/examples/wav2vec/unsupervised/scripts/normalize_and_filter_text.py @@ -0,0 +1,72 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import argparse +import fasttext as ft +import os +import regex +import sys + + +def get_parser(): + parser = argparse.ArgumentParser( + description="reads text from stdin and outputs normalized, lid-filtered version to stdout" + ) + parser.add_argument( + "--fasttext-model", + help="path to fasttext model", + default="lid.187.bin", + ) + parser.add_argument("--lang", help="language id", required=True) + parser.add_argument( + "--lid-threshold", + type=float, + help="threshold for this lang id probability", + default=0.4, + ) + + return parser + + +def main(): + parser = get_parser() + args = parser.parse_args() + filter_r = regex.compile(r"[^\p{L}\p{N}\p{M}\' \-]") + + lg = args.lang.lower() + lg_label = f"__label__{lg}" + thresh = args.lid_threshold + + if os.path.exists(args.fasttext_model): + model = ft.load_model(args.fasttext_model) + else: + print( + f"fasttext language id model {args.fasttext_model} not found. Proceeding without language filtering. " + f"To enable language filtering, please download the latest language id model " + f"from https://fasttext.cc/docs/en/language-identification.html", + file=sys.stderr, + ) + model = None + + for line in sys.stdin: + line = line.strip() + line = filter_r.sub(" ", line) + line = " ".join(line.split()) + + if model is not None: + lid, prob = model.predict(line, k=100) + try: + target_idx = lid.index(lg_label) + except ValueError: + continue + if target_idx == 0 or prob[target_idx] >= thresh: + print(line) + else: + print(line) + + +if __name__ == "__main__": + main() diff --git a/fairseq/examples/wav2vec/unsupervised/scripts/normalize_text.py b/fairseq/examples/wav2vec/unsupervised/scripts/normalize_text.py new file mode 100644 index 0000000000000000000000000000000000000000..9d0ffeb27d038a6b82aaf0f6bdf208af565663f6 --- /dev/null +++ b/fairseq/examples/wav2vec/unsupervised/scripts/normalize_text.py @@ -0,0 +1,22 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import regex +import sys + + +def main(): + filter_r = regex.compile(r"[^\p{L}\p{N}\p{M}\' \-]") + + for line in sys.stdin: + line = line.strip() + line = filter_r.sub(" ", line) + line = " ".join(line.split()) + print(line) + + +if __name__ == "__main__": + main() diff --git a/fairseq/examples/wav2vec/unsupervised/scripts/pca.py b/fairseq/examples/wav2vec/unsupervised/scripts/pca.py new file mode 100644 index 0000000000000000000000000000000000000000..948cf5319fd86ba1bccff65270b2881048faf9b1 --- /dev/null +++ b/fairseq/examples/wav2vec/unsupervised/scripts/pca.py @@ -0,0 +1,53 @@ +#!/usr/bin/env python3 -u +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import argparse +import os +import os.path as osp +import numpy as np + +import faiss + + + +def get_parser(): + parser = argparse.ArgumentParser( + description="compute a pca matrix given an array of numpy features" + ) + # fmt: off + parser.add_argument('data', help='numpy file containing features') + parser.add_argument('--output', help='where to save the pca matrix', required=True) + parser.add_argument('--dim', type=int, help='dim for pca reduction', required=True) + parser.add_argument('--eigen-power', type=float, default=0, help='eigen power, -0.5 for whitening') + + return parser + + +def main(): + parser = get_parser() + args = parser.parse_args() + + print("Reading features") + x = np.load(args.data, mmap_mode="r") + + print("Computing PCA") + pca = faiss.PCAMatrix(x.shape[-1], args.dim, args.eigen_power) + pca.train(x) + b = faiss.vector_to_array(pca.b) + A = faiss.vector_to_array(pca.A).reshape(pca.d_out, pca.d_in) + + os.makedirs(args.output, exist_ok=True) + + prefix = str(args.dim) + if args.eigen_power != 0: + prefix += f"_{args.eigen_power}" + + np.save(osp.join(args.output, f"{prefix}_pca_A"), A.T) + np.save(osp.join(args.output, f"{prefix}_pca_b"), b) + + +if __name__ == "__main__": + main() diff --git a/fairseq/examples/wav2vec/unsupervised/scripts/phonemize_with_sil.py b/fairseq/examples/wav2vec/unsupervised/scripts/phonemize_with_sil.py new file mode 100644 index 0000000000000000000000000000000000000000..c6512d7322def67b27aba46e9e36da171db6963b --- /dev/null +++ b/fairseq/examples/wav2vec/unsupervised/scripts/phonemize_with_sil.py @@ -0,0 +1,83 @@ +#!/usr/bin/env python3 -u +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import argparse +import numpy as np +import sys + + +def get_parser(): + parser = argparse.ArgumentParser( + description="converts words to phones adding optional silences around in between words" + ) + parser.add_argument( + "--sil-prob", + "-s", + type=float, + default=0, + help="probability of inserting silence between each word", + ) + parser.add_argument( + "--surround", + action="store_true", + help="if set, surrounds each example with silence", + ) + parser.add_argument( + "--lexicon", + help="lexicon to convert to phones", + required=True, + ) + + return parser + + +def main(): + parser = get_parser() + args = parser.parse_args() + + sil_prob = args.sil_prob + surround = args.surround + sil = "" + + wrd_to_phn = {} + + with open(args.lexicon, "r") as lf: + for line in lf: + items = line.rstrip().split() + assert len(items) > 1, line + assert items[0] not in wrd_to_phn, items + wrd_to_phn[items[0]] = items[1:] + + for line in sys.stdin: + words = line.strip().split() + + if not all(w in wrd_to_phn for w in words): + continue + + phones = [] + if surround: + phones.append(sil) + + sample_sil_probs = None + if sil_prob > 0 and len(words) > 1: + sample_sil_probs = np.random.random(len(words) - 1) + + for i, w in enumerate(words): + phones.extend(wrd_to_phn[w]) + if ( + sample_sil_probs is not None + and i < len(sample_sil_probs) + and sample_sil_probs[i] < sil_prob + ): + phones.append(sil) + + if surround: + phones.append(sil) + print(" ".join(phones)) + + +if __name__ == "__main__": + main() diff --git a/fairseq/examples/wav2vec/unsupervised/scripts/prepare_audio.sh b/fairseq/examples/wav2vec/unsupervised/scripts/prepare_audio.sh new file mode 100644 index 0000000000000000000000000000000000000000..013f7a9b055a7693a29f9c5ba1e4003a9a25850e --- /dev/null +++ b/fairseq/examples/wav2vec/unsupervised/scripts/prepare_audio.sh @@ -0,0 +1,78 @@ +#!/usr/bin/env zsh +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +source_dir=$1 +tgt_dir=$2 +model=$3 + +if [ -z "$4" ] + then + dim=512 + else + dim=$4 +fi + +echo "using $dim dim for PCA" + +if [ -z "$5" ] + then + layer=14 + else + layer=$5 +fi + +echo "extracting from layer $layer" + +train_split=train +valid_split=valid +test_split=test + +all_splits=($train_split) + +if [[ -f "$source_dir/valid.tsv" ]]; then + all_splits+=('valid') +fi + +if [[ -f "$source_dir/test.tsv" ]]; then + all_splits+=('test') +fi + +echo "processing splits: $all_splits" + +mkdir -p $tgt_dir + +cp $source_dir/*.tsv $tgt_dir +cp $source_dir/*.wrd $tgt_dir +cp $source_dir/*.ltr $tgt_dir +cp $source_dir/*.phn $tgt_dir +cp $source_dir/dict* $tgt_dir + +setopt shwordsplit + +for split in $all_splits; do + python $FAIRSEQ_ROOT/examples/wav2vec/unsupervised/scripts/wav2vec_extract_features.py $source_dir --split $split \ + --save-dir $tgt_dir --checkpoint $model --layer $layer +done + +python $FAIRSEQ_ROOT/examples/wav2vec/unsupervised/scripts/wav2vec_cluster_faiss.py $tgt_dir/${train_split}.tsv \ +--checkpoint $model --save-dir $tgt_dir -f "CLUS128" --sample-pct 1.0 + +for split in $all_splits; do + python $FAIRSEQ_ROOT/examples/wav2vec/unsupervised/scripts/wav2vec_apply_cluster_faiss.py $tgt_dir \ + --checkpoint $model --path $tgt_dir/CLUS128 --split $split +done + +python $FAIRSEQ_ROOT/examples/wav2vec/unsupervised/scripts/pca.py $tgt_dir/${train_split}.npy --output $tgt_dir/pca --dim $dim + +for split in $all_splits; do + python $FAIRSEQ_ROOT/examples/wav2vec/unsupervised/scripts/apply_pca.py $tgt_dir --split $split --save-dir $tgt_dir/precompute_pca$dim --pca-path $tgt_dir/pca/${dim}_pca --batch-size 1048000 + + python $FAIRSEQ_ROOT/examples/wav2vec/unsupervised/scripts/merge_clusters.py $tgt_dir/precompute_pca$dim --cluster-dir $tgt_dir/CLUS128 \ + --split $split --save-dir $tgt_dir/precompute_pca${dim}_cls128_mean --pooling mean + + python $FAIRSEQ_ROOT/examples/wav2vec/unsupervised/scripts/mean_pool.py $tgt_dir/precompute_pca${dim}_cls128_mean \ + --save-dir $tgt_dir/precompute_pca${dim}_cls128_mean_pooled --split $split +done diff --git a/fairseq/examples/wav2vec/unsupervised/scripts/prepare_audio_v2.sh b/fairseq/examples/wav2vec/unsupervised/scripts/prepare_audio_v2.sh new file mode 100644 index 0000000000000000000000000000000000000000..96a52c5c83845a539584c1ee2298b1f438b8ce70 --- /dev/null +++ b/fairseq/examples/wav2vec/unsupervised/scripts/prepare_audio_v2.sh @@ -0,0 +1,68 @@ +#!/usr/bin/env zsh +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +source_dir=$1 +tgt_dir=$2 +model=$3 + +if [ -z "$4" ] + then + dim=64 + else + dim=$4 +fi + +echo "using $dim clusters for auxilary target" + +if [ -z "$5" ] + then + layer=14 + else + layer=$5 +fi + +echo "extracting from layer $layer" + +train_split=train +valid_split=valid +test_split=test + +all_splits=($train_split) + +if [[ -f "$source_dir/valid.tsv" ]]; then + all_splits+=('valid') +fi + +if [[ -f "$source_dir/test.tsv" ]]; then + all_splits+=('test') +fi + +echo "processing splits: $all_splits" + +mkdir -p $tgt_dir + +cp $source_dir/*.tsv $tgt_dir +cp $source_dir/*.wrd $tgt_dir +cp $source_dir/*.ltr $tgt_dir +cp $source_dir/*.phn $tgt_dir +cp $source_dir/dict* $tgt_dir + +setopt shwordsplit + +for split in $all_splits; do + python $FAIRSEQ_ROOT/examples/wav2vec/unsupervised/scripts/wav2vec_extract_features.py $source_dir --split $split \ + --save-dir $tgt_dir --checkpoint $model --layer $layer +done + + +mkdir -p $tgt_dir/mfcc + +# Consider spliting corpus into chuncks for large corpus, see HuBERT preprocessing for more details +python $FAIRSEQ_ROOT/examples/hubert/simple_kmeans/dump_mfcc_feature.py \ + $tgt_dir $train_split 1 0 $tgt_dir/mfcc +python $FAIRSEQ_ROOT/examples/hubert/simple_kmeans/dump_km_label.py \ + $tgt_dir/mfcc $train_split $tgt_dir/mfcc/cls$dim 1 0 $tgt_dir/mfcc/cls${dim}_idx +cp $tgt_dir/mfcc/cls${dim}_idx/${train_split}_0_1.km $tgt_dir/$train_split.km diff --git a/fairseq/examples/wav2vec/unsupervised/scripts/prepare_text.sh b/fairseq/examples/wav2vec/unsupervised/scripts/prepare_text.sh new file mode 100644 index 0000000000000000000000000000000000000000..dbd17a2472f4f9903c96c4a3506d5b78d30afc55 --- /dev/null +++ b/fairseq/examples/wav2vec/unsupervised/scripts/prepare_text.sh @@ -0,0 +1,83 @@ +#!/usr/bin/env zsh +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +lg=$1 +text_path=$2 +target_dir=$3 +min_phones=$4 +phonemizer=$5 +lid_path=$6 +sil_prob=$7 + +if [ -z "$lid_path" ]; then + lid_path="lid.187.bin" +fi + +ph_lg=${lg:l} +if test "$lg" = 'fr'; then + ph_lg='fr-fr' +elif test "$lg" = 'en'; then + ph_lg='en-us' +elif test "$lg" = 'pt'; then + ph_lg='pt-br' +fi + +ESPEAK_PATH='' +if test "$phonemizer" = 'espeak'; then + ESPEAK_PATH=$(which espeak) +elif test "$phonemizer" = 'espeak-ng'; then + ESPEAK_PATH=$(which espeak-ng) +elif test "$phonemizer" = 'G2P'; then + ESPEAK_PATH='' +else + echo "Unknown phonemizer $phonemizer. Valid options are espeak, espean-ng and G2P" + exit 1 +fi + +echo $lg +echo $ph_lg +echo $text_path +echo $target_dir +echo "min phone seen threshold is $min_phones" + +mkdir -p $target_dir +python $FAIRSEQ_ROOT/examples/wav2vec/unsupervised/scripts/normalize_and_filter_text.py --lang $lg --fasttext-model $lid_path < $text_path | grep -v '\-\-\-' >! $target_dir/lm.upper.lid.txt +python $FAIRSEQ_ROOT/fairseq_cli/preprocess.py --dataset-impl mmap --trainpref $target_dir/lm.upper.lid.txt --only-source --destdir $target_dir --thresholdsrc 2 --padding-factor 1 --dict-only +cut -f1 -d' ' $target_dir/dict.txt | grep -v -x '[[:punct:]]*' | grep -Pv '\d\d\d\d\d+' >! $target_dir/words.txt + + +if [ -z "$ESPEAK_PATH" ]; then + python $FAIRSEQ_ROOT/examples/wav2vec/unsupervised/scripts/g2p_wrd_to_phn.py --compact < $target_dir/words.txt > $target_dir/phones.txt +else + # echoing 1 into corpus will prevent the mismatch lines between lexicon and phones in case the phonemizer fails + one=$(echo "1" | PHONEMIZER_ESPEAK_PATH=$ESPEAK_PATH phonemize -p ' ' -w '' -l $ph_lg --language-switch remove-flags) + sed 's/$/ 1/' $target_dir/words.txt | PHONEMIZER_ESPEAK_PATH=$ESPEAK_PATH phonemize -o $target_dir/phones.txt -p ' ' -w '' -l $ph_lg -j 70 --language-switch remove-flags + echo "one is ${one}" + sed -i "s/${one}$//" $target_dir/phones.txt +fi + +paste $target_dir/words.txt $target_dir/phones.txt >! $target_dir/lexicon.lst + +python $FAIRSEQ_ROOT/fairseq_cli/preprocess.py --dataset-impl mmap --trainpref $target_dir/phones.txt --only-source --destdir $target_dir/phones --thresholdsrc $min_phones --padding-factor 1 --dict-only + +python $FAIRSEQ_ROOT/examples/wav2vec/unsupervised/scripts/filter_lexicon.py -d $target_dir/phones/dict.txt < $target_dir/lexicon.lst >! $target_dir/lexicon_filtered.lst +python $FAIRSEQ_ROOT/examples/wav2vec/unsupervised/scripts/phonemize_with_sil.py -s $sil_prob --surround --lexicon $target_dir/lexicon_filtered.lst < $target_dir/lm.upper.lid.txt >! $target_dir/phones/lm.phones.filtered.txt +cp $target_dir/phones/dict.txt $target_dir/phones/dict.phn.txt +echo " 0" >> $target_dir/phones/dict.phn.txt +python $FAIRSEQ_ROOT/fairseq_cli/preprocess.py --dataset-impl mmap --trainpref $target_dir/phones/lm.phones.filtered.txt --workers 70 --only-source --destdir $target_dir/phones --srcdict $target_dir/phones/dict.phn.txt + +$KENLM_ROOT/lmplz -o 4 < $target_dir/lm.upper.lid.txt --discount_fallback --prune 0 0 0 3 >! $target_dir/kenlm.wrd.o40003.arpa +$KENLM_ROOT/build_binary $target_dir/kenlm.wrd.o40003.arpa $target_dir/kenlm.wrd.o40003.bin + +lg=$lg python $FAIRSEQ_ROOT/examples/speech_recognition/kaldi/kaldi_initializer.py kaldi_root=$KALDI_ROOT fst_dir=$target_dir/fst/phn_to_words_sil lm_arpa=$target_dir/kenlm.wrd.o40003.arpa wav2letter_lexicon=$target_dir/lexicon_filtered.lst data_dir=$target_dir/phones in_labels=phn "blank_symbol=''" +lg=$lg python $FAIRSEQ_ROOT/examples/speech_recognition/kaldi/kaldi_initializer.py kaldi_root=$KALDI_ROOT fst_dir=$target_dir/fst/phn_to_words lm_arpa=$target_dir/kenlm.wrd.o40003.arpa wav2letter_lexicon=$target_dir/lexicon_filtered.lst data_dir=$target_dir/phones in_labels=phn + +$KENLM_ROOT/lmplz -o 4 < $target_dir/phones/lm.phones.filtered.txt --discount_fallback >! $target_dir/phones/lm.phones.filtered.04.arpa +$KENLM_ROOT/build_binary $target_dir/phones/lm.phones.filtered.04.arpa $target_dir/phones/lm.phones.filtered.04.bin +$KENLM_ROOT/lmplz -o 6 < $target_dir/phones/lm.phones.filtered.txt --discount_fallback >! $target_dir/phones/lm.phones.filtered.06.arpa +$KENLM_ROOT/build_binary $target_dir/phones/lm.phones.filtered.06.arpa $target_dir/phones/lm.phones.filtered.06.bin + +lg=$lg python $FAIRSEQ_ROOT/examples/speech_recognition/kaldi/kaldi_initializer.py kaldi_root=$KALDI_ROOT fst_dir=$target_dir/fst/phn_to_phn_sil lm_arpa=$target_dir/phones/lm.phones.filtered.06.arpa data_dir=$target_dir/phones in_labels=phn "blank_symbol=''" diff --git a/fairseq/examples/wav2vec/unsupervised/scripts/prepare_timit.sh b/fairseq/examples/wav2vec/unsupervised/scripts/prepare_timit.sh new file mode 100644 index 0000000000000000000000000000000000000000..d8f5d596b4b4ec55f11a82dbbf83bad4a22c0b6c --- /dev/null +++ b/fairseq/examples/wav2vec/unsupervised/scripts/prepare_timit.sh @@ -0,0 +1,79 @@ +#!/bin/bash +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +timit_root=$1 # assume it is the upper-cased version +tgt_dir=$2 +model=$3 + +set -eu + +setups="matched unmatched" +splits="test valid train train_text" + +tgt_dir=$(realpath $tgt_dir) +sph2wav=$KALDI_ROOT/tools/sph2pipe_v2.5/sph2pipe +wav_dir=$tgt_dir/wav + + +mkdir -p $tgt_dir $wav_dir +find $timit_root/{TRAIN,TEST} -iname "*.WAV" > $tgt_dir/all_sph.flist +cat $tgt_dir/all_sph.flist | sed -e 's#//*#/#g' -e 's#.*/\([^/]*\)/\([^/]*\).WAV#\1_\2#g' > $tgt_dir/all.uid +paste -d' ' $tgt_dir/{all_sph.flist,all.uid} | \ + awk -v sph2wav=$sph2wav -v wav_dir=$wav_dir '{print sph2wav " -f wav " $1 " > " wav_dir "/" $2 ".wav"}' \ + > $tgt_dir/sph2wav.sh +bash $tgt_dir/sph2wav.sh +cat $tgt_dir/all.uid | awk -v wav_dir=$(pwd)/$wav_dir '{print $1" "wav_dir"/"$1".wav"}' | sort > $tgt_dir/all_wav.scp +cut -d' ' -f2 $tgt_dir/all_wav.scp | xargs -I{} soxi -s {} > $tgt_dir/all.dur +paste -d' ' $tgt_dir/{all_wav.scp,all.dur} > $tgt_dir/all_wav_dur.scp +rm $tgt_dir/{all.uid,all_sph.flist,sph2wav.sh} + +find $timit_root/{TRAIN,TEST} -iname "*.PHN" > $tgt_dir/all_phn60.flist +while read line; do + if [ ! -f $line ]; then + >&2 echo "Cannot find transcription file '$line'" && exit 1; + fi + cut -f3 -d' ' "$line" | tr '\n' ' ' | perl -ape 's: *$:\n:;' +done < $tgt_dir/all_phn60.flist > $tgt_dir/all.phn60 +cat $tgt_dir/all_phn60.flist | sed -e 's#//*#/#g' -e 's#.*/\([^/]*\)/\([^/]*\).PHN#\1_\2#g' | \ + paste -d' ' - $tgt_dir/all.phn60 | \ + $KALDI_ROOT/egs/timit/s5/local/timit_norm_trans.pl -i - -m $KALDI_ROOT/egs/timit/s5/conf/phones.60-48-39.map -to 39 | \ + sort > $tgt_dir/all.phn +echo "done preparing wav and 39-phone transcripts" + + +for s in $setups; do + mkdir -p $tgt_dir/$s + for x in $splits; do + uid_path=config/timit_${s}/${x}.uid + grep -w -f $uid_path $tgt_dir/all.phn | cut -d' ' -f2- > $tgt_dir/$s/$x.phn + ln -sf $(realpath $tgt_dir/$s/$x.phn) $tgt_dir/$s/$x.wrd + + echo "/" > $tgt_dir/$s/$x.tsv && grep -w -f $uid_path $tgt_dir/all_wav_dur.scp | cut -d' ' -f2- | sed 's# #\t#' >> $tgt_dir/$s/$x.tsv + done + + for x in $splits; do + cat $tgt_dir/$s/$x.phn + done | tr ' ' '\n' | sort -u | awk '{print $1" "1}' > $tgt_dir/$s/dict.phn.txt + ln -sf $(realpath $tgt_dir/$s/dict.phn.txt) $tgt_dir/$s/dict.wrd.txt +done +echo "done preparing unmatched and matched setups for TIMIT" + + +for s in $setups; do + zsh scripts/prepare_audio.sh $tgt_dir/$s $tgt_dir/$s/feat $model + + lm_dir=$tgt_dir/$s/phones + fst_dir=$tgt_dir/$s/fst/phn_to_phn + + python $FAIRSEQ_ROOT/fairseq_cli/preprocess.py --dataset-impl mmap --trainpref $tgt_dir/$s/train_text.phn --workers 10 --only-source --destdir $lm_dir --srcdict $tgt_dir/$s/dict.phn.txt + $KENLM_ROOT/lmplz -o 3 < $tgt_dir/$s/train_text.phn --discount_fallback >$lm_dir/train_text_phn.03.arpa + $KENLM_ROOT/build_binary $lm_dir/train_text_phn.03.arpa $lm_dir/train_text_phn.03.bin + $KENLM_ROOT/lmplz -o 4 < $tgt_dir/$s/train_text.phn --discount_fallback >$lm_dir/train_text_phn.04.arpa + $KENLM_ROOT/build_binary $lm_dir/train_text_phn.04.arpa $lm_dir/train_text_phn.04.bin + + python $FAIRSEQ_ROOT/examples/speech_recognition/kaldi/kaldi_initializer.py kaldi_root=$KALDI_ROOT fst_dir=$fst_dir lm_arpa=$lm_dir/train_text_phn.03.arpa data_dir=$tgt_dir/$s in_labels=phn +done +echo "done preprocessing audio and text for wav2vec-U" diff --git a/fairseq/examples/wav2vec/unsupervised/scripts/remove_silence.py b/fairseq/examples/wav2vec/unsupervised/scripts/remove_silence.py new file mode 100644 index 0000000000000000000000000000000000000000..fac88b989703262a84b242b2761df621bf02c739 --- /dev/null +++ b/fairseq/examples/wav2vec/unsupervised/scripts/remove_silence.py @@ -0,0 +1,63 @@ +#!/usr/bin/env python3 -u +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +""" +get intervals from .vads file, specify output data, and this script removes silences and saves the audio data in out path folder +paths=shards/train.tsv +vads=shards/train.vads +python remove_silence.py --paths $paths --vads $vads +""" + +import os +import argparse +import torch +import torchaudio +import tqdm + + +parser = argparse.ArgumentParser() +parser.add_argument("--tsv", default="", type=str) +parser.add_argument("--vads", default="", type=str) +parser.add_argument("--out", type=str) +params = parser.parse_args() + +# load paths +paths = [] +with open(params.tsv) as f: + root = next(f).rstrip() + for line in f: + paths.append(os.path.join(root, line.rstrip().split("\t")[0])) + +# load vads +list_intervals = [] +with open(params.vads) as f: + for line in f: + interval = [ + [int(w.split(":")[0]), int(w.split(":")[1])] for w in line.rstrip().split() + ] + list_intervals.append(interval) + + +# load audio and keep only intervals (i.e. remove silences) +for i in tqdm.trange(len(paths)): + data, _ = torchaudio.load(paths[i]) + if len(list_intervals[i]) > 0: + data_filtered = torch.cat( + [data[0][int(it[0]) : int(it[1])] for it in list_intervals[i]] + ).unsqueeze(0) + else: + data_filtered = data + + # YOU MAY NEED TO MODIFY THIS TO GET THE RIGHT SUBPATH + # outpath = params.out + '/'.join(paths[i].split('/')[-1]) + outpath = params.out + "/" + "/".join(paths[i].split("/")[-2:]) + + if not os.path.isdir("/".join(outpath.split("/")[:-1])): + os.makedirs("/".join(outpath.split("/")[:-1])) + if not os.path.exists(outpath): + torchaudio.save(outpath, data_filtered, sample_rate=16000) + else: + print(outpath, "exists!") diff --git a/fairseq/examples/wav2vec/unsupervised/scripts/vads.py b/fairseq/examples/wav2vec/unsupervised/scripts/vads.py new file mode 100644 index 0000000000000000000000000000000000000000..2398da97d8c44b8f3f270b22d5508a003482b4d6 --- /dev/null +++ b/fairseq/examples/wav2vec/unsupervised/scripts/vads.py @@ -0,0 +1,98 @@ +#!/usr/bin/env python3 -u +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import argparse +import sys + +from copy import deepcopy +from scipy.signal import lfilter + +import numpy as np +from tqdm import tqdm +import soundfile as sf +import os.path as osp + + +def get_parser(): + parser = argparse.ArgumentParser(description="compute vad segments") + parser.add_argument( + "--rvad-home", + "-r", + help="path to rvad home (see https://github.com/zhenghuatan/rVADfast)", + required=True, + ) + + return parser + + +def rvad(speechproc, path): + winlen, ovrlen, pre_coef, nfilter, nftt = 0.025, 0.01, 0.97, 20, 512 + ftThres = 0.5 + vadThres = 0.4 + opts = 1 + + data, fs = sf.read(path) + assert fs == 16_000, "sample rate must be 16khz" + ft, flen, fsh10, nfr10 = speechproc.sflux(data, fs, winlen, ovrlen, nftt) + + # --spectral flatness -- + pv01 = np.zeros(ft.shape[0]) + pv01[np.less_equal(ft, ftThres)] = 1 + pitch = deepcopy(ft) + + pvblk = speechproc.pitchblockdetect(pv01, pitch, nfr10, opts) + + # --filtering-- + ENERGYFLOOR = np.exp(-50) + b = np.array([0.9770, -0.9770]) + a = np.array([1.0000, -0.9540]) + fdata = lfilter(b, a, data, axis=0) + + # --pass 1-- + noise_samp, noise_seg, n_noise_samp = speechproc.snre_highenergy( + fdata, nfr10, flen, fsh10, ENERGYFLOOR, pv01, pvblk + ) + + # sets noisy segments to zero + for j in range(n_noise_samp): + fdata[range(int(noise_samp[j, 0]), int(noise_samp[j, 1]) + 1)] = 0 + + vad_seg = speechproc.snre_vad( + fdata, nfr10, flen, fsh10, ENERGYFLOOR, pv01, pvblk, vadThres + ) + return vad_seg, data + + +def main(): + parser = get_parser() + args = parser.parse_args() + + sys.path.append(args.rvad_home) + import speechproc + + stride = 160 + lines = sys.stdin.readlines() + root = lines[0].rstrip() + for fpath in tqdm(lines[1:]): + path = osp.join(root, fpath.split()[0]) + vads, wav = rvad(speechproc, path) + + start = None + vad_segs = [] + for i, v in enumerate(vads): + if start is None and v == 1: + start = i * stride + elif start is not None and v == 0: + vad_segs.append((start, i * stride)) + start = None + if start is not None: + vad_segs.append((start, len(wav))) + + print(" ".join(f"{v[0]}:{v[1]}" for v in vad_segs)) + + +if __name__ == "__main__": + main() diff --git a/fairseq/examples/wav2vec/unsupervised/scripts/wav2vec_apply_cluster_faiss.py b/fairseq/examples/wav2vec/unsupervised/scripts/wav2vec_apply_cluster_faiss.py new file mode 100644 index 0000000000000000000000000000000000000000..a5dd7ae6c15b358206e067385be260c94021bf20 --- /dev/null +++ b/fairseq/examples/wav2vec/unsupervised/scripts/wav2vec_apply_cluster_faiss.py @@ -0,0 +1,128 @@ +#!/usr/bin/env python3 -u +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import argparse +import os +import os.path as osp +import numpy as np +import tqdm +import torch +import sys + +import faiss +import torch.nn.functional as F + +from wav2vec_cluster_faiss import parse_faiss_specs, Wav2VecFeatureReader + + +def get_parser(): + parser = argparse.ArgumentParser(description="apply clusters") + # fmt: off + parser.add_argument('data', help='location of tsv files') + parser.add_argument('--split', help='split to process', required=True) + parser.add_argument('--labels', help='split to process', default="phn") + parser.add_argument('--path', help='path to pca and centroids', required=True) + parser.add_argument('--checkpoint', type=str, help='checkpoint for wav2vec model (if using wav2vec features)', required=True) + parser.add_argument('--layer', '-l', type=int, help='which layer to read', default=14) + parser.add_argument('--max-tsz', type=int, help='batch kmeans up to this much', default=14) + # fmt: on + + return parser + + +def get_iterator(args): + label_path = osp.join(args.data, f"{args.split}.{args.labels}") + if osp.exists(label_path): + lp = open(label_path, "r") + else: + lp = None + + with open(osp.join(args.data, f"{args.split}.tsv"), "r") as fp: + lines = fp.read().split("\n") + root = lines.pop(0).strip() + files = [line.rstrip() for line in lines if len(line) > 0] + + if lp is not None: + lbls = [line.rstrip() for line in lp] + else: + lbls = [None] * len(files) + + num = len(files) + reader = Wav2VecFeatureReader(args.checkpoint, args.layer) + + def iterate(): + for fname, lbl in zip(files, lbls): + file = osp.join(root, fname.split("\t")[0]) + feats = reader.get_feats(file) + yield feats.data, fname, lbl + + return iterate, num, root + + +def main(): + parser = get_parser() + args = parser.parse_args() + + spec = osp.basename(args.path) + + try: + faiss_spec = parse_faiss_specs(spec.rstrip("/"))[0] + except: + print(spec) + raise + + print("Faiss Spec:", faiss_spec, file=sys.stderr) + + if faiss_spec.pca: + A = torch.from_numpy(np.load(osp.join(args.path, "pca_A.npy"))).cuda() + b = torch.from_numpy(np.load(osp.join(args.path, "pca_b.npy"))).cuda() + print("Loaded PCA", file=sys.stderr) + + centroids = np.load(osp.join(args.path, "centroids.npy")) + print("Loaded centroids", centroids.shape, file=sys.stderr) + + res = faiss.StandardGpuResources() + index_flat = ( + faiss.IndexFlatL2(centroids.shape[1]) + if not faiss_spec.sphere + else faiss.IndexFlatIP(centroids.shape[1]) + ) + faiss_index = faiss.index_cpu_to_gpu(res, 0, index_flat) + faiss_index.add(centroids) + + generator, num, root = get_iterator(args) + iterator = generator() + + had_labels = False + label_path = osp.join(args.path, f"{args.split}.{args.labels}") + + with torch.no_grad(): + with open(osp.join(args.path, f"{args.split}.src"), "w") as fp, open( + osp.join(args.path, f"{args.split}.tsv"), "w" + ) as pp, open(label_path, "w") as lp: + print(root, file=pp) + for f, fname, lbl in tqdm.tqdm(iterator, total=num): + if faiss_spec.pca: + f = torch.mm(f, A) + b + if faiss_spec.norm: + f = F.normalize(f, p=2, dim=-1) + + f = f.cpu().numpy() + + _, z = faiss_index.search(f, 1) + + print(" ".join(str(x.item()) for x in z), file=fp) + print(fname, file=pp) + + if lbl is not None: + print(lbl, file=lp) + had_labels = True + if not had_labels: + os.remove(label_path) + + +if __name__ == "__main__": + main() diff --git a/fairseq/examples/wav2vec/unsupervised/scripts/wav2vec_cluster_faiss.py b/fairseq/examples/wav2vec/unsupervised/scripts/wav2vec_cluster_faiss.py new file mode 100644 index 0000000000000000000000000000000000000000..632a69e9f4bd98d33abb689c15557c818d0e35ea --- /dev/null +++ b/fairseq/examples/wav2vec/unsupervised/scripts/wav2vec_cluster_faiss.py @@ -0,0 +1,210 @@ +#!/usr/bin/env python3 -u +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import argparse +import gc +import os +import os.path as osp +import random +import numpy as np +import tqdm +import torch + +from collections import namedtuple + +import faiss + +import fairseq +import soundfile as sf + + +def get_parser(): + parser = argparse.ArgumentParser( + description="compute kmeans codebook from kaldi-computed feats" + ) + # fmt: off + parser.add_argument('data', help='location of tsv files') + parser.add_argument('--save-dir', help='where to save the output', required=True) + parser.add_argument('--checkpoint', type=str, help='checkpoint for wav2vec model (if using wav2vec features)', required=True) + parser.add_argument('--sample-pct', '-r', type=float, help='percentage of timesteps to sample', default=0) + parser.add_argument('--layer', '-l', type=int, help='which layer to read', default=14) + parser.add_argument('--faiss-specs', '-f', type=str, + help='faiss index specs; separated by space ' + 'format is: PCAx_NORM_CLUSx_SPHERICAL -> ' + 'PCAx if exists first apply PCA ' + 'NORM if exists, normalize the vector by L2 norm ' + 'CLUSx must exist, cluster to x clusters ' + 'SPEHRICAL if exists, apply spherical kmeans', + default='l2') + # fmt: on + + return parser + + +faiss_spec = namedtuple("faiss_spec", ["pca", "norm", "n_clus", "sphere", "spec_str"]) + + +def parse_faiss_specs(specs_str): + specs = [] + for ss in specs_str.split(): + comps = ss.split("_") + pca = 0 + norm = False + n_clus = 0 + sphere = False + for c in comps: + if c.startswith("PCA"): + pca = int(c[3:]) + elif c == "NORM": + norm = True + elif c.startswith("CLUS"): + n_clus = int(c[4:]) + elif c == "SPHERICAL": + sphere = True + assert n_clus > 0 + specs.append( + faiss_spec(pca=pca, norm=norm, n_clus=n_clus, sphere=sphere, spec_str=ss) + ) + return specs + + +class Wav2VecFeatureReader(object): + def __init__(self, cp_file, layer): + state = fairseq.checkpoint_utils.load_checkpoint_to_cpu(cp_file) + + self.layer = layer + + if "cfg" in state: + w2v_args = state["cfg"] + task = fairseq.tasks.setup_task(w2v_args.task) + model = task.build_model(w2v_args.model) + else: + w2v_args = state["args"] + task = fairseq.tasks.setup_task(w2v_args) + model = task.build_model(w2v_args) + model.load_state_dict(state["model"], strict=True) + model.eval() + model.cuda() + self.model = model + + def read_audio(self, fname): + """Load an audio file and return PCM along with the sample rate""" + wav, sr = sf.read(fname) + assert sr == 16e3 + + return wav + + def get_feats(self, loc): + x = self.read_audio(loc) + with torch.no_grad(): + source = torch.from_numpy(x).view(1, -1).float().cuda() + res = self.model( + source=source, mask=False, features_only=True, layer=self.layer + ) + return res["layer_results"][self.layer][0].squeeze(1) + + +def get_iterator(args): + with open(args.data, "r") as fp: + lines = fp.read().split("\n") + root = lines.pop(0).strip() + files = [osp.join(root, line.split("\t")[0]) for line in lines if len(line) > 0] + + if getattr(args, "sample_pct", 0) > 0: + files = random.sample(files, int(args.sample_pct * len(files))) + num = len(files) + reader = Wav2VecFeatureReader(args.checkpoint, args.layer) + + def iterate(): + for fname in files: + feats = reader.get_feats(fname) + yield feats.cpu().numpy() + + return iterate, num + + +def main(): + parser = get_parser() + args = parser.parse_args() + + faiss_specs = parse_faiss_specs(args.faiss_specs) + print("Faiss Specs:", faiss_specs) + + feat_path = osp.join(args.save_dir, "features") + if osp.exists(feat_path + ".npy"): + feats = np.load(feat_path + ".npy") + else: + generator, num = get_iterator(args) + iterator = generator() + + feats = [] + for f in tqdm.tqdm(iterator, total=num): + feats.append(f) + + del iterator + del generator + + feats = np.concatenate(feats) + + print(feats.shape) + + os.makedirs(args.save_dir, exist_ok=True) + # np.save(feat_path, feats) + + gc.collect() + torch.cuda.empty_cache() + + reload = False + for spec in faiss_specs: + print("Processing spec", spec) + + if reload: + print("Reloading...") + del feats + gc.collect() + feats = np.load(feat_path + ".npy") + + save_path = osp.join(args.save_dir, spec.spec_str) + os.makedirs(save_path, exist_ok=True) + d = feats.shape[-1] + x = feats + if spec.pca > 0: + print("Computing PCA") + pca = faiss.PCAMatrix(d, spec.pca) + pca.train(x) + d = spec.pca + b = faiss.vector_to_array(pca.b) + A = faiss.vector_to_array(pca.A).reshape(pca.d_out, pca.d_in) + np.save(osp.join(save_path, "pca_A"), A.T) + np.save(osp.join(save_path, "pca_b"), b) + print("Applying PCA") + x = pca.apply_py(x) + + if spec.norm: + reload = spec.pca <= 0 + print("Normalizing") + faiss.normalize_L2(x) + + print("Computing kmeans") + kmeans = faiss.Kmeans( + d, + spec.n_clus, + niter=50, + verbose=True, + spherical=spec.sphere, + max_points_per_centroid=feats.shape[0], + gpu=True, + nredo=3, + ) + kmeans.train(x) + np.save(osp.join(save_path, "centroids"), kmeans.centroids) + del kmeans + del x + gc.collect() + + +if __name__ == "__main__": + main() diff --git a/fairseq/examples/wav2vec/unsupervised/scripts/wav2vec_extract_features.py b/fairseq/examples/wav2vec/unsupervised/scripts/wav2vec_extract_features.py new file mode 100644 index 0000000000000000000000000000000000000000..b07e274d202414ce40d00aa64a27cf97bb49c1c3 --- /dev/null +++ b/fairseq/examples/wav2vec/unsupervised/scripts/wav2vec_extract_features.py @@ -0,0 +1,119 @@ +#!/usr/bin/env python3 -u +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import argparse +import os +import os.path as osp +import tqdm +import torch +import torch.nn.functional as F +from shutil import copyfile + +from npy_append_array import NpyAppendArray + +import fairseq +import soundfile as sf + + +def get_parser(): + parser = argparse.ArgumentParser( + description="compute kmeans codebook from kaldi-computed feats" + ) + # fmt: off + parser.add_argument('data', help='location of tsv files') + parser.add_argument('--split', help='which split to read', required=True) + parser.add_argument('--save-dir', help='where to save the output', required=True) + parser.add_argument('--checkpoint', type=str, help='checkpoint for wav2vec ctc model', required=True) + parser.add_argument('--layer', type=int, default=14, help='which layer to use') + # fmt: on + + return parser + + +class Wav2VecFeatureReader(object): + def __init__(self, cp_file, layer): + model, cfg, task = fairseq.checkpoint_utils.load_model_ensemble_and_task( + [cp_file] + ) + model = model[0] + model.eval() + model.cuda() + self.model = model + self.task = task + self.layer = layer + + def read_audio(self, fname): + """Load an audio file and return PCM along with the sample rate""" + wav, sr = sf.read(fname) + assert sr == 16e3 + + return wav + + def get_feats(self, loc): + x = self.read_audio(loc) + with torch.no_grad(): + source = torch.from_numpy(x).float().cuda() + if self.task.cfg.normalize: + assert source.dim() == 1, source.dim() + with torch.no_grad(): + source = F.layer_norm(source, source.shape) + source = source.view(1, -1) + + m_res = self.model(source=source, mask=False, features_only=True, layer=self.layer) + return m_res["x"].squeeze(0).cpu() + + +def get_iterator(args): + with open(osp.join(args.data, args.split) + ".tsv", "r") as fp: + lines = fp.read().split("\n") + root = lines.pop(0).strip() + files = [osp.join(root, line.split("\t")[0]) for line in lines if len(line) > 0] + + num = len(files) + reader = Wav2VecFeatureReader(args.checkpoint, args.layer) + + def iterate(): + for fname in files: + w2v_feats = reader.get_feats(fname) + yield w2v_feats + + return iterate, num + + +def main(): + parser = get_parser() + args = parser.parse_args() + + os.makedirs(args.save_dir, exist_ok=True) + + def create_files(dest): + copyfile(osp.join(args.data, args.split) + ".tsv", dest + ".tsv") + if osp.exists(osp.join(args.data, args.split) + ".wrd"): + copyfile(osp.join(args.data, args.split) + ".wrd", dest + ".wrd") + if osp.exists(osp.join(args.data, args.split) + ".phn"): + copyfile(osp.join(args.data, args.split) + ".phn", dest + ".phn") + + if osp.exists(dest + ".npy"): + os.remove(dest + ".npy") + npaa = NpyAppendArray(dest + ".npy") + return npaa + + save_path = osp.join(args.save_dir, args.split) + npaa = create_files(save_path) + + generator, num = get_iterator(args) + iterator = generator() + + with open(save_path + ".lengths", "w") as l_f: + for w2v_feats in tqdm.tqdm(iterator, total=num): + print(len(w2v_feats), file=l_f) + + if len(w2v_feats) > 0: + npaa.append(w2v_feats.numpy()) + + +if __name__ == "__main__": + main() diff --git a/fairseq/examples/wav2vec/unsupervised/scripts/wer.py b/fairseq/examples/wav2vec/unsupervised/scripts/wer.py new file mode 100644 index 0000000000000000000000000000000000000000..613ab50d39019f6edf67c56c2353646be2a2f17d --- /dev/null +++ b/fairseq/examples/wav2vec/unsupervised/scripts/wer.py @@ -0,0 +1,82 @@ +#!/usr/bin/env python3 -u +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +""" +Implement unsupervised metric for decoding hyperparameter selection: + $$ alpha * LM_PPL + ViterbitUER(%) * 100 $$ +""" +import argparse +import logging +import sys + +import editdistance + +logging.root.setLevel(logging.INFO) +logging.basicConfig(stream=sys.stdout, level=logging.INFO) +logger = logging.getLogger(__name__) + + +def get_parser(): + parser = argparse.ArgumentParser() + parser.add_argument("-s", "--hypo", help="hypo transcription", required=True) + parser.add_argument( + "-r", "--reference", help="reference transcription", required=True + ) + return parser + + +def compute_wer(ref_uid_to_tra, hyp_uid_to_tra, g2p): + d_cnt = 0 + w_cnt = 0 + w_cnt_h = 0 + for uid in hyp_uid_to_tra: + ref = ref_uid_to_tra[uid].split() + if g2p is not None: + hyp = g2p(hyp_uid_to_tra[uid]) + hyp = [p for p in hyp if p != "'" and p != " "] + hyp = [p[:-1] if p[-1].isnumeric() else p for p in hyp] + else: + hyp = hyp_uid_to_tra[uid].split() + d_cnt += editdistance.eval(ref, hyp) + w_cnt += len(ref) + w_cnt_h += len(hyp) + wer = float(d_cnt) / w_cnt + logger.debug( + ( + f"wer = {wer * 100:.2f}%; num. of ref words = {w_cnt}; " + f"num. of hyp words = {w_cnt_h}; num. of sentences = {len(ref_uid_to_tra)}" + ) + ) + return wer + + +def main(): + args = get_parser().parse_args() + + errs = 0 + count = 0 + with open(args.hypo, "r") as hf, open(args.reference, "r") as rf: + for h, r in zip(hf, rf): + h = h.rstrip().split() + r = r.rstrip().split() + errs += editdistance.eval(r, h) + count += len(r) + + logger.info(f"UER: {errs / count * 100:.2f}%") + + +if __name__ == "__main__": + main() + + +def load_tra(tra_path): + with open(tra_path, "r") as f: + uid_to_tra = {} + for line in f: + uid, tra = line.split(None, 1) + uid_to_tra[uid] = tra + logger.debug(f"loaded {len(uid_to_tra)} utterances from {tra_path}") + return uid_to_tra diff --git a/fairseq/examples/wav2vec/unsupervised/scripts/wrd_to_ltr.py b/fairseq/examples/wav2vec/unsupervised/scripts/wrd_to_ltr.py new file mode 100644 index 0000000000000000000000000000000000000000..f83471409a434556cab70086ca9e2d72d4bdddd5 --- /dev/null +++ b/fairseq/examples/wav2vec/unsupervised/scripts/wrd_to_ltr.py @@ -0,0 +1,16 @@ +#!/usr/bin/env python3 -u +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import sys + + +def main(): + for line in sys.stdin: + print(" ".join(list(line.strip().replace(" ", "|"))) + " |") + + +if __name__ == "__main__": + main() diff --git a/fairseq/examples/wav2vec/unsupervised/tasks/__init__.py b/fairseq/examples/wav2vec/unsupervised/tasks/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..6d7dd625e09451be671908578f93148f371f53cd --- /dev/null +++ b/fairseq/examples/wav2vec/unsupervised/tasks/__init__.py @@ -0,0 +1,11 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from .unpaired_audio_text import UnpairedAudioText + + +__all__ = [ + "UnpairedAudioText", +] diff --git a/fairseq/examples/wav2vec/unsupervised/tasks/unpaired_audio_text.py b/fairseq/examples/wav2vec/unsupervised/tasks/unpaired_audio_text.py new file mode 100644 index 0000000000000000000000000000000000000000..b6b65d5c495a9a1afeeca57ce2fae5d1c4b3da5a --- /dev/null +++ b/fairseq/examples/wav2vec/unsupervised/tasks/unpaired_audio_text.py @@ -0,0 +1,452 @@ +# Copyright (c) 2017-present, Facebook, Inc. +# All rights reserved. +# +# This source code is licensed under the license found in the LICENSE file in +# the root directory of this source tree. An additional grant of patent rights +# can be found in the PATENTS file in the same directory. + +from dataclasses import dataclass, field +import logging +import math +import os +from typing import Optional +import torch + +from fairseq.logging import metrics +from fairseq.tasks import FairseqTask, register_task +from ..data import ExtractedFeaturesDataset, RandomInputDataset + +from fairseq.data import ( + Dictionary, + data_utils, + StripTokenDataset, +) +from fairseq.dataclass import FairseqDataclass +from fairseq.distributed.utils import get_data_parallel_world_size +from omegaconf import MISSING + +from examples.speech_recognition.kaldi.kaldi_decoder import ( + KaldiDecoder, + KaldiDecoderConfig, +) + + +logger = logging.getLogger(__name__) + + +@dataclass +class DecodingConfig(FairseqDataclass): + kenlm_path: Optional[str] = None + lm_weight: float = 0 + blank_weight: float = 0 + + +@dataclass +class UnpairedAudioTextConfig(FairseqDataclass): + data: str = field( + default=MISSING, metadata={"help": "path to data directory containing audio"} + ) + text_data: str = field( + default=MISSING, metadata={"help": "path to data directory containing text"} + ) + max_length: Optional[int] = None + labels: Optional[str] = field( + default=None, + metadata={"help": "extension of the label file to load, used for fine-tuning"}, + ) + aux_target_postfix: Optional[str] = field( + default=None, + metadata={"help": "auxaliry target filename extension"}, + ) + unfiltered: bool = field( + default=False, metadata={"help": "load data with _unfiltered suffix"} + ) + ctc_eval: bool = field( + default=False, metadata={"help": "eval UER as if computed by CTC"} + ) + sort_by_length: bool = field( + default=True, metadata={"help": "sort examples by length of audio timesteps"} + ) + shuffle: bool = field(default=True, metadata={"help": "shuffle examples"}) + append_eos: bool = field(default=False, metadata={"help": "append eos"}) + uppercase: Optional[bool] = field( + default=False, metadata={"help": "uppercase for LM score computation"} + ) + skipwords: Optional[str] = field( + default="", + metadata={ + "help": "comma-separated words to be removed for LM score computation" + }, + ) + kenlm_path: Optional[str] = None + vocab_usage_power: float = 2 + + word_decoder_config: Optional[KaldiDecoderConfig] = None + word_kenlm_path: Optional[str] = None + + decoding_config: DecodingConfig = DecodingConfig() + + +@register_task("unpaired_audio_text", dataclass=UnpairedAudioTextConfig) +class UnpairedAudioText(FairseqTask): + """ """ + + cfg: UnpairedAudioTextConfig + + def __init__( + self, + cfg: UnpairedAudioTextConfig, + source_dictionary=None, + target_dictionary=None, + ): + super().__init__(cfg) + + self._target_dictionary = target_dictionary + self._source_dictionary = source_dictionary + self.num_symbols = ( + len([s for s in target_dictionary.symbols if not s.startswith("madeup")]) + - target_dictionary.nspecial + ) + self.sil_id = ( + target_dictionary.index("") if "" in target_dictionary else -1 + ) + self.kenlm = None + if cfg.kenlm_path is not None: + import kenlm + + self.kenlm = kenlm.Model(cfg.kenlm_path) + + self.word_kenlm = None + if cfg.word_kenlm_path is not None: + import kenlm + + self.word_kenlm = kenlm.Model(cfg.word_kenlm_path) + + self.uppercase = cfg.uppercase + self.skipwords = set(cfg.skipwords.split(",")) + + def str_postprocess(s): + s = " ".join(w for w in s.split() if w not in self.skipwords) + s = s.upper() if self.uppercase else s + return s + + self.str_postprocess = str_postprocess + self.compute_lm_score = lambda s: self.kenlm.score(self.str_postprocess(s)) + + self.compute_word_score = None + if cfg.word_decoder_config is not None: + self.kaldi_decoder = KaldiDecoder(cfg.word_decoder_config, beam=10) + + def compute_word_score(logits, padding): + res = self.kaldi_decoder.decode(logits, padding) + for r in res: + r = r.result() + assert len(r) == 1 + r = r[0] + yield r["score"], r["words"] + + self.compute_word_score = compute_word_score + + @classmethod + def setup_task(cls, cfg: UnpairedAudioTextConfig, **kwargs): + """Setup the task (e.g., load dictionaries). + + Args: + cfg (AudioPretrainingConfig): configuration of this task + """ + + dict_path = os.path.join(cfg.text_data, "dict.txt") + if os.path.exists(dict_path): + target_dictionary = Dictionary.load(dict_path) + else: + dict_path = os.path.join(cfg.data, f"dict.{cfg.labels}.txt") + target_dictionary = Dictionary.load(dict_path) + + return cls(cfg, target_dictionary=target_dictionary) + + def optimizer_step(self, optimizer, model, update_num): + if hasattr(model, "get_groups_for_update"): + groups = model.get_groups_for_update(update_num) + optimizer.step(groups={groups}) + else: + optimizer.step() + + def valid_step(self, sample, model, criterion): + res = model( + **sample["net_input"], + dense_x_only=True, + ) + + dense_x = res["logits"] + padding_mask = res["padding_mask"] + + word_scores = None + if self.compute_word_score is not None: + word_scores = self.compute_word_score(dense_x.cpu(), padding_mask.cpu()) + + z = dense_x.argmax(-1) + z[padding_mask] = self.target_dictionary.pad() + + vocab_seen = torch.zeros(self.num_symbols, dtype=torch.bool) + + import editdistance + + c_err = 0 + c_len = 0 + pred_c_len = 0 + lm_score_sum = 0 + for i, (x, t, id) in enumerate( + zip( + z, + sample["target"] if "target" in sample else [None] * len(z), + sample["id"], + ) + ): + + if t is not None: + t = t[(t >= self.target_dictionary.nspecial)] + x = x[ + (x >= self.target_dictionary.nspecial) + & (x < (self.num_symbols + self.target_dictionary.nspecial)) + ] + if self.sil_id >= 0: + x = x[x != self.sil_id] + + vocab_seen[x - self.target_dictionary.nspecial] = True + + pred_units_arr = x + if self.cfg.ctc_eval: + pred_units_arr = pred_units_arr.unique_consecutive() + pred_units_arr = pred_units_arr[pred_units_arr != 0] + + if id == 0: + if t is not None: + logger.info(f"REF: {self.target_dictionary.string(t)}") + logger.info(f"HYP: {self.target_dictionary.string(pred_units_arr)}") + + if self.kenlm is not None: + if t is not None: + ref_lm_s = self.compute_lm_score( + self.target_dictionary.string(t) + ) + logger.info( + f"LM [REF]: {ref_lm_s}, {math.pow(10, -ref_lm_s / (len(t) + 1))}" + ) + + hyp_lm_s = self.compute_lm_score( + self.target_dictionary.string(pred_units_arr) + ) + logger.info( + f"LM [HYP]: {hyp_lm_s}, {math.pow(10, -hyp_lm_s / (len(pred_units_arr) + 1))}" + ) + + pred_units_arr = pred_units_arr.tolist() + + pred_c_len += len(pred_units_arr) + + if t is not None: + t = t.tolist() + c_err += editdistance.eval(pred_units_arr, t) + c_len += len(t) + else: + c_len = pred_c_len + + if self.kenlm is not None: + pred_str = self.target_dictionary.string(pred_units_arr) + lm_score = self.compute_lm_score(pred_str) + lm_score_sum += lm_score + + kaldi_score_sum = 0 + word_lm_sum = 0 + num_words = 0 + if word_scores is not None: + for score, words in word_scores: + kaldi_score_sum += score + num_words += len(words) + if self.word_kenlm is not None: + word_lm_sum += self.kenlm.score(" ".join(words)) + + try: + world_size = get_data_parallel_world_size() + except: + world_size = 1 + + logging_output = { + "loss": c_err, + "_num_char_errors": c_err, + "_num_chars": c_len, + "_num_pred_chars": pred_c_len, + "ntokens": c_len, + "nsentences": z.size(0), + "sample_size": c_len, + "_world_size": world_size, + "_lm_score_sum": lm_score_sum, + "_kaldi_score_sum": kaldi_score_sum, + "_word_lm_sum": word_lm_sum, + "_num_words": num_words, + "_vocab_seen": vocab_seen, + } + + return c_err, c_len, logging_output + + def load_dataset(self, split: str, task_cfg: FairseqDataclass = None, **kwargs): + data_path = self.cfg.data + task_cfg = task_cfg or self.cfg + + has_unpaired_text = os.path.exists( + os.path.join(self.cfg.text_data, f"{split}.idx") + ) + + self.datasets[split] = ExtractedFeaturesDataset( + path=data_path, + split=split, + min_length=3, + max_length=task_cfg.max_length, + labels=None if has_unpaired_text else task_cfg.labels, + label_dict=self.target_dictionary, + shuffle=getattr(task_cfg, "shuffle", True), + sort_by_length=task_cfg.sort_by_length, + aux_target_postfix=task_cfg.aux_target_postfix, + ) + + logger.info(f"split {split} has unpaired text? {has_unpaired_text}") + if has_unpaired_text: + text_dataset = data_utils.load_indexed_dataset( + os.path.join(self.cfg.text_data, split), self.target_dictionary + ) + text_dataset = StripTokenDataset(text_dataset, self.target_dictionary.eos()) + self.datasets[split] = RandomInputDataset( + self.datasets[split], + text_dataset, + ["random_label"], + add_to_input=True, + pad_idx=self.target_dictionary.pad(), + ) + + @property + def source_dictionary(self): + return self._source_dictionary + + @property + def target_dictionary(self): + """Return the :class:`~fairseq.data.Dictionary` for the language + model.""" + return self._target_dictionary + + def max_positions(self): + """Maximum input length supported by the encoder.""" + return None + + def reduce_metrics(self, logging_outputs, criterion): + super().reduce_metrics(logging_outputs, criterion) + + zero = torch.scalar_tensor(0.0) + num_char_errors = sum( + log.get("_num_char_errors", zero) for log in logging_outputs + ) + num_chars = sum(log.get("_num_chars", zero) for log in logging_outputs) + num_word_errors = sum( + log.get("_num_word_errors", zero) for log in logging_outputs + ) + num_words = sum(log.get("_num_words", zero) for log in logging_outputs) + num_pred_chars = sum( + log.get("_num_pred_chars", zero) for log in logging_outputs + ) + + lm_score_sum = sum(log.get("_lm_score_sum", zero) for log in logging_outputs) + vocab_seen = ( + sum(log.get("_vocab_seen", zero) for log in logging_outputs) + .bool() + .sum() + .item() + ) + kaldi_score_sum = sum( + log.get("_kaldi_score_sum", zero) for log in logging_outputs + ) + word_lm_sum = sum(log.get("_word_lm_sum", zero) for log in logging_outputs) + + metrics.log_scalar_sum("_num_char_errors", num_char_errors) + metrics.log_scalar_sum("_num_chars", num_chars) + metrics.log_scalar_sum("_num_word_errors", num_word_errors) + metrics.log_scalar_sum("_num_words", num_words) + + metrics.log_scalar_sum("lm_score_sum", lm_score_sum) + metrics.log_scalar_sum("num_pred_chars", num_pred_chars) + + if self.cfg.word_kenlm_path is not None: + metrics.log_scalar_sum("kaldi_score_sum", kaldi_score_sum) + metrics.log_scalar_sum("word_lm_sum", word_lm_sum) + + if num_chars > 0: + metrics.log_derived( + "uer", + lambda meters: meters["_num_char_errors"].sum + * 100.0 + / meters["_num_chars"].sum + if meters["_num_chars"].sum > 0 + else float("nan"), + ) + + if lm_score_sum < 0 and vocab_seen > 0: + metrics.log_scalar("vocab_seen_pct", vocab_seen / self.num_symbols) + + metrics.log_derived( + "weighted_lm_ppl", + lambda meters: math.pow( + 10, + -meters["lm_score_sum"].sum + / ( + meters["num_pred_chars"].sum + meters["nsentences"].sum + ), # account for + ) + / meters["vocab_seen_pct"].avg ** self.cfg.vocab_usage_power, + ) + + metrics.log_derived( + "lm_ppl", + lambda meters: math.pow( + 10, + -meters["lm_score_sum"].sum + / ( + meters["num_pred_chars"].sum + meters["nsentences"].sum + ), # account for + ), + ) + else: + metrics.log_derived("weighted_lm_ppl", lambda meters: float("inf")) + + if num_words > 0: + if word_lm_sum != 0: + metrics.log_derived( + "word_lm_ppl", + lambda meters: math.pow( + 10, + -meters["word_lm_sum"].sum + / ( + meters["_num_words"].sum + meters["nsentences"].sum + ), # account for + ), + ) + metrics.log_derived( + "weighted_word_lm_ppl", + lambda meters: math.pow( + 10, + -meters["word_lm_sum"].sum + / ( + meters["_num_words"].sum + meters["nsentences"].sum + ), # account for + ) + / meters["vocab_seen_pct"].avg ** self.cfg.vocab_usage_power, + ) + + if self.cfg.word_kenlm_path is not None: + metrics.log_derived( + "kaldi_score", + lambda meters: meters["kaldi_score_sum"].sum + / meters["nsentences"].sum, + ) + + def build_model(self, cfg: FairseqDataclass, from_checkpoint=False): + model = super().build_model(cfg) + + return model diff --git a/fairseq/examples/wav2vec/xlsr/README.md b/fairseq/examples/wav2vec/xlsr/README.md new file mode 100644 index 0000000000000000000000000000000000000000..e0a7c4ef3f53ece0bef1c121e8529469f13b3481 --- /dev/null +++ b/fairseq/examples/wav2vec/xlsr/README.md @@ -0,0 +1,95 @@ +# XLS-R + +XLS-R is a set of large-scale models for self-supervised cross-lingual speech representation learning based on wav2vec 2.0. It was pretrained on 128 languages and approximately 436K hours of unlabeled speech data. With finetuning, these models achieve state of the art performance in speech translation, speech recognition and language identification. We evaluate the model across multiple benchmarks such as CoVoST-2 for speech translation, BABEL / MLS / CommonVoice / VoxPopuli for automatic speech recognition, and VoxLingua107 for language identification as we llas VoxCeleb1 for speaker identification. More details about this work can be found in our [paper](https://arxiv.org/pdf/2111.09296.pdf) and download links can be found below. + +Model | Link +|------|------ +XLS-R 300M | [download](https://dl.fbaipublicfiles.com/fairseq/wav2vec/xlsr2_300m.pt) +XLS-R 1B | [download](https://dl.fbaipublicfiles.com/fairseq/wav2vec/xlsr2_960m_1000k.pt) +XLS-R 2B | [download](https://dl.fbaipublicfiles.com/fairseq/wav2vec/xlsr2_2B_1000k.pt) + +You can also download these models [here](https://huggingface.co/models?other=xls_r) and read more about it in the [blogpost](https://huggingface.co/blog/fine-tune-xlsr-wav2vec2) from Hugging Face. + +## Speech Translation Finetuned Models + +We multilingually finetune XLS-R models on [CoVoST 2](https://github.com/facebookresearch/covost), which has 21 +into-English and 15 out-of-English directions. + +Model | Directions | Link +|------|------|------ +XLS-R 300M | 21 langs → En | [download](https://dl.fbaipublicfiles.com/fairseq/wav2vec/xls_r_300m_21_en.pt) +XLS-R 300M | En → 15 langs | [download](https://dl.fbaipublicfiles.com/fairseq/wav2vec/xls_r_300m_en_15.pt) +XLS-R 1B | 21 langs → En | [download](https://dl.fbaipublicfiles.com/fairseq/wav2vec/xls_r_1b_21_en.pt) +XLS-R 1B | En → 15 langs | [download](https://dl.fbaipublicfiles.com/fairseq/wav2vec/xls_r_1b_en_15.pt) +XLS-R 2B | 21 langs → En | [download](https://dl.fbaipublicfiles.com/fairseq/wav2vec/xls_r_2b_21_en.pt) +XLS-R 2B | En → 15 langs | [download](https://dl.fbaipublicfiles.com/fairseq/wav2vec/xls_r_2b_en_15.pt) +XLS-R 2B | 21 langs → En + En → 15 langs | [download](https://dl.fbaipublicfiles.com/fairseq/wav2vec/xls_r_2b_22_16.pt) + +## ASR Finetuning + +You can refer the original wav2vec documentation on detailed instructions about how to finetune a pretrained model with CTC [here](https://github.com/pytorch/fairseq/tree/main/examples/wav2vec#fine-tune-a-pre-trained-model-with-ctc). Below is an example command and you can find the values for different hyperparameters to reproduce the results in our paper. + +```shell script +$ fairseq-hydra-train \ + distributed_training.distributed_port=$PORT \ + task.data=/path/to/data \ + model.w2v_path=/path/to/model.pt \ + --config-dir /path/to/fairseq-py/examples/wav2vec/xlsr/config \ + --config-name finetune +``` + +For finetuning the 300M as well as 1B model, we use the same hyperparameter setting defined in `finetune.yaml`. We vary `optimization.max_update` as described in the below table and the `optimization.lr` is picked from the interval [2e-5, 3e-4] based on dev word error rate. + +Benchmark | Total Number of Updates +|------|------ +Babel | 26000 +Common Voice | 13000 +VoxPopuli | 50000 +MLS 10h | 20000 + +For finetuning the 2B model, we make some additional changes for `finetune.yaml` . We use the fully_sharded `distributed_training.ddp_backend` provided by the [fairscale](https://github.com/facebookresearch/fairscale) library and and set `model.activation_checkpoint` to true. We also increase `dataset.max_tokens` to 2560000 and use a total effective batch size of 2560000*24. We sweep for the best `optimization.lr` within the interval [3e−6,3e−5] using dev error rate. For common voice dataset, we pick the `model.mask_prob` for different languages among {0.30, 0.40} based on best dev error rate. + +## LID Inference + +Model | Link +|------|------ +XLS-R 300M + ft Voxlingua107 | [download](https://dl.fbaipublicfiles.com/fairseq/wav2vec/xlsr_300m_voxlingua107_ft.pt) + +How to run inference & calculate accuracy (step-by-step): +1. Download the Voxlingua107 checkpoint from the table above. +1. Use this python script to extract logit/embedding from the XLSR model: https://github.com/fairinternal/fairseq-py/blob/xlsr2/examples/wav2vec/gen_audio_embedding.py +```shell command +CUDA_VISIBLE_DEVICES=0 PYTHONPATH=. python3 examples/wav2vec/gen_audio_embedding.py \ + /fsx/data/VoxLingua107/manifest --path "/path/to/checkpoint.pt" \ + --task audio_classification --batch-size 90 --gen-subset test \ + --infer-manifest /fsx/data/VoxLingua107/manifest/test.tsv \ + --infer-xtimes 10 --infer-max-sample-size 160000 --output-path /tmp/tmp_voxling_infer.npz +``` + +2. Calculate the overall accuracy, 0-5 seconds and 5-20 seconds: +```shell command +PYTHONPATH='.' python examples/wav2vec/eval_speaker_clf_task.py \ + --task cls --merge mean_logit --data /tmp/tmp_voxling_infer.npz + +Output: +| run classification evaluation +| acc = 94.34% -- err = 5.66% -- correct=1518 total=1609 +| acc 0to5 = 90.91% -- err = 9.09% -- c_5=230.0 t_5=253 +| acc 5to20 = 94.99% -- err = 5.01% -- c_20=1288.0 t_20=1356 +``` + +## Citation + +Please cite as: + +``` bibtex +@article{babu2021xlsr, + title={XLS-R: Self-supervised Cross-lingual Speech Representation Learning at Scale}, + author={Arun Babu and Changhan Wang and Andros Tjandra and Kushal Lakhotia and Qiantong Xu and Naman Goyal and Kritika Singh and Patrick von Platen and Yatharth Saraf and Juan Pino and Alexei Baevski and Alexis Conneau and Michael Auli}, + year={2021}, + volume={abs/2111.09296}, + journal={arXiv}, +} +``` + + diff --git a/fairseq/examples/wav2vec/xlsr/config/finetune.yaml b/fairseq/examples/wav2vec/xlsr/config/finetune.yaml new file mode 100644 index 0000000000000000000000000000000000000000..8736e101c5047f74ee24c6198f0f9523c9e9eef6 --- /dev/null +++ b/fairseq/examples/wav2vec/xlsr/config/finetune.yaml @@ -0,0 +1,66 @@ +# @package _group_ + +common: + fp16: true + log_format: json + log_interval: 200 + tensorboard_logdir: tb + +checkpoint: + save_interval: 1000 + save_interval_updates: 1000 + keep_interval_updates: 1 + no_epoch_checkpoints: true + best_checkpoint_metric: wer + +task: + _name: audio_finetuning + data: ??? + normalize: true + labels: ltr + +dataset: + num_workers: 6 + max_tokens: 1280000 + skip_invalid_size_inputs_valid_test: true + validate_after_updates: 10000 + validate_interval_updates: 1000 + valid_subset: valid + +distributed_training: + ddp_backend: legacy_ddp + distributed_world_size: 4 + +criterion: + _name: ctc + zero_infinity: true + +optimization: + max_update: ??? + lr: [0.0003] + sentence_avg: true + update_freq: [5] + +optimizer: + _name: adam + adam_betas: (0.9,0.98) + adam_eps: 1e-08 + +lr_scheduler: + _name: tri_stage + phase_ratio: [0.1, 0.4, 0.5] + final_lr_scale: 0.05 + +model: + _name: wav2vec_ctc + w2v_path: ??? + apply_mask: true + mask_prob: 0.75 + mask_channel_prob: 0.25 + mask_channel_length: 64 + layerdrop: 0.1 + activation_dropout: 0.1 + feature_grad_mult: 0.0 + freeze_finetune_updates: 10000 + + checkpoint_activations: false diff --git a/fairseq/examples/wav2vec/xlsr/scripts/eval_speaker_clf_task.py b/fairseq/examples/wav2vec/xlsr/scripts/eval_speaker_clf_task.py new file mode 100644 index 0000000000000000000000000000000000000000..16d07516f8c249c15e4ddd11b6db15c7de3f8f08 --- /dev/null +++ b/fairseq/examples/wav2vec/xlsr/scripts/eval_speaker_clf_task.py @@ -0,0 +1,173 @@ +""" +Usage: + This scripts it to evaluate the classification accuracy/error rate from the embedding extracted + by gen_audio_embedding.py + Example (LID classification) + + PYTHONPATH='.' python examples/wav2vec/eval_speaker_clf_task.py \ + --data /fsx/androstj/exps/lid_voxlingua/infer/atj_xlsr2_100pct_300M_mean_fast_upd_100k_new.npz \ + --task cls --merge mean_logit +""" +import numpy as np +import sklearn +from sklearn.metrics.pairwise import cosine_similarity +from sklearn.preprocessing import StandardScaler +from tqdm import tqdm +import ipdb +import logging +import argparse +from scipy.special import softmax + +log=logging.getLogger(__name__) +log.setLevel(logging.INFO) + +def calculate_eer(y_label, y_score): + # y denotes groundtruth scores, + # y_score denotes the prediction scores. + from scipy.optimize import brentq + from sklearn.metrics import roc_curve + from scipy.interpolate import interp1d + + fpr, tpr, thresholds = roc_curve(y_label, y_score, pos_label=1) + eer = brentq(lambda x : 1. - x - interp1d(fpr, tpr)(x), 0., 1.) + optimal_threshold = interp1d(fpr, thresholds)(eer) + return eer, optimal_threshold + +def calculate_minDCF(y_label, y_score, p_target=0.01, c_miss=1, c_fa=1): + # https://github.com/kaldi-asr/kaldi/blob/master/egs/sre08/v1/sid/compute_min_dcf.py + from sklearn.metrics import det_curve + fpr, fnr, thresholds = det_curve(y_label, y_score, pos_label=1) + min_c_det = float("inf") + min_c_det_threshold = thresholds[0] + for i in range(0, len(fpr)): + # See Equation (2). it is a weighted sum of false negative + # and false positive errors. + c_det = c_miss * fnr[i] * p_target + c_fa * fpr[i] * (1 - p_target) + if c_det < min_c_det: + min_c_det = c_det + min_c_det_threshold = thresholds[i] + # See Equations (3) and (4). Now we normalize the cost. + c_def = min(c_miss * p_target, c_fa * (1 - p_target)) + min_dcf = min_c_det / c_def + return min_dcf, min_c_det_threshold + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--data', help='npz contains name & latent file') + parser.add_argument('--task', choices=['cls', 'veri', 'cls_voxlingua']) + parser.add_argument('--merge', choices=['mean_logit', 'first_logit', 'mean_latent_sim', 'first_latent_sim', 'mean_logit_sim', 'first_logit_sim']) + parser.add_argument('--veri-pair', help='verification file contains 1/0 utt_x utt_y') + parser.add_argument('--scaler', type=str, choices=['mean_var']) + parser.add_argument('--compress-method', choices=['pca']) + parser.add_argument('--compress-dim', type=int) + args = parser.parse_args() + + if args.task in ['cls', 'cls_voxlingua']: + print('| run classification evaluation') + data = np.load(args.data) + data_logit = data['logit'] + data_target = data['target'] + data_src_len = data['src_len'] + assert data_logit.shape[0] == data_target.shape[0] + B = data_logit.shape[0] + correct = 0 + total = 0 + data_prob = softmax(data_logit, axis=2) + correct_vs_len = np.empty((B, 2)) + for ii in range(B): + _target = data_target[ii] + if args.merge == 'mean_logit': + _prob = np.mean(data_prob[ii], axis=0) + top_1 = np.argmax(_prob) + elif args.merge == 'first_logit': + _prob = data_prob[ii][0] + top_1 = np.argmax(_prob) + else : + raise ValueError() + is_top_1 = (1 if top_1 == _target else 0) + correct += is_top_1 + total += 1 + _src_len = data_src_len[ii] / 16000 + correct_vs_len[ii] = [is_top_1, _src_len] + + acc = correct / total * 100 + t_5 = correct_vs_len[:, 1] <= 5 + t_20 = correct_vs_len[:, 1] > 5 + c_5 = correct_vs_len[t_5, 0].sum() + c_20 = correct_vs_len[t_20, 0].sum() + t_5 = t_5.sum() + t_20 = t_20.sum() + acc_5 = c_5 / t_5 * 100 + acc_20 = c_20 / t_20 * 100 + print(f'| acc = {acc:.2f}% -- err = {100-acc:.2f}% -- {correct=} {total=}') + print(f'| acc 0to5 = {acc_5:.2f}% -- err = {100-acc_5:.2f}% -- {c_5=} {t_5=}') + print(f'| acc 5to20 = {acc_20:.2f}% -- err = {100-acc_20:.2f}% -- {c_20=} {t_20=}') + + + + if args.task == 'veri': + print('| run verification evaluation') + veri_pairs = [] + with open(args.veri_pair) as ff: + for fi in ff: + a,b,c = fi.split() + a = int(a) + veri_pairs.append([a,b,c]) + + data = np.load(args.data) + if 'logit' in args.merge: + data_latent = data['logit'] + elif 'latent' in args.merge: + data_latent = data['latent'] + else : + raise ValueError() + + data_name = data['name'] + assert len(data_name) == len(data_latent) + map_name_latent = {} + + from sklearn.pipeline import make_pipeline + pipe = [] + if args.scaler == 'mean_var': + print(f'| apply StandardScaler') + pipe.append(StandardScaler()) + + if args.compress_method == 'pca': + n_comp = args.compress_dim + print(f'| apply PCA with {n_comp=}') + from sklearn.decomposition import PCA + pipe.append(PCA(n_components=n_comp)) + if len(pipe) > 0 : + pipe = make_pipeline(*pipe) + data_latent_2d = data_latent.reshape(-1, data_latent.shape[-1]) + pipe.fit(data_latent_2d) + data_latent_2d = pipe.transform(data_latent_2d) + data_latent = data_latent_2d.reshape(data_latent.shape[0], data_latent.shape[1], -1) + + for ii in range(len(data_name)): + map_name_latent[data_name[ii]] = data_latent[ii] + labels = [] + scores = [] + for lbl, pair_a, pair_b in tqdm(veri_pairs): + labels.append(lbl) + pair_a = map_name_latent[pair_a] + pair_b = map_name_latent[pair_b] + assert pair_a.ndim == pair_b.ndim == 2 + score = cosine_similarity(pair_a, pair_b) + if args.merge.startswith('mean'): + score = np.mean(score) + elif args.merge.startswith('first'): + score = score[0, 0] + else : + raise ValueError() + scores.append(score) + labels = np.array(labels) + scores = np.array(scores) + eer, eer_threshold = calculate_eer(labels, scores) + minDCF, minDCF_threshold = calculate_minDCF(labels, scores) + print('='*40) + print(f'| EER = {eer*100:.2f}%\tthreshold = {eer_threshold:.2f}') + print(f'| minDCF = {minDCF:.2f}\tthreshold = {minDCF_threshold:.2f}') + + diff --git a/fairseq/examples/wav2vec/xlsr/scripts/gen_audio_embedding.py b/fairseq/examples/wav2vec/xlsr/scripts/gen_audio_embedding.py new file mode 100644 index 0000000000000000000000000000000000000000..e5de1d5efd1f8e493dd2aba0e36b619a0e20db49 --- /dev/null +++ b/fairseq/examples/wav2vec/xlsr/scripts/gen_audio_embedding.py @@ -0,0 +1,222 @@ +""" +Usage: + This script is used to extract the embedding / logit for speech classification task. + 1. Set fdir into your model checkpoint directory + 2. Run the following command (preferrably on GPU machine to speed up the inference process) + + CUDA_VISIBLE_DEVICES=0 python3 examples/wav2vec/gen_audio_embedding.py /fsx/data/VoxLingua107/manifest --path ${fdir} \ + --task audio_classification --batch-size 90 --gen-subset test \ + --infer-manifest /fsx/data/VoxLingua107/manifest/test.tsv \ + --infer-xtimes 10 --infer-max-sample-size 160000 --output-path $odir + + Example: + Case: LID logit extraction + fdir='/fsx/androstj/exps/voxlingua_lid_train_all/ckpt_100pct_300m_voxling-act_linear-pool_mean_fast-lr_1e-4-phase_0.1_0.4_0.5-maxupd_100000-ufreq_1-mprob_0.5-fz_0-cr_softmax/0/checkpoints/checkpoint_best.pt' + python3 examples/wav2vec/gen_audio_embedding.py /fsx/data/VoxLingua107/manifest --path ${fdir} \ + --task audio_classification --batch-size 90 --gen-subset test \ + --infer-manifest /fsx/data/VoxLingua107/manifest/test.tsv \ + --infer-xtimes 10 --infer-max-sample-size 160000 --output-path $odir + +""" +import torch +from fairseq import checkpoint_utils, distributed_utils, options, utils +from fairseq.dataclass.utils import convert_namespace_to_omegaconf +from fairseq.logging import metrics, progress_bar +from fairseq import checkpoint_utils, data, options, tasks +from fairseq.data import FileAudioDataset, AddTargetDataset, Dictionary +from fairseq.tasks.audio_classification import LabelEncoder +import ipdb +import copy +import sys +from tqdm import tqdm +import tempfile +import numpy as np +import sklearn + +def subset_manifest(infer_manifest, veri_pair): + with open(infer_manifest) as ff, open(veri_pair) as gg, \ + tempfile.NamedTemporaryFile('w', delete=False) as ww: + fnames = ff.read().strip().split("\n") + basedir = fnames[0] + needed_fname = [] + for gi in gg.read().strip().split('\n'): + _, x1, x2 = gi.split() + needed_fname.append(x1) + needed_fname.append(x2) + needed_fname = set(needed_fname) + + ww.write(basedir+'\n') + for ii in range(1, len(fnames)): + x1,x2 = fnames[ii].split() + if x1 in needed_fname: + ww.write(fnames[ii]+'\n') + print(f'| subset manifest for verification: {ww.name}') + return ww.name + +def wrap_target_dataset(infer_manifest, dataset, task): + label_path = infer_manifest.replace(".tsv", ".label") + with open(label_path, "r") as f: + labels = f.read().strip().split("\n") + assert len(labels) == len(dataset) + process_label = LabelEncoder(task.target_dictionary) + dataset = AddTargetDataset(dataset, labels, + pad=task.target_dictionary.pad(), + eos=task.target_dictionary.eos(), + batch_targets=True, + process_label=process_label, + add_to_input=False) + return dataset + +def resample_data(source, padding_mask, n_sample, max_sample_len): + # source: BxT + # padding_mask: BxT + B = source.shape[0] + T = source.shape[1] + sources = [] + padding_masks = [] + seq_len = (~padding_mask).sum(1) + for jj in range(n_sample): + new_source = source.new_zeros(B, max_sample_len) + new_padding_mask = padding_mask.new_zeros(B, max_sample_len) + for ii in range(B): + if seq_len[ii] > max_sample_len: + start = np.random.randint(0, seq_len[ii]-max_sample_len+1) + end = start + max_sample_len + else : + start = 0 + end = seq_len[ii] + new_source[ii, 0:end-start] = source[ii, start:end] + new_padding_mask[ii, end-start+1:] = True + sources.append(new_source) + padding_masks.append(new_padding_mask) + return sources, padding_masks + +def resample_sample(sample, n_sample, max_sample_len): + new_sources, new_padding_masks = resample_data(sample['net_input']['source'], sample['net_input']['padding_mask'], n_sample, max_sample_len) + new_samples = [] + for ii in range(n_sample): + new_sample = copy.deepcopy(sample) + new_sample['net_input']['source'] = new_sources[ii] + new_sample['net_input']['padding_mask'] = new_padding_masks[ii] + new_samples.append(new_sample) + return new_samples + +if __name__ == '__main__': + np.random.seed(123) + # Parse command-line arguments for generation + parser = options.get_generation_parser(default_task='audio_classification') + # parser.add_argument('--infer-merge', type=str, default='mean') + parser.add_argument('--infer-xtimes', type=int, default=1) + parser.add_argument('--infer-max-sample-size', type=int, default=5*16000) # 5 secs + parser.add_argument('--infer-manifest', type=str) + parser.add_argument('--verification-pair', type=str, required=False, + help=''' + a file that contains pairs of utts to evaluated if they are from same speaker or not + format: (following voxceleb) + 1/0 + ''') + parser.add_argument('--output-path', type=str) + # parser.add_argument('--infer-xtimes', type=int, default=1) + + args = options.parse_args_and_arch(parser) + # Setup task + # task = tasks.setup_task(args) + use_cuda = not args.cpu + + # Load model & task + print('| loading model from {}'.format(args.path)) + arg_overrides = { + 'data': args.data, + # 'mask_prob': 0 + #'max_sample_size': sys.maxsize, + #'min_sample_size': 0, + } + state = checkpoint_utils.load_checkpoint_to_cpu(args.path) + # move to AWS + state['cfg']['model']['w2v_path'] = state['cfg']['model']['w2v_path'].replace('/checkpoint/arbabu/XLSR2/model_versions/', '/fsx/data/model_versions/').replace('/checkpoint/kushall/final_model_checkpoints/wav2vec2/', '/fsx/data/wav2vec_ckpt/') + state['cfg']['task']['data'] = state['cfg']['task']['data'].replace('/checkpoint/kushall/data/', '/fsx/data/') + + models, _model_args, task = checkpoint_utils.load_model_ensemble_and_task([args.path], + arg_overrides=arg_overrides, + task=None, + state=state) + model = models[0] + model.eval() + if use_cuda: + model.cuda() + + + # Load dataset + task.load_dataset(args.gen_subset) + dataset = task.dataset(args.gen_subset) + infer_manifest = args.infer_manifest + # only decode needed utts + # infer_manifest = subset_manifest(infer_manifest, + # args.verification_pair) + infer_dataset = FileAudioDataset(infer_manifest, + sample_rate=task.cfg.sample_rate, + max_sample_size=10**10, #task.cfg.max_sample_size, + min_sample_size=1, #task.cfg.min_sample_size, + pad=True, + normalize=task.cfg.normalize) + # add target (if needed) + infer_dataset = wrap_target_dataset(infer_manifest, infer_dataset, task) + itr = task.get_batch_iterator( + dataset=infer_dataset, + max_sentences=args.batch_size, + ).next_epoch_itr(shuffle=False) + + + # correct = 0 + # total = 0 + list_uttname = [] + list_latent = [] + list_logit = [] + list_target = [] + list_src_len = [] + with torch.no_grad(): + for _, sample in tqdm(enumerate(itr)): + # resample if needed + samples = resample_sample(sample, args.infer_xtimes, args.infer_max_sample_size) + list_uttname.extend(sample['name']) + list_target.extend(sample['target'][:, 0].cpu().numpy()) + list_src_len.extend((~sample['net_input']['padding_mask']).sum(1).cpu().numpy()) + latents = [] + logits = [] + for sample in samples: + sample = utils.move_to_cuda(sample) if use_cuda else sample + try: + latent = model.forward_latent(**sample['net_input']) + latents.append(latent.detach().cpu().numpy()) + except: + latent = None + logit = model.forward(**sample['net_input']) + logits.append(logit.detach().cpu().numpy()) + + if len(latents) > 0: + latents = np.stack(latents, 1) # B,X,D + logits = np.stack(logits, 1) # B,X,Cls + list_latent.extend(latents) + list_logit.extend(logits) + + # create big npz + list_uttname = np.array(list_uttname) + list_latent = np.array(list_latent) + list_target = np.array(list_target) + list_logit = np.array(list_logit) + list_src_len = np.array(list_src_len) + # save to npz + output_path = args.output_path + if (output_path is None): + output_path = tempfile.NamedTemporaryFile('wb', delete=False).name + + with open(output_path, 'wb') as ww: + np.savez(ww, name=list_uttname, + latent=list_latent, + target=list_target, + logit=list_logit, + src_len=list_src_len) + + print("="*10 + " REPORT " + "="*10) + print(f'| latent saved in {output_path}') + print(f'| {list_uttname.shape=}, {list_latent.shape=}, {list_target.shape=}, {list_logit.shape=}, {list_src_len.shape=}') diff --git a/fairseq/examples/wmt19/README.md b/fairseq/examples/wmt19/README.md new file mode 100644 index 0000000000000000000000000000000000000000..5c90d0e6c4ae8d043ca622e70c5828dca6f9c2f2 --- /dev/null +++ b/fairseq/examples/wmt19/README.md @@ -0,0 +1,85 @@ +# WMT 19 + +This page provides pointers to the models of Facebook-FAIR's WMT'19 news translation task submission [(Ng et al., 2019)](https://arxiv.org/abs/1907.06616). + +## Pre-trained models + +Model | Description | Download +---|---|--- +`transformer.wmt19.en-de` | En->De Ensemble | [download (.tar.gz)](https://dl.fbaipublicfiles.com/fairseq/models/wmt19.en-de.joined-dict.ensemble.tar.gz) +`transformer.wmt19.de-en` | De->En Ensemble | [download (.tar.gz)](https://dl.fbaipublicfiles.com/fairseq/models/wmt19.de-en.joined-dict.ensemble.tar.gz) +`transformer.wmt19.en-ru` | En->Ru Ensemble | [download (.tar.gz)](https://dl.fbaipublicfiles.com/fairseq/models/wmt19.en-ru.ensemble.tar.gz) +`transformer.wmt19.ru-en` | Ru->En Ensemble | [download (.tar.gz)](https://dl.fbaipublicfiles.com/fairseq/models/wmt19.ru-en.ensemble.tar.gz) +`transformer_lm.wmt19.en` | En Language Model | [download (.tar.gz)](https://dl.fbaipublicfiles.com/fairseq/models/lm/wmt19.en.tar.gz) +`transformer_lm.wmt19.de` | De Language Model | [download (.tar.gz)](https://dl.fbaipublicfiles.com/fairseq/models/lm/wmt19.de.tar.gz) +`transformer_lm.wmt19.ru` | Ru Language Model | [download (.tar.gz)](https://dl.fbaipublicfiles.com/fairseq/models/lm/wmt19.ru.tar.gz) + +## Pre-trained single models before finetuning + +Model | Description | Download +---|---|--- +`transformer.wmt19.en-de` | En->De Single, no finetuning | [download (.tar.gz)](https://dl.fbaipublicfiles.com/fairseq/models/wmt19.en-de.ffn8192.tar.gz) +`transformer.wmt19.de-en` | De->En Single, no finetuning | [download (.tar.gz)](https://dl.fbaipublicfiles.com/fairseq/models/wmt19.de-en.ffn8192.tar.gz) +`transformer.wmt19.en-ru` | En->Ru Single, no finetuning | [download (.tar.gz)](https://dl.fbaipublicfiles.com/fairseq/models/wmt19.en-ru.ffn8192.tar.gz) +`transformer.wmt19.ru-en` | Ru->En Single, no finetuning | [download (.tar.gz)](https://dl.fbaipublicfiles.com/fairseq/models/wmt19.ru-en.ffn8192.tar.gz) + +## Example usage (torch.hub) + +#### Requirements + +We require a few additional Python dependencies for preprocessing: +```bash +pip install fastBPE sacremoses +``` + +#### Translation + +```python +import torch + +# English to German translation +en2de = torch.hub.load('pytorch/fairseq', 'transformer.wmt19.en-de', checkpoint_file='model1.pt:model2.pt:model3.pt:model4.pt', + tokenizer='moses', bpe='fastbpe') +en2de.translate("Machine learning is great!") # 'Maschinelles Lernen ist großartig!' + +# German to English translation +de2en = torch.hub.load('pytorch/fairseq', 'transformer.wmt19.de-en', checkpoint_file='model1.pt:model2.pt:model3.pt:model4.pt', + tokenizer='moses', bpe='fastbpe') +de2en.translate("Maschinelles Lernen ist großartig!") # 'Machine learning is great!' + +# English to Russian translation +en2ru = torch.hub.load('pytorch/fairseq', 'transformer.wmt19.en-ru', checkpoint_file='model1.pt:model2.pt:model3.pt:model4.pt', + tokenizer='moses', bpe='fastbpe') +en2ru.translate("Machine learning is great!") # 'Машинное обучение - это здорово!' + +# Russian to English translation +ru2en = torch.hub.load('pytorch/fairseq', 'transformer.wmt19.ru-en', checkpoint_file='model1.pt:model2.pt:model3.pt:model4.pt', + tokenizer='moses', bpe='fastbpe') +ru2en.translate("Машинное обучение - это здорово!") # 'Machine learning is great!' +``` + +#### Language Modeling + +```python +# Sample from the English LM +en_lm = torch.hub.load('pytorch/fairseq', 'transformer_lm.wmt19.en', tokenizer='moses', bpe='fastbpe') +en_lm.sample("Machine learning is") # 'Machine learning is the future of computing, says Microsoft boss Satya Nadella ...' + +# Sample from the German LM +de_lm = torch.hub.load('pytorch/fairseq', 'transformer_lm.wmt19.de', tokenizer='moses', bpe='fastbpe') +de_lm.sample("Maschinelles lernen ist") # 'Maschinelles lernen ist das A und O (neues-deutschland.de) Die Arbeitsbedingungen für Lehrerinnen und Lehrer sind seit Jahren verbesserungswürdig ...' + +# Sample from the Russian LM +ru_lm = torch.hub.load('pytorch/fairseq', 'transformer_lm.wmt19.ru', tokenizer='moses', bpe='fastbpe') +ru_lm.sample("машинное обучение это") # 'машинное обучение это то, что мы называем "искусственным интеллектом".' +``` + +## Citation +```bibtex +@inproceedings{ng2019facebook}, + title = {Facebook FAIR's WMT19 News Translation Task Submission}, + author = {Ng, Nathan and Yee, Kyra and Baevski, Alexei and Ott, Myle and Auli, Michael and Edunov, Sergey}, + booktitle = {Proc. of WMT}, + year = 2019, +} +``` diff --git a/fairseq/examples/wmt20/README.md b/fairseq/examples/wmt20/README.md new file mode 100644 index 0000000000000000000000000000000000000000..b4f2874652f8be19998a65faa1d9276d8017ec59 --- /dev/null +++ b/fairseq/examples/wmt20/README.md @@ -0,0 +1,72 @@ +# WMT 20 + +This page provides pointers to the models of Facebook-FAIR's WMT'20 news translation task submission [(Chen et al., 2020)](https://arxiv.org/abs/2011.08298). + +## Single best MT models (after finetuning on part of WMT20 news dev set) + +Model | Description | Download +---|---|--- +`transformer.wmt20.ta-en` | Ta->En | [download (.tar.gz)](https://dl.fbaipublicfiles.com/fairseq/models/wmt20.ta-en.single.tar.gz) +`transformer.wmt20.en-ta` | En->Ta | [download (.tar.gz)](https://dl.fbaipublicfiles.com/fairseq/models/wmt20.en-ta.single.tar.gz) +`transformer.wmt20.iu-en.news` | Iu->En (News domain) | [download (.tar.gz)](https://dl.fbaipublicfiles.com/fairseq/models/wmt20.iu-en.news.single.tar.gz) +`transformer.wmt20.en-iu.news` | En->Iu (News domain) | [download (.tar.gz)](https://dl.fbaipublicfiles.com/fairseq/models/wmt20.en-iu.news.single.tar.gz) +`transformer.wmt20.iu-en.nh` | Iu->En (Nunavut Hansard domain) | [download (.tar.gz)](https://dl.fbaipublicfiles.com/fairseq/models/wmt20.iu-en.nh.single.tar.gz) +`transformer.wmt20.en-iu.nh` | En->Iu (Nunavut Hansard domain) | [download (.tar.gz)](https://dl.fbaipublicfiles.com/fairseq/models/wmt20.en-iu.nh.single.tar.gz) + +## Language models +Model | Description | Download +---|---|--- +`transformer_lm.wmt20.en` | En Language Model | [download (.tar.gz)](https://dl.fbaipublicfiles.com/fairseq/models/wmt20.en.tar.gz) +`transformer_lm.wmt20.ta` | Ta Language Model | [download (.tar.gz)](https://dl.fbaipublicfiles.com/fairseq/models/wmt20.ta.tar.gz) +`transformer_lm.wmt20.iu.news` | Iu Language Model (News domain) | [download (.tar.gz)](https://dl.fbaipublicfiles.com/fairseq/models/wmt20.iu.news.tar.gz) +`transformer_lm.wmt20.iu.nh` | Iu Language Model (Nunavut Hansard domain) | [download (.tar.gz)](https://dl.fbaipublicfiles.com/fairseq/models/wmt20.iu.nh.tar.gz) + +## Example usage (torch.hub) + +#### Translation + +```python +import torch + +# English to Tamil translation +en2ta = torch.hub.load('pytorch/fairseq', 'transformer.wmt20.en-ta') +en2ta.translate("Machine learning is great!") # 'இயந்திரக் கற்றல் அருமை!' + +# Tamil to English translation +ta2en = torch.hub.load('pytorch/fairseq', 'transformer.wmt20.ta-en') +ta2en.translate("இயந்திரக் கற்றல் அருமை!") # 'Machine learning is great!' + +# English to Inuktitut translation +en2iu = torch.hub.load('pytorch/fairseq', 'transformer.wmt20.en-iu.news') +en2iu.translate("machine learning is great!") # 'ᖃᒧᑕᐅᔭᓄᑦ ᐃᓕᓐᓂᐊᕐᓂᖅ ᐱᐅᔪᒻᒪᕆᒃ!' + +# Inuktitut to English translation +iu2en = torch.hub.load('pytorch/fairseq', 'transformer.wmt20.iu-en.news') +iu2en.translate("ᖃᒧᑕᐅᔭᓄᑦ ᐃᓕᓐᓂᐊᕐᓂᖅ ᐱᐅᔪᒻᒪᕆᒃ!") # 'Machine learning excellence!' +``` + +#### Language Modeling + +```python +# Sample from the English LM +en_lm = torch.hub.load('pytorch/fairseq', 'transformer_lm.wmt20.en') +en_lm.sample("Machine learning is") # 'Machine learning is a type of artificial intelligence that uses machine learning to learn from data and make predictions.' + +# Sample from the Tamil LM +ta_lm = torch.hub.load('pytorch/fairseq', 'transformer_lm.wmt20.ta') +ta_lm.sample("இயந்திரக் கற்றல் என்பது செயற்கை நுண்ணறிவின்") # 'இயந்திரக் கற்றல் என்பது செயற்கை நுண்ணறிவின் ஒரு பகுதியாகும்.' + +# Sample from the Inuktitut LM +iu_lm = torch.hub.load('pytorch/fairseq', 'transformer_lm.wmt20.iu.news') +iu_lm.sample("ᖃᒧᑕᐅᔭᓄᑦ ᐃᓕᓐᓂᐊᕐᓂᖅ") # 'ᖃᒧᑕᐅᔭᓄᑦ ᐃᓕᓐᓂᐊᕐᓂᖅ, ᐊᒻᒪᓗ ᓯᓚᐅᑉ ᐊᓯᙳᖅᐸᓪᓕᐊᓂᖓᓄᑦ ᖃᓄᐃᓕᐅᕈᑎᒃᓴᑦ, ᐃᓚᖃᖅᖢᑎᒃ ᐅᑯᓂᖓ:' +``` + +## Citation +```bibtex +@inproceedings{chen2020facebook + title={Facebook AI's WMT20 News Translation Task Submission}, + author={Peng-Jen Chen and Ann Lee and Changhan Wang and Naman Goyal and Angela Fan and Mary Williamson and Jiatao Gu}, + booktitle={Proc. of WMT}, + year={2020}, +} +``` diff --git a/fairseq/examples/wmt21/README.md b/fairseq/examples/wmt21/README.md new file mode 100644 index 0000000000000000000000000000000000000000..524fffb7247bce6ab2f54e2124bcc3d680702778 --- /dev/null +++ b/fairseq/examples/wmt21/README.md @@ -0,0 +1,25 @@ +# WMT 21 + +This page provides pointers to the models of Facebook AI's WMT'21 news translation task submission [(Tran et al., 2021)](https://arxiv.org/abs/2108.03265). + +## Single best dense models + +Model | Description | Download +---|---|--- +`wmt21.dense-24-wide.X-En` | X-En | [download (.tar.gz)](https://dl.fbaipublicfiles.com/fairseq/models/wmt21.dense-24-wide.X-En.tar.gz) +`wmt21.dense-24-wide.En-X` | En-X | [download (.tar.gz)](https://dl.fbaipublicfiles.com/fairseq/models/wmt21.dense-24-wide.En-X.tar.gz) + +## Example usage + +See eval.sh + + +## Citation +```bibtex +@inproceedings{tran2021facebook + title={Facebook AI’s WMT21 News Translation Task Submission}, + author={Chau Tran and Shruti Bhosale and James Cross and Philipp Koehn and Sergey Edunov and Angela Fan}, + booktitle={Proc. of WMT}, + year={2021}, +} +``` diff --git a/fairseq/examples/wmt21/eval.sh b/fairseq/examples/wmt21/eval.sh new file mode 100644 index 0000000000000000000000000000000000000000..b36d934c51b948a54f11a62cbc486a17431eea89 --- /dev/null +++ b/fairseq/examples/wmt21/eval.sh @@ -0,0 +1,49 @@ +#!/bin/bash +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +SRC=en +TGT=is +MODEL_NAME=wmt21.dense-24-wide.En-X + +PATH_TO_FAIRSEQ_PY=. +TMP_DIR=generation_tmp +mkdir -p $TMP_DIR + +REPLACE_UNICODE_PUNCT=$PATH_TO_FAIRSEQ_PY/examples/wmt21/scripts/replace-unicode-punctuation.perl +NORM_PUNCT=$PATH_TO_FAIRSEQ_PY/examples/wmt21/scripts/normalize-punctuation.perl +if [ ! -d "${TMP_DIR}/${MODEL_NAME}" ]; then + wget https://dl.fbaipublicfiles.com/fairseq/models/${MODEL_NAME}.tar.gz -P $TMP_DIR/ + tar -xvf $TMP_DIR/${MODEL_NAME}.tar.gz -C $TMP_DIR +fi +MODEL_DIR=$TMP_DIR/${MODEL_NAME} +if [ ! -d "${TMP_DIR}/wmt21-news-systems" ]; then + git clone https://github.com/wmt-conference/wmt21-news-systems $TMP_DIR/wmt21-news-systems +fi + +DOMAIN_TAG="wmtdata newsdomain" +INPUT_FILE=$TMP_DIR/wmt21-news-systems/txt/sources/newstest2021.${SRC}-${TGT}.src.${SRC} +REF_FILE=$TMP_DIR/wmt21-news-systems/txt/references/newstest2021.${SRC}-${TGT}.ref.A.${TGT} + +# Translate +cat ${INPUT_FILE} | sed "s/^/${DOMAIN_TAG} /" | $REPLACE_UNICODE_PUNCT | $NORM_PUNCT -l ${SRC} | python $PATH_TO_FAIRSEQ_PY/fairseq_cli/interactive.py $MODEL_DIR \ + --path ${MODEL_DIR}/checkpoint.pt \ + --task translation_multi_simple_epoch \ + --langs "en,ha,is,ja,cs,ru,zh,de" \ + --lang-pairs $SRC-$TGT \ + --bpe "sentencepiece" \ + --sentencepiece-model ${MODEL_DIR}/sentencepiece.model \ + --buffer-size 1024 \ + --batch-size 10 -s $SRC -t $TGT \ + --decoder-langtok \ + --encoder-langtok src \ + --beam 5 \ + --lenpen 1.0 \ + --fp16 > $TMP_DIR/${SRC}-${TGT}.gen_log + +cat $TMP_DIR/$SRC-$TGT.gen_log | grep -P "^D-" | cut -f3 > $TMP_DIR/$SRC-$TGT.hyp + +# Calculate BLEU score +sacrebleu -l $SRC-$TGT $REF_FILE < $TMP_DIR/$SRC-$TGT.hyp diff --git a/fairseq/examples/wmt21/scripts/normalize-punctuation.perl b/fairseq/examples/wmt21/scripts/normalize-punctuation.perl new file mode 100644 index 0000000000000000000000000000000000000000..a7c0750f5844ab3a10a497714ce55dcf8680cf49 --- /dev/null +++ b/fairseq/examples/wmt21/scripts/normalize-punctuation.perl @@ -0,0 +1,90 @@ +#!/usr/bin/env perl +# +# This file is part of moses. Its use is licensed under the GNU Lesser General +# Public License version 2.1 or, at your option, any later version. + +use warnings; +use strict; + +my $language = "en"; +my $PENN = 0; + +while (@ARGV) { + $_ = shift; + /^-b$/ && ($| = 1, next); # not buffered (flush each line) + /^-l$/ && ($language = shift, next); + /^[^\-]/ && ($language = $_, next); + /^-penn$/ && ($PENN = 1, next); +} + +while() { + s/\r//g; + # remove extra spaces + s/\(/ \(/g; + s/\)/\) /g; s/ +/ /g; + s/\) ([\.\!\:\?\;\,])/\)$1/g; + s/\( /\(/g; + s/ \)/\)/g; + s/(\d) \%/$1\%/g; + s/ :/:/g; + s/ ;/;/g; + # normalize unicode punctuation + if ($PENN == 0) { + s/\`/\'/g; + s/\'\'/ \" /g; + } + + s/„/\"/g; + s/“/\"/g; + s/”/\"/g; + s/–/-/g; + s/—/ - /g; s/ +/ /g; + s/´/\'/g; + s/([a-z])‘([a-z])/$1\'$2/gi; + s/([a-z])’([a-z])/$1\'$2/gi; + s/‘/\'/g; + s/‚/\'/g; + s/’/\"/g; + s/''/\"/g; + s/´´/\"/g; + s/…/.../g; + # French quotes + s/ « / \"/g; + s/« /\"/g; + s/«/\"/g; + s/ » /\" /g; + s/ »/\"/g; + s/»/\"/g; + # handle pseudo-spaces + s/ \%/\%/g; + s/nº /nº /g; + s/ :/:/g; + s/ ºC/ ºC/g; + s/ cm/ cm/g; + s/ \?/\?/g; + s/ \!/\!/g; + s/ ;/;/g; + s/, /, /g; s/ +/ /g; + + # English "quotation," followed by comma, style + if ($language eq "en") { + s/\"([,\.]+)/$1\"/g; + } + # Czech is confused + elsif ($language eq "cs" || $language eq "cz") { + } + # German/Spanish/French "quotation", followed by comma, style + else { + s/,\"/\",/g; + s/(\.+)\"(\s*[^<])/\"$1$2/g; # don't fix period at end of sentence + } + + + if ($language eq "de" || $language eq "es" || $language eq "cz" || $language eq "cs" || $language eq "fr") { + s/(\d) (\d)/$1,$2/g; + } + else { + s/(\d) (\d)/$1.$2/g; + } + print $_; +} diff --git a/fairseq/examples/wmt21/scripts/replace-unicode-punctuation.perl b/fairseq/examples/wmt21/scripts/replace-unicode-punctuation.perl new file mode 100644 index 0000000000000000000000000000000000000000..faed2cd9d86af7ad96300571b3a7b1037cc21e79 --- /dev/null +++ b/fairseq/examples/wmt21/scripts/replace-unicode-punctuation.perl @@ -0,0 +1,55 @@ +#!/usr/bin/env perl +# +# This file is part of moses. Its use is licensed under the GNU Lesser General +# Public License version 2.1 or, at your option, any later version. + +use warnings; +use strict; + +while (@ARGV) { + $_ = shift; + /^-b$/ && ($| = 1, next); # not buffered (flush each line) +} + +#binmode(STDIN, ":utf8"); +#binmode(STDOUT, ":utf8"); + +while() { + s/,/,/g; + s/。 */. /g; + s/、/,/g; + s/”/"/g; + s/“/"/g; + s/∶/:/g; + s/:/:/g; + s/?/\?/g; + s/《/"/g; + s/》/"/g; + s/)/\)/g; + s/!/\!/g; + s/(/\(/g; + s/;/;/g; + s/1/1/g; + s/」/"/g; + s/「/"/g; + s/0/0/g; + s/3/3/g; + s/2/2/g; + s/5/5/g; + s/6/6/g; + s/9/9/g; + s/7/7/g; + s/8/8/g; + s/4/4/g; + s/. */. /g; + s/~/\~/g; + s/’/\'/g; + s/…/\.\.\./g; + s/━/\-/g; + s/〈/\/g; + s/【/\[/g; + s/】/\]/g; + s/%/\%/g; + print $_; +} diff --git a/fairseq/examples/womens_bios/README.md b/fairseq/examples/womens_bios/README.md new file mode 100644 index 0000000000000000000000000000000000000000..07d06468870c5d124f44f87d9c27158b6210cd58 --- /dev/null +++ b/fairseq/examples/womens_bios/README.md @@ -0,0 +1,81 @@ +# Wikipedia Biographies of Women + + +## Training: + +The training dataset is created based on WikiSum, a dataset created from the paper [Generating Wikipedia by Summarizing Long Sequences](https://arxiv.org/pdf/1801.10198.pdf). The dataset needs to be generated following the instructions in this [Github Repository](https://github.com/tensorflow/tensor2tensor/tree/master/tensor2tensor/data_generators/wikisum). + +### How is the WikiSum dataset structured? + +Overall, the task in WikiSum was to generate the entire Wikipedia article based on the contents of the top 10 Google Search Results. The authors provide a way for people to recreate their work. In the WikiSum Github, there are two options for the dataset recreation --- the first is to use CommonCrawl (a static, open source crawl of the web) and the second to do Live Web Fetches. The second has higher coverage, but the content is subject to change and difficult to fetch. We used the static, Commoncrawl version. This can be downloaded following the Github repo instructions, though note it will require usage of Google Cloud. + +Note: in our experience, it also requires requesting that the resource limit of the Google Cloud instance be raised, which requires emailing. + +Note: Having higher coverage in the training dataset would be expected to improve the model quality. There are many instances in the dataset where the training input (web evidence) does not contain sufficient content for producing the desired Wikipedia article. This may harm the model's ability to learn to retrieve, look at the input evidence, and overall could contribute to increased challenges in generating verifiable Wikipedia biographies. + +### How do you go from WikiSum dataset to Biography dataset? + +The WikiSum dataset is for Wikipedia in general, not just biographies. We do this by querying WikiData to see if the Wikipedia article has an occupation, with the thought that all articles with occupations are probably biographies. + + +## Evaluation: + +You can download the dataset and baseline model with the following command: + +``` +wget -N 'https://dl.fbaipublicfiles.com/fairseq/womenbios_dataset.zip' +wget -N 'https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/encoder.json' +wget -N 'https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/vocab.bpe' +wget -N 'https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/dict.txt' +``` + +We provide the full text Wikipedia articles split into four categories: +- Women in Africa +- Women in Asia +- Women in Science +- Women +We note that these are not exhaustive intersectional categories and mainly stem from personal interest. + +We also provide the URL of the Wikipedia article. Note that Wikipedia articles are constantly being improved, edited, and changed. Thus, it's completely possible that the Wikipedia article on Wikipedia has been lovingly improved by other Wikipedia editors. + +To get the occupations of each biographical subject, we use WikiData. We provide a sample script to do this. We also provide the raw output of this query. + +The final part of the evaluation dataset is to query web evidence for each of the biographical subjects. This is the part of the evaluation dataset that requires the most improvement. As we discuss in our paper, one of the major reasons why it is difficult to write biographies for sometimes very well qualified women is that there is not information online about them. Further, the search engine may not find it. We encourage others to improve upon this part of the data, as even re-querying again on the internet may find new, updated sources of information as the web is constantly evolving. + +We use the search engine from [Internet-Augmented Dialogue Generation](https://arxiv.org/abs/2107.07566), see [project URL](https://parl.ai/projects/sea/) to do the search queries. Note: we remove wikipedia site sources from our query (or we'd query the data itself). However, it's possible Wikipedia information can be copied around in multiple forms on the web, linked with edits, etc. + + +## Section by Section Generation: + +Wikipedia articles are split into sections, which are usually separated by headings. These headings can be separated in the article text by looking for these equal signs (==), where the number of equal signs usually signals if you are looking at a toplevel heading or a subheading, etc. An example regex that you can use is: + +` +section_header_re = re.compile(r"(? schema:about ?uri . + ?uri wdt:P106 ?occupation . + SERVICE wikibase:label {{ bd:serviceParam wikibase:language "en" }} + }}""" + user_agent = "WDQS-example Python/%s.%s" % (sys.version_info[0], sys.version_info[1]) + sparql = SPARQLWrapper(endpoint_url, agent=user_agent) + sparql.setQuery(query) + sparql.setReturnFormat(JSON) + return sparql.query().convert() + +all_occupations = [] +for URL in urls: + results = get_results(endpoint_url, URL) + occupations = [] + for result in results["results"]["bindings"]: + occupations.append(result['occupationLabel']['value']) + all_occupations.append(result['uriLabel']['value'] + ", " + ", ".join(occupations)) + +assert(len(all_occupations) == len(urls)) + +with open("/your/file/output/here", "w") as o: + for line in all_occupations: + o.write(line.strip() + "\n") \ No newline at end of file diff --git a/fairseq/examples/xformers/README.md b/fairseq/examples/xformers/README.md new file mode 100644 index 0000000000000000000000000000000000000000..400a74d5366b7cbd0eab63af1901a45346fa3de8 --- /dev/null +++ b/fairseq/examples/xformers/README.md @@ -0,0 +1,43 @@ +# Using xFormers with FairSeq + +[xFormers](https://github.com/facebookresearch/xformers) is a xFormers is a modular library for flexibly generating transformer architectures with interoperable and optimized building blocks. +The current integration allows for FairSeq users to use an attention variant available in the xFormers repository. + +In order to enable xFormers, all that needs to be passed in is a string representing an [xFormers attention config](https://github.com/facebookresearch/xformers/blob/5f754129bfb1ea53747b1ab2077261ea762faa47/xformers/components/attention/base.py#L18). + +The various attention variants can be found [here](https://github.com/facebookresearch/xformers/tree/main/xformers/components/attention). +These include sparse attention and blocksparse attention. + +For example, you could pass in the following args: + ```python +decoder_xformers_att_config = '{"name": "scaled_dot_product"}' + +encoder_xformers_att_config = '{"name": "linformer", "seq_len": "256"}' + ``` + +In order to use blocksparse attention you would have to additionally pass in a blocksparse layout and blocksize. For example: + + ```python + + xformers_att_config = '{"name": "scaled_dot_product"}' + xformers_blocksparse_blocksize = 16 + xformers_blocksparse_layout = torch.ones( + seq_len // xformers_blocksparse_blocksize, + seq_len // xformers_blocksparse_blocksize, + ) + + xf_blocksparse_mha = ( + MultiheadAttention( + embedding, + num_heads, + dropout=0.0, + add_zero_attn=add_zero_attn, + xformers_att_config=xformers_att_config, + xformers_blocksparse_layout=xformers_blocksparse_layout, + xformers_blocksparse_blocksize=xformers_blocksparse_blocksize, + ) + + ``` + +The xFormers repository currenlty has benchmarks on the [runtime](https://github.com/facebookresearch/xformers/blob/main/docs/plots/runtime_vs_attention.png) +and [memory usage](https://github.com/facebookresearch/xformers/blob/main/docs/plots/memory_vs_attention.png) of the various attentions. diff --git a/fairseq/examples/xglm/README.md b/fairseq/examples/xglm/README.md new file mode 100644 index 0000000000000000000000000000000000000000..914e297669a1f92096a39d43f366309d57c52854 --- /dev/null +++ b/fairseq/examples/xglm/README.md @@ -0,0 +1,195 @@ +# Few-shot Learning with Multilingual Language Models + +## Introduction + +In this work, we train a family of multilingual generative language models, dubbed XGLM, on a balanced corpus covering a diverse set of languages, and study their few- and zero-shot learning capabilities in a wide range of tasks. Our largest model with 7.5 billion parameters sets new state of the art in few-shot learning on more than 20 representative languages, outperforming GPT-3 of comparable size in multilingual commonsense reasoning (+7.4 accuracy points for 0-shot, +9.4 for 4-shot) and natural language inference (+5.4 for 0-shot, +5.4 for 4-shot). We have included a [model card](model_card.md) of XGLM for transparency and accountability. + +## Data and Languages +XGLM models are trained on a new multilingual corpus extracted from CommonCrawl (CC100-XL), a significantly larger multilingual dataset covering 68 Common Crawl (CC) snapshots (from [Summer 2013](http://commoncrawl.org/2013/11/new-crawl-data-available/) to [March/April 2020](https://commoncrawl.org/2020/04/march-april-2020-crawl-archive-now-available/) consisting of 134 languages. The detailed languages and data statistics are reported in the paper (Table A.1). + +## Pre-trained models + +Model | Layers | Model Dim | FFN Dim | Languages | Download +---|---|---|---|---|--- +`XGLM 564M` | 24 | 1024 | 4096 | trained on 30 languages| [xglm.564M.tar.gz](https://dl.fbaipublicfiles.com/fairseq/models/xglm/xglm.564M.tar.gz) +`XGLM 1.7B` | 24 | 2048 | 8192 | trained on 30 languages| [xglm.1.7B.tar.gz](https://dl.fbaipublicfiles.com/fairseq/models/xglm/xglm.1.7B.tar.gz) +`XGLM 2.9B` | 48 | 2048 | 8192 | trained on 30 languages| [xglm.2.9B.tar.gz](https://dl.fbaipublicfiles.com/fairseq/models/xglm/xglm.2.9B.tar.gz) +`XGLM 7.5B` | 32 | 4096 | 16384 | trained on 30 languages| [xglm.7.5B.tar.gz](https://dl.fbaipublicfiles.com/fairseq/models/xglm/xglm.7.5B.tar.gz) +`XGLM 4.5B` | 48 | 2048 | 16384 | trained on 134 languages| [xglm.4.5B.tar.gz](https://dl.fbaipublicfiles.com/fairseq/models/xglm/xglm.4.5B.tar.gz) + +## Pre-training Data Format +Our models were pre-trained with data in the following format (i.e. paragraphs are separated with new lines and documents were separated with double new lines). +``` + ... # X0: number of tokens in para0 of doc0 + ... # Y0: number of tokens in para1 of doc0 + + ... # X1: number of tokens in para0 of doc1 + ... # Y1: number of tokens in para1 of doc1 + +... +``` +Fairseq's preprocessing replaces newlines with the end-of-sentence symbol (``). As a result, the models never saw newline characters during pretraining and the same preprocessing should be run prior to few-shot inference to maximize performance. For example, our language model scoring function has `replace_newlines_with_eos` argument to trigger this preprocessing: +```python +from fairseq.models.transformer_lm import TransformerLanguageModel + +model_dir = 'path_to_decompressed_tar_gz_dir' +lm = TransformerLanguageModel.from_pretrained(model_dir, bpe='sentencepiece') + +text = """First paragraph of the first document. +Second paragraph of the first document. + +First paragraph of the second document. +""" +tokens = lm.score(text, replace_newlines_with_eos=True)['tokens'] +assert '\n' not in lm.decode(tokens) # no newlines were encoded +``` + +## Evaluation + +### Example (COPA) + +The following snippet show how to evaluate our models on the Choice of Plausible Alternatives (COPA) task, using examples in English, Chinese and Hindi. + +```python +data_samples = { + 'en': [ + { + "premise": "I wanted to conserve energy.", + "choice1": "I swept the floor in the unoccupied room.", + "choice2": "I shut off the light in the unoccupied room.", + "question": "effect", + "label": "1" + }, + { + "premise": "The flame on the candle went out.", + "choice1": "I blew on the wick.", + "choice2": "I put a match to the wick.", + "question": "cause", + "label": "0" + } + ], + 'zh': [ + { + "premise": "我想节约能源。", + "choice1": "我在空着的房间里扫了地板。", + "choice2": "我把空房间里的灯关了。", + "question": "effect", + "label": "1" + }, + { + "premise": "蜡烛上的火焰熄灭了。", + "choice1": "我吹灭了灯芯。", + "choice2": "我把一根火柴放在灯芯上。", + "question": "cause", + "label": "0" + } + ], + 'hi': [ + { + "premise": "M te vle konsève enèji.", + "choice1": "Mwen te fin baleye chanm lib la.", + "choice2": "Mwen te femen limyè nan chanm lib la.", + "question": "effect", + "label": "1" + }, + { + "premise": "Flam bouji a te etenn.", + "choice1": "Mwen te soufle bouji a.", + "choice2": "Mwen te limen mèch bouji a.", + "question": "cause", + "label": "0" + } + ] +} +``` +In this example, we format the examples use the non-verbal prompts `{premise}\n{choice1}` and `{premise}\n{choice2}`, which are shared by all three languages. +```python +from fairseq.models.transformer_lm import TransformerLanguageModel + +model_dir = 'path_to_decompressed_tar_gz_dir' +lm = TransformerLanguageModel.from_pretrained(model_dir, bpe='sentencepiece') +lm = lm.eval() +lm = lm.half() +lm = lm.cuda() + +def get_logprobs(prompt): + import re + prompt = re.sub('\n+' , '\n', prompt) # collapse repeated newlines, which indicate separate documents + return lm.score(prompt, replace_newlines_with_eos=True)['positional_scores'] + +# Zero-shot evaluation for the Choice of Plausible Alternatives (COPA) task. +# A return value of 0 indicates that the first alternative is more plausible, +# while 1 indicates that the second alternative is more plausible. +def COPA_eval(prompt, alternative1, alternative2): + lprob1 = get_logprobs(prompt + "\n" + alternative1).sum() + lprob2 = get_logprobs(prompt + "\n" + alternative2).sum() + return 0 if lprob1 > lprob2 else 1 + +for lang in ['en', 'zh', 'hi']: + for idx, example in enumerate(data_samples[lang]): + predict = COPA_eval(example["premise"], example["choice1"], example["choice2"]) + print(f'{lang}-{idx}', predict, example['label']) + +# en-0 1 1 +# en-1 0 0 +# zh-0 1 1 +# zh-1 0 0 +# hi-0 1 1 +# hi-1 0 0 +``` + +## XStoryCloze + +We release XStoryCloze, a new multilingual dataset intended for few-shot evaluation, alongside this paper. XStoryCloze consists of professional translation of the validation split of the [English StoryCloze dataset](https://cs.rochester.edu/nlp/rocstories/) (Spring 2016 version) to 10 other languages. It is opensourced under [CC BY-SA 4.0](https://creativecommons.org/licenses/by-sa/4.0/legalcode), the same license as the English StoryCloze. + +You can download the dataset via [this link](https://dl.fbaipublicfiles.com/xstorycloze.zip). + +Language | ar | es | eu | hi | id | my | ru | sw | te | zh +---|---|---|---|---|---|---|---|---|---|--- +Train size | 360 | 360 | 360 | 360 | 360 | 360 | 360 | 360 | 360 | 360 +Eval size | 1511 | 1511 | 1511 | 1511 | 1511 | 1511 | 1511 | 1511 | 1511 | 1511 + +Please refer to [the dataset doc](XStoryCloze.md) for more information. + + +## Publication +[Few-shot Learning with Multilingual Generative Language Models](https://arxiv.org/abs/2112.10668). +Xi Victoria Lin*, Todor Mihaylov, Mikel Artetxe, Tianlu Wang, Shuohui Chen, Daniel Simig, Myle Ott, Naman Goyal, Shruti Bhosale, Jingfei Du, Ramakanth Pasunuru, Sam Shleifer, Punit Singh Koura, Vishrav Chaudhary, Brian O'Horo, Jeff Wang, Luke Zettlemoyer, Zornitsa Kozareva, Mona Diab, Veselin Stoyanov, Xian Li* (* Equal Contribution). +EMNLP 2022. + +## Citation +``` +@article{DBLP:journals/corr/abs-2112-10668, + author = {Xi Victoria Lin and + Todor Mihaylov and + Mikel Artetxe and + Tianlu Wang and + Shuohui Chen and + Daniel Simig and + Myle Ott and + Naman Goyal and + Shruti Bhosale and + Jingfei Du and + Ramakanth Pasunuru and + Sam Shleifer and + Punit Singh Koura and + Vishrav Chaudhary and + Brian O'Horo and + Jeff Wang and + Luke Zettlemoyer and + Zornitsa Kozareva and + Mona T. Diab and + Veselin Stoyanov and + Xian Li}, + title = {Few-shot Learning with Multilingual Language Models}, + journal = {CoRR}, + volume = {abs/2112.10668}, + year = {2021}, + url = {https://arxiv.org/abs/2112.10668}, + eprinttype = {arXiv}, + eprint = {2112.10668}, + timestamp = {Tue, 04 Jan 2022 15:59:27 +0100}, + biburl = {https://dblp.org/rec/journals/corr/abs-2112-10668.bib}, + bibsource = {dblp computer science bibliography, https://dblp.org} +} +``` diff --git a/fairseq/examples/xglm/XStoryCloze.md b/fairseq/examples/xglm/XStoryCloze.md new file mode 100644 index 0000000000000000000000000000000000000000..9b0fce07159ab0f4886acc206da4afd3b450f453 --- /dev/null +++ b/fairseq/examples/xglm/XStoryCloze.md @@ -0,0 +1,57 @@ +XStoryCloze consists of professional translation of the validation split of the [English StoryCloze dataset](https://cs.rochester.edu/nlp/rocstories/) (Spring 2016 version) to 10 other languages. This dataset is released by FAIR (Fundamental Artificial Intelligence Research) alongside the paper [Few-shot Learning with Multilingual Generative Language Models. EMNLP 2022](https://arxiv.org/abs/2112.10668). + +# Languages +ru, zh (Simplified), es (Latin America), ar, hi, id, te, sw, eu, my. + +# Data Splits +This dataset is intended to be used for evaluating the zero- and few-shot learning capabilities of multlingual language models. We split the data for each language into train and test (360 vs. 1510 examples, respectively). The released data files for different languages maintain a line-by-line alignment. + +# Access English StoryCloze +Please request the original English StoryCloze dataset through the [official website](https://cs.rochester.edu/nlp/rocstories/). You can create a split of the en data following our data split scheme using the following commands: +``` +head -361 spring2016.val.tsv > spring2016.val.en.tsv.split_20_80_train.tsv + +head -1 spring2016.val.tsv > spring2016.val.en.tsv.split_20_80_eval.tsv # TSV header +tail -1511 spring2016.val.tsv >> spring2016.val.en.tsv.split_20_80_eval.tsv +``` + +# Licence +XStoryCloze is opensourced under [CC BY-SA 4.0](https://creativecommons.org/licenses/by-sa/4.0/legalcode), the same license as the original English StoryCloze. + +# Citation +We hope this dataset is helpful for the research and wider NLP community. If you use XStoryCloze in your work, please cite +``` +@article{DBLP:journals/corr/abs-2112-10668, + author = {Xi Victoria Lin and + Todor Mihaylov and + Mikel Artetxe and + Tianlu Wang and + Shuohui Chen and + Daniel Simig and + Myle Ott and + Naman Goyal and + Shruti Bhosale and + Jingfei Du and + Ramakanth Pasunuru and + Sam Shleifer and + Punit Singh Koura and + Vishrav Chaudhary and + Brian O'Horo and + Jeff Wang and + Luke Zettlemoyer and + Zornitsa Kozareva and + Mona T. Diab and + Veselin Stoyanov and + Xian Li}, + title = {Few-shot Learning with Multilingual Language Models}, + journal = {CoRR}, + volume = {abs/2112.10668}, + year = {2021}, + url = {https://arxiv.org/abs/2112.10668}, + eprinttype = {arXiv}, + eprint = {2112.10668}, + timestamp = {Tue, 04 Jan 2022 15:59:27 +0100}, + biburl = {https://dblp.org/rec/journals/corr/abs-2112-10668.bib}, + bibsource = {dblp computer science bibliography, https://dblp.org} +} +``` diff --git a/fairseq/examples/xglm/model_card.md b/fairseq/examples/xglm/model_card.md new file mode 100644 index 0000000000000000000000000000000000000000..2656ec5d63ea288df744dbe2e1d77e52898183c2 --- /dev/null +++ b/fairseq/examples/xglm/model_card.md @@ -0,0 +1,152 @@ +# XGLM multilingual model +## Version 1.0.0 + +### Model developer +FAIR (Fundamental Artificial Intelligence Research) + +### Model type +A family of multilingual autoregressive language models (ranging from 564 million to 7.5 billion parameters) trained on a balanced corpus of a diverse set of languages. The language model can learn tasks from natural language descriptions and a few examples. + +### Model Feedback Channel +https://github.com/pytorch/fairseq + +## Intended use +### Primary intended use +For research purposes only, e.g. reproducing model evaluation results. Generation is only used in a limited capacity for explanation/justification or for prompting/probing/priming for class labels. + +### Out of scope uses +The primary purpose of the model is not to generate language, although the model is capable of doing that. + +## Potential risks +This section lists the potential risks associated with using the model. + +### Relevant factors +Based on known problems with NLP technology, potential relevant factors include output correctness, robustness, bias (gender, profession, race and religion), etc. + +### Evaluation factors +The model was evaluated on hate speech detection and occupation identification. +* Hate speech detection (Huang et al. (2020)) - A safety task to test language models’ ability to identify hateful and offensive text. +* Occupation identification (De-Arteaga et al., 2019), (Zhao et al., 2020) - A bias task to study language models’ performance divergence between different gender groups on the task of occupation identification. + +## Metrics +### Model performance measures +The XGLM model was primarily evaluated on +1. Zero shot and few shot learning by looking at per-language performance on tasks spanning commonsense reasoning (XCOPA, XWinograd), natural language inference (XNLI) and paraphrasing (PAWS-X). The model is also evaluated on XStoryCloze, a new dataset created by FAIR (Fundamental Artificial Intelligence Research). +2. Cross lingual transfer through templates and few-shot examples. +3. Knowledge probing - Evaluate to what extent the XGLM model can effectively store factual knowledge in different languages using the mLAMA benchmark. +4. Translation - We report machine translation results on WMT benchmarks and a subset of FLORES-101 in the main paper. + +The model was also evaluated on hate speech datasets introduced by Huang et al. (2020) and an occupation identification dataset by De-Arteaga et al. 2019 to identify bias in the model. + +### Approaches to handle uncertainty +Report confidence intervals, variance metrics for the model performance metrics. Few-shot evaluation was conducted with different sampling with 5 seeds. We reported statistical significance. + +## Evaluation data +## Zero Shot and Few Shot evaluation + +### XNLI (Conneau et al., 2018) +#### Description +The Cross-lingual Natural Language Inference (XNLI) corpus is the extension of the Multi-Genre NLI (MultiNLI) corpus to 15 languages. The dataset was created by manually translating the validation and test sets of MultiNLI into each of those 15 languages. + +### XStoryCloze +#### Description +A new dataset created by FAIR along side this work by translating the validation split of the English StoryCloze dataset (Mostafazadeh et al., 2016) (Spring 2016 version) to 10 other typologically diverse languages (ru, zh Simplified, es Latin America, ar, hi, id, te, sw, eu, my). + +### XCOPA (Ponti et al., 2020) +#### Description +The Cross-lingual Choice of Plausible Alternatives (XCOPA) dataset is a benchmark to evaluate the ability of machine learning models to transfer commonsense reasoning across languages. The dataset is the translation and reannotation of the English COPA (Roemmele et al. 2011) and covers 11 languages from 11 families and several areas around the globe. + +### XWinograd (Tikhonov and Ryabinin, 2021) +#### Description +XWinograd is a multilingual collection of Winograd Schemas in six languages that can be used for evaluation of cross-lingual commonsense reasoning capabilities. + +### PAWS-X (Yang et al., 2019) +#### Description +PAWS-X contains 23,659 human translated PAWS evaluation pairs and 296,406 machine translated training pairs in six typologically distinct languages: French, Spanish, German, Chinese, Japanese, and Korean. All translated pairs are sourced from examples in PAWS-Wiki. + +## Responsible AI (RAI) evaluation +### Hate speech (Huang et al. 2020) +This is a multilingual Twitter corpus for the task of hate speech detection with inferred four author demographic factors: age, country, gender and race/ethnicity. The corpus covers five languages: English, Italian, Polish, Portuguese and Spanish. + +### Bias dataset (De-Arteaga et al. 2019) +The aim of this dataset is to study the gender bias of models that identify a person’s occupation from their bios. + +---- + +## Training data +### CC100-XL +#### Description +Following the recent success of multilingual self-supervised pre-training (Devlin et al., 2019; Lample and Conneau, 2019; Con; Xue et al., 2020; Goyal et al., 2021a; Liu et al., 2020), we train our language models on a mixture of monolingual text of different languages. We extended the pipeline used for mining the CC100 corpus to generate CC100-XL, a significantly larger multilingual dataset covering 68 Common Crawl snapshots (from Summer 2013 to March/April 2020) and 134 languages. + +More details on the CC100-XL dataset can be found in the Appendix section of the paper. + +## RAI Dimensions +### Fairness (Bias and inclusion) +The XGLM model was evaluated on Hate speech and bias identification datasets. For hate speech, we observe that across the 5 languages in the dataset, in context learning results are only slightly better than random (50%). Another interesting observation is that most few shot results are worse than zero-shot, which indicates that the model is not able to utilize examples using the templates described in the paper. For bias identification, the XGLM (6.7B) English only model achieves the best performance on English and Spanish, while the GPT-3 model of comparable size (6.7B) model achieves the best in French. On certain occupations (e.g. model and teacher), XGLM 6.7B En only model and GPT-3 (6.7B) have very significant bias while XGLM 7.5B is much less biased. + +### Privacy and security +The XGLM model did not have any special Privacy and Security considerations. The training data and evaluation data were both public and went through standard Meta privacy and licensing procedures. + +### Transparency and control +In the spirit of transparency and accountability we have created this model card and a data card for the CC100-XL which can be found in the Appendix section of the paper. + +### Efficiency (Green AI) +From an engineering perspective, XGLM pertains to a family of models that represent single unified models catering to many languages which have wide application across many applications. Such a unified single model saves on carbon footprint as well as energy consumption (comparing to the alternative: separate models for different languages) leading to more energy efficiency. A single model, despite having the risk of being a single point of failure, has the powerful incentive of being easier to maintain, access, distribute, and track. + +## References +Edoardo Maria Ponti, Goran Glavas, Olga Majewska, Qianchu Liu, Ivan Vulic, and Anna Korhonen. 2020. XCOPA: A multilingual dataset for causal commonsense reasoning. In Proceedings of the 2020 Conference on Empirical Methods in Natural Language Processing, EMNLP 2020, Online, November 16-20, 2020, pages 2362–2376. Association for Computational Linguistics. +XCOPA Dataset | Papers With Code + +Alexey Tikhonov and Max Ryabinin. 2021. It’s all in the heads: Using attention heads as a baseline for cross-lingual transfer in commonsense reasoning. In Findings of the Association for Computational Linguistics: ACL/IJCNLP 2021, Online Event, August 1-6, 2021, volume ACL/IJCNLP 2021 of Findings of ACL, pages 3534–3546. Association for Computational Linguistics. +XWINO Dataset | Papers With Code (XWinograd) + +Yinfei Yang, Yuan Zhang, Chris Tar, and Jason Baldridge. 2019. PAWS-X: A cross-lingual adversarial dataset for paraphrase identification. CoRR, abs/1908.11828. +PAWS-X Dataset | Papers With Code + +Alexis Conneau, Guillaume Lample, Ruty Rinott, Adina Williams, Samuel R. Bowman, Holger Schwenk, and Veselin Stoyanov. 2018. XNLI: evaluating cross-lingual sentence representations. CoRR, abs/1809.05053. +XNLI Dataset | Papers With Code + +Xiaolei Huang, Linzi Xing, Franck Dernoncourt, and Michael Paul. 2020. Multilingual twitter corpus and baselines for evaluating demographic bias in hate speech recognition. In Proceedings of the 12th Language Resources and Evaluation Conference, pages 1440–1448. + +Maria De-Arteaga, Alexey Romanov, Hanna Wallach, Jennifer Chayes, Christian Borgs, Alexandra Chouldechova, Sahin Geyik, Krishnaram Kenthapadi, and Adam Tauman Kalai. 2019. Bias in bios: A case study of semantic representation bias in a high-stakes setting. In proceedings of the Conference on Fairness, Accountability, and Transparency, pages 120–128. + +Nasrin Mostafazadeh, Nathanael Chambers, Xiaodong He, Devi Parikh, Dhruv Batra, Lucy Vanderwende, Pushmeet Kohli, James F. Allen. A Corpus and Evaluation Framework for Deeper Understanding of Commonsense Stories. CoRR abs/1604.01696. + +Jieyu Zhao, Subhabrata Mukherjee, Saghar Hosseini, Kai-Wei Chang, and Ahmed Hassan Awadallah. 2020. Gender bias in multilingual embeddings and crosslingual transfer. In Proceedings of the 58th Annual Meeting of the Association for Computational Linguistics, pages 2896–2907. + +## Citation details +``` +@article{DBLP:journals/corr/abs-2112-10668, + author = {Xi Victoria Lin and + Todor Mihaylov and + Mikel Artetxe and + Tianlu Wang and + Shuohui Chen and + Daniel Simig and + Myle Ott and + Naman Goyal and + Shruti Bhosale and + Jingfei Du and + Ramakanth Pasunuru and + Sam Shleifer and + Punit Singh Koura and + Vishrav Chaudhary and + Brian O'Horo and + Jeff Wang and + Luke Zettlemoyer and + Zornitsa Kozareva and + Mona T. Diab and + Veselin Stoyanov and + Xian Li}, + title = {Few-shot Learning with Multilingual Language Models}, + journal = {CoRR}, + volume = {abs/2112.10668}, + year = {2021}, + url = {https://arxiv.org/abs/2112.10668}, + eprinttype = {arXiv}, + eprint = {2112.10668}, + timestamp = {Tue, 04 Jan 2022 15:59:27 +0100}, + biburl = {https://dblp.org/rec/journals/corr/abs-2112-10668.bib}, + bibsource = {dblp computer science bibliography, https://dblp.org} +} +``` diff --git a/fairseq/examples/xlmr/README.md b/fairseq/examples/xlmr/README.md new file mode 100644 index 0000000000000000000000000000000000000000..bba7910e303e5bc18b570b8a3877db59e1762d43 --- /dev/null +++ b/fairseq/examples/xlmr/README.md @@ -0,0 +1,144 @@ +# Unsupervised Cross-lingual Representation Learning at Scale (XLM-RoBERTa) +https://arxiv.org/pdf/1911.02116.pdf + +# Larger-Scale Transformers for Multilingual Masked Language Modeling +https://arxiv.org/pdf/2105.00572.pdf + + +## What's New: +- June 2021: `XLMR-XL` AND `XLMR-XXL` models released. + +## Introduction + +`XLM-R` (`XLM-RoBERTa`) is a generic cross lingual sentence encoder that obtains state-of-the-art results on many cross-lingual understanding (XLU) benchmarks. It is trained on `2.5T` of filtered CommonCrawl data in 100 languages (list below). + + Language | Language|Language |Language | Language +---|---|---|---|--- +Afrikaans | Albanian | Amharic | Arabic | Armenian +Assamese | Azerbaijani | Basque | Belarusian | Bengali +Bengali Romanize | Bosnian | Breton | Bulgarian | Burmese +Burmese zawgyi font | Catalan | Chinese (Simplified) | Chinese (Traditional) | Croatian +Czech | Danish | Dutch | English | Esperanto +Estonian | Filipino | Finnish | French | Galician +Georgian | German | Greek | Gujarati | Hausa +Hebrew | Hindi | Hindi Romanize | Hungarian | Icelandic +Indonesian | Irish | Italian | Japanese | Javanese +Kannada | Kazakh | Khmer | Korean | Kurdish (Kurmanji) +Kyrgyz | Lao | Latin | Latvian | Lithuanian +Macedonian | Malagasy | Malay | Malayalam | Marathi +Mongolian | Nepali | Norwegian | Oriya | Oromo +Pashto | Persian | Polish | Portuguese | Punjabi +Romanian | Russian | Sanskrit | Scottish Gaelic | Serbian +Sindhi | Sinhala | Slovak | Slovenian | Somali +Spanish | Sundanese | Swahili | Swedish | Tamil +Tamil Romanize | Telugu | Telugu Romanize | Thai | Turkish +Ukrainian | Urdu | Urdu Romanize | Uyghur | Uzbek +Vietnamese | Welsh | Western Frisian | Xhosa | Yiddish + +## Pre-trained models + +Model | Description | #params | vocab size | Download +---|---|---|---|--- +`xlmr.base` | XLM-R using the BERT-base architecture | 250M | 250k | [xlm.base.tar.gz](https://dl.fbaipublicfiles.com/fairseq/models/xlmr.base.tar.gz) +`xlmr.large` | XLM-R using the BERT-large architecture | 560M | 250k | [xlm.large.tar.gz](https://dl.fbaipublicfiles.com/fairseq/models/xlmr.large.tar.gz) +`xlmr.xl` | XLM-R (`layers=36, model_dim=2560`) | 3.5B | 250k | [xlm.xl.tar.gz](https://dl.fbaipublicfiles.com/fairseq/models/xlmr/xlmr.xl.tar.gz) +`xlmr.xxl` | XLM-R (`layers=48, model_dim=4096`) | 10.7B | 250k | [xlm.xxl.tar.gz](https://dl.fbaipublicfiles.com/fairseq/models/xlmr/xlmr.xxl.tar.gz) + +## Results + +**[XNLI (Conneau et al., 2018)](https://arxiv.org/abs/1809.05053)** + +Model | average | en | fr | es | de | el | bg | ru | tr | ar | vi | th | zh | hi | sw | ur +---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|--- +`roberta.large.mnli` _(TRANSLATE-TEST)_ | 77.8 | 91.3 | 82.9 | 84.3 | 81.2 | 81.7 | 83.1 | 78.3 | 76.8 | 76.6 | 74.2 | 74.1 | 77.5 | 70.9 | 66.7 | 66.8 +`xlmr.large` _(TRANSLATE-TRAIN-ALL)_ | 83.6 | 89.1 | 85.1 | 86.6 | 85.7 | 85.3 | 85.9 | 83.5 | 83.2 | 83.1 | 83.7 | 81.5 | 83.7 | 81.6 | 78.0 | 78.1 +`xlmr.xl` _(TRANSLATE-TRAIN-ALL)_ | 85.4 | 91.1 | 87.2 | 88.1 | 87.0 | 87.4 | 87.8 | 85.3 | 85.2 | 85.3 | 86.2 | 83.8 | 85.3 | 83.1 | 79.8 | 78.2 | 85.4 +`xlmr.xxl` _(TRANSLATE-TRAIN-ALL)_ | 86.0 | 91.5 | 87.6 | 88.7 | 87.8 | 87.4 | 88.2 | 85.6 | 85.1 | 85.8 | 86.3 | 83.9 | 85.6 | 84.6 | 81.7 | 80.6 + +**[MLQA (Lewis et al., 2018)](https://arxiv.org/abs/1910.07475)** + +Model | average | en | es | de | ar | hi | vi | zh +---|---|---|---|---|---|---|---|--- +`BERT-large` | - | 80.2/67.4 | - | - | - | - | - | - +`mBERT` | 57.7 / 41.6 | 77.7 / 65.2 | 64.3 / 46.6 | 57.9 / 44.3 | 45.7 / 29.8| 43.8 / 29.7 | 57.1 / 38.6 | 57.5 / 37.3 +`xlmr.large` | 70.7 / 52.7 | 80.6 / 67.8 | 74.1 / 56.0 | 68.5 / 53.6 | 63.1 / 43.5 | 69.2 / 51.6 | 71.3 / 50.9 | 68.0 / 45.4 +`xlmr.xl` | 73.4 / 55.3 | 85.1 / 72.6 | 66.7 / 46.2 | 70.5 / 55.5 | 74.3 / 56.9 | 72.2 / 54.7 | 74.4 / 52.9 | 70.9 / 48.5 +`xlmr.xxl` | 74.8 / 56.6 | 85.5 / 72.4 | 68.6 / 48.4 | 72.7 / 57.8 | 75.4 / 57.6 | 73.7 / 55.8 | 76.0 / 55.0 | 71.7 / 48.9 + + +## Example usage + +##### Load XLM-R from torch.hub (PyTorch >= 1.1): +```python +import torch +xlmr = torch.hub.load('pytorch/fairseq:main', 'xlmr.large') +xlmr.eval() # disable dropout (or leave in train mode to finetune) +``` + +##### Load XLM-R (for PyTorch 1.0 or custom models): +```python +# Download xlmr.large model +wget https://dl.fbaipublicfiles.com/fairseq/models/xlmr.large.tar.gz +tar -xzvf xlmr.large.tar.gz + +# Load the model in fairseq +from fairseq.models.roberta import XLMRModel +xlmr = XLMRModel.from_pretrained('/path/to/xlmr.large', checkpoint_file='model.pt') +xlmr.eval() # disable dropout (or leave in train mode to finetune) +``` + +##### Apply sentence-piece-model (SPM) encoding to input text: +```python +en_tokens = xlmr.encode('Hello world!') +assert en_tokens.tolist() == [0, 35378, 8999, 38, 2] +xlmr.decode(en_tokens) # 'Hello world!' + +zh_tokens = xlmr.encode('你好,世界') +assert zh_tokens.tolist() == [0, 6, 124084, 4, 3221, 2] +xlmr.decode(zh_tokens) # '你好,世界' + +hi_tokens = xlmr.encode('नमस्ते दुनिया') +assert hi_tokens.tolist() == [0, 68700, 97883, 29405, 2] +xlmr.decode(hi_tokens) # 'नमस्ते दुनिया' + +ar_tokens = xlmr.encode('مرحبا بالعالم') +assert ar_tokens.tolist() == [0, 665, 193478, 258, 1705, 77796, 2] +xlmr.decode(ar_tokens) # 'مرحبا بالعالم' + +fr_tokens = xlmr.encode('Bonjour le monde') +assert fr_tokens.tolist() == [0, 84602, 95, 11146, 2] +xlmr.decode(fr_tokens) # 'Bonjour le monde' +``` + +##### Extract features from XLM-R: +```python +# Extract the last layer's features +last_layer_features = xlmr.extract_features(zh_tokens) +assert last_layer_features.size() == torch.Size([1, 6, 1024]) + +# Extract all layer's features (layer 0 is the embedding layer) +all_layers = xlmr.extract_features(zh_tokens, return_all_hiddens=True) +assert len(all_layers) == 25 +assert torch.all(all_layers[-1] == last_layer_features) +``` + +## Citation + +```bibtex +@article{conneau2019unsupervised, + title={Unsupervised Cross-lingual Representation Learning at Scale}, + author={Conneau, Alexis and Khandelwal, Kartikay and Goyal, Naman and Chaudhary, Vishrav and Wenzek, Guillaume and Guzm{\'a}n, Francisco and Grave, Edouard and Ott, Myle and Zettlemoyer, Luke and Stoyanov, Veselin}, + journal={arXiv preprint arXiv:1911.02116}, + year={2019} +} +``` + + +```bibtex +@article{goyal2021larger, + title={Larger-Scale Transformers for Multilingual Masked Language Modeling}, + author={Goyal, Naman and Du, Jingfei and Ott, Myle and Anantharaman, Giri and Conneau, Alexis}, + journal={arXiv preprint arXiv:2105.00572}, + year={2021} +} +``` diff --git a/fairseq/examples/xmod/README.md b/fairseq/examples/xmod/README.md new file mode 100644 index 0000000000000000000000000000000000000000..46958b81410c1cd2f9ca1ddb044467d859700c63 --- /dev/null +++ b/fairseq/examples/xmod/README.md @@ -0,0 +1,151 @@ +# X-MOD: Lifting the Curse of Multilinguality by Pre-training Modular Transformers + +https://arxiv.org/abs/2205.06266 + + +## Introduction + +X-MOD extends multilingual masked language models like XLM-R to include language-specific modular components, introduced at each transformer layer. Each module is only used by one language. For fine-tuning, the modular components are frozen, and replaced with the target language in cross-lingual transfer settings. + + +## Pre-trained models + +Model | Size | # train steps | # langs | Download +---|---|---|---|--- +`xmod.base.13.125k` | BERT-base | 125k | 13 | [xmod.base.13.125k.tar.gz](https://dl.fbaipublicfiles.com/fairseq/models/xmod/xmod.base.13.125k.tar.gz) +`xmod.base.30.125k` | BERT-base | 125k | 30 | [xmod.base.30.125k.tar.gz](https://dl.fbaipublicfiles.com/fairseq/models/xmod/xmod.base.30.125k.tar.gz) +`xmod.base.30.195k` | BERT-base | 195k | 30 | [xmod.base.30.195k.tar.gz](https://dl.fbaipublicfiles.com/fairseq/models/xmod/xmod.base.30.195k.tar.gz) +`xmod.base.60.125k` | BERT-base | 125k | 60 | [xmod.base.60.125k.tar.gz](https://dl.fbaipublicfiles.com/fairseq/models/xmod/xmod.base.60.125k.tar.gz) +`xmod.base.60.265k` | BERT-base | 265k | 60 | [xmod.base.60.265k.tar.gz](https://dl.fbaipublicfiles.com/fairseq/models/xmod/xmod.base.60.265k.tar.gz) +`xmod.base.75.125k` | BERT-base | 125k | 75 | [xmod.base.75.125k.tar.gz](https://dl.fbaipublicfiles.com/fairseq/models/xmod/xmod.base.75.125k.tar.gz) +`xmod.base.75.269k` | BERT-base | 269k | 75 | [xmod.base.75.269k.tar.gz](https://dl.fbaipublicfiles.com/fairseq/models/xmod/xmod.base.75.269k.tar.gz) +`xmod.base` | BERT-base | 1M | 81 | [xmod.base.81.1M.tar.gz](https://dl.fbaipublicfiles.com/fairseq/models/xmod/xmod.base.81.1M.tar.gz) +`xmod.large.prenorm` | BERT-large | 500k | 81 | [xmod.large.prenorm.81.500k.tar.gz](https://dl.fbaipublicfiles.com/fairseq/models/xmod/xmod.large.prenorm.81.500k.tar.gz) + + +## Fine-tuning on NLI + +We next provide an example of how to fine-tune the pre-trained models above on Natural Language Inference (NLI). We use MNLI for training in English, and show how to run inference in other languages. + +### 1) Download a pre-trained model + +```bash +MODEL=xmod.base.81.1M +wget https://dl.fbaipublicfiles.com/fairseq/models/xmod/$MODEL.tar.gz +tar -xzf $MODEL.tar.gz +``` + +### 2) Download and preprocess [MNLI](https://cims.nyu.edu/~sbowman/multinli/) +```bash +wget https://cims.nyu.edu/~sbowman/multinli/multinli_1.0.zip +unzip multinli_1.0.zip +python ./examples/xmod/preprocess_nli.py \ + --sentencepiece-model $MODEL/sentencepiece.bpe.model \ + --train multinli_1.0/multinli_1.0_train.jsonl \ + --valid multinli_1.0/multinli_1.0_dev_matched.jsonl \ + --destdir multinli_1.0/fairseq +``` + +### 3) Fine-tune on MNLI: + +```bash +MAX_EPOCH=5 +LR=1e-05 +BATCH_SIZE=32 +DATA_DIR=multinli_1.0/fairseq/bin + +CUDA_VISIBLE_DEVICES=0 fairseq-train $DATA_DIR \ + --restore-file $MODEL/model.pt \ + --save-dir $MODEL/nli \ + --reset-optimizer \ + --reset-dataloader \ + --reset-meters \ + --best-checkpoint-metric accuracy \ + --maximize-best-checkpoint-metric \ + --task sentence_prediction_adapters \ + --num-classes 3 \ + --init-token 0 \ + --separator-token 2 \ + --max-positions 512 \ + --shorten-method "truncate" \ + --arch xmod_base \ + --dropout 0.1 \ + --attention-dropout 0.1 \ + --weight-decay 0.01 \ + --criterion sentence_prediction_adapters \ + --optimizer adam \ + --adam-betas '(0.9, 0.98)' \ + --adam-eps 1e-06 \ + --clip-norm 0.0 \ + --lr-scheduler fixed \ + --lr $LR \ + --fp16 \ + --fp16-init-scale 4 \ + --threshold-loss-scale 1 \ + --fp16-scale-window 128 \ + --batch-size $BATCH_SIZE \ + --required-batch-size-multiple 1 \ + --update-freq 1 \ + --max-epoch $MAX_EPOCH +``` + +### 4) Run inference + +After training the model, we can load it and run inference in our target language. The default language is set to English, which is why we were not required to pass a language ID to the model during fine-tuning. To run inference in a non-English language, we need to tell the model that the module of the target language should be used instead: + +```python +from fairseq.models.xmod import XMODModel + +MODEL='xmod.base.81.1M/nli' +DATA='multinli_1.0/fairseq/bin' + +# Load model +model = XMODModel.from_pretrained( + model_name_or_path=MODEL, + checkpoint_file='checkpoint_best.pt', + data_name_or_path=DATA, + suffix='', + criterion='cross_entropy', + bpe='sentencepiece', + sentencepiece_model=DATA+'/input0/sentencepiece.bpe.model') +model = model.eval(); # disable dropout +model = model.half(); # use FP16 +model = model.cuda(); # move to GPU + +def predict(premise, hypothesis, lang): + tokens = model.encode(premise, hypothesis) + idx = model.predict('sentence_classification_head', tokens, lang_id=[lang]).argmax().item() + dictionary = model.task.label_dictionary + return dictionary[idx + dictionary.nspecial] + +predict( + premise='X-Mod hat spezifische Module die für jede Sprache existieren.', + hypothesis='X-Mod hat Module.', + lang='de_DE' +) # entailment + +predict( + premise='Londres es la capital del Reino Unido.', + hypothesis='Londres está en Francia.', + lang='es_XX', +) # contradiction + +predict( + premise='Patxik gogoko ditu babarrunak.', + hypothesis='Patxik babarrunak bazkaldu zituen.', + lang='eu_ES', +) # neutral +``` + + +## Citation + +```bibtex +@misc{pfeiffer2022xmod, + doi = {10.48550/ARXIV.2205.06266}, + url = {https://arxiv.org/abs/2205.06266}, + title = {Lifting the Curse of Multilinguality by Pre-training Modular Transformers}, + publisher = {arXiv}, + year = {2022}, +} +``` diff --git a/fairseq/examples/xmod/preprocess_nli.py b/fairseq/examples/xmod/preprocess_nli.py new file mode 100644 index 0000000000000000000000000000000000000000..e1fb91c5d3a18a7a5a964d77cc63ca4f4914f2f2 --- /dev/null +++ b/fairseq/examples/xmod/preprocess_nli.py @@ -0,0 +1,168 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import os +import json +import collections +import argparse +import shutil +import subprocess +import sys +import tempfile +from multiprocessing import Pool +import sentencepiece as spm + + +def preprocess(spm_model_path, train_path, valid_path, test_path, dest_dir, remove_empty=False, output_format='piece', workers=20): + with tempfile.TemporaryDirectory() as tmp: + # Tokenize with SentencePiece + for split, path in ('train', train_path), ('valid', valid_path), ('test', test_path): + if path is None: + continue + if path == '-': + path = sys.stdin.fileno() + with open(path, encoding='utf-8', errors='surrogateescape') as fin: + with open(f'{tmp}/{split}', mode='w', encoding='utf-8', errors='surrogateescape') as fout: + encoder = MultiprocessingEncoder(model=spm_model_path, remove_empty=remove_empty, output_format=output_format) + pool = Pool(workers, initializer=encoder.initializer) + encoded_lines = pool.imap(encoder.encode, fin, 10000) + for i, line in enumerate(encoded_lines, start=1): + if line is not None: + print(line, file=fout) + if i % 10000 == 0: + print("tokenized {} lines".format(i), file=sys.stderr) + + # Generate dictionary + sp = spm.SentencePieceProcessor(model_file=spm_model_path) + if output_format == 'piece': + vocab = [sp.id_to_piece(i) for i in range(3, sp.vocab_size())] + else: + vocab = map(str, range(sp.vocab_size())) + with open(f'{tmp}/dict.txt', mode='w', encoding='utf-8', errors='surrogateescape') as f: + for word in vocab: + print(word, 1, file=f) + + # Binarize + command = [ + 'python3', '-m', 'fairseq_cli.preprocess', + '--only-source', + '--thresholdsrc', '0', + '--destdir', dest_dir, + '--srcdict', f'{tmp}/dict.txt', + '--workers', '20', + ] + for split, path in ('train', train_path), ('valid', valid_path), ('test', test_path): + if path is not None: + command += [f'--{split}pref', f'{tmp}/{split}'] + subprocess.run(command) + + # Copy SentencePiece model + shutil.copyfile(spm_model_path, f'{dest_dir}/sentencepiece.bpe.model') + + +class MultiprocessingEncoder(object): + def __init__(self, model, remove_empty, output_format): + self.model = model + self.remove_empty = remove_empty + self.output_format = output_format + + def initializer(self): + global sp + sp = spm.SentencePieceProcessor(model_file=self.model) + + def encode(self, line): + global sp + line = line.strip() + if len(line) == 0 and self.remove_empty: + return None + + if self.output_format == 'piece': + return ' '.join(sp.encode_as_pieces(line)) + else: + return ' '.join(map(str, sp.encode(line))) + + +def write_lines(lines, path): + with open(path, mode='x', encoding='utf-8') as f: + for line in lines: + print(line, file=f) + + +def read_jsonl(path): + with open(path, encoding='utf-8') as f: + return [json.loads(line) for line in f.read().splitlines()] + + +def read_nli(path, langs=None): + data = read_jsonl(path) + + if langs is not None: + data = [sample for sample in data if sample.get('language') in langs] + + lang2count = collections.defaultdict(int) + for sample in data: + lang2count[sample.get('language')] += 1 + + if langs: + assert set(lang2count.keys()) == set(langs) + + nlangs = len(lang2count) + assert nlangs > 0 + lens = list(lang2count.values()) + assert all([lens[0] == length for length in lens]) + + print(f'Loaded {lens[0]} samples in {nlangs} languages from {path}', file=sys.stderr) + return data + + +def main(): + parser = argparse.ArgumentParser(description='Tokenize and binarize NLI data') + parser.add_argument('--sentencepiece-model', required=True) + parser.add_argument('--train', required=True, help='Training data in jsonl format') + parser.add_argument('--valid', required=True, help='Validation data in jsonl format') + parser.add_argument('--destdir', required=True) + + args = parser.parse_args() + + os.makedirs(args.destdir + '/raw',) + os.makedirs(args.destdir + '/bin', ) + + # Extract input/labels + for split, path in ('train', args.train), ('valid', args.valid): + data = read_nli(path, langs=None) + original_size = len(data) + data = [sample for sample in data if sample['gold_label'] != '-'] + assert all(sample['gold_label'] in ('contradiction', 'entailment', 'neutral') for sample in data) + filtered_size = len(data) + if filtered_size != original_size: + print(f'Filtered {filtered_size}/{original_size} samples from {path}', file=sys.stderr) + for name, field in ('input0', 'sentence1'), ('input1', 'sentence2'), ('label', 'gold_label'): + write_lines([sample[field] for sample in data], f'{args.destdir}/raw/{split}.{name}.txt') + + # Tokenize and binarize input + for field in 'input0', 'input1': + preprocess( + spm_model_path=args.sentencepiece_model, + train_path=f'{args.destdir}/raw/train.{field}.txt', + valid_path=f'{args.destdir}/raw/valid.{field}.txt', + test_path=None, + dest_dir=f'{args.destdir}/bin/{field}', + workers=20, + ) + + # Binarize labels + subprocess.run([ + 'python3', '-m', 'fairseq_cli.preprocess', + '--trainpref', f'{args.destdir}/raw/train.label.txt', + '--validpref', f'{args.destdir}/raw/valid.label.txt', + '--only-source', + '--thresholdsrc', '0', + '--destdir', f'{args.destdir}/bin/label', + '--workers', '20', + ]) + + +if __name__ == '__main__': + main() diff --git a/fairseq/fairseq/__init__.py b/fairseq/fairseq/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..080c988b2da326c2fe356630d5641d367b37a546 --- /dev/null +++ b/fairseq/fairseq/__init__.py @@ -0,0 +1,45 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +"""isort:skip_file""" + +import os +import sys + +try: + from .version import __version__ # noqa +except ImportError: + version_txt = os.path.join(os.path.dirname(__file__), "version.txt") + with open(version_txt) as f: + __version__ = f.read().strip() + +__all__ = ["pdb"] + +# backwards compatibility to support `from fairseq.X import Y` +from fairseq.distributed import utils as distributed_utils +from fairseq.logging import meters, metrics, progress_bar # noqa + +sys.modules["fairseq.distributed_utils"] = distributed_utils +sys.modules["fairseq.meters"] = meters +sys.modules["fairseq.metrics"] = metrics +sys.modules["fairseq.progress_bar"] = progress_bar + +# initialize hydra +from fairseq.dataclass.initialize import hydra_init + +hydra_init() + +import fairseq.criterions # noqa +import fairseq.distributed # noqa +import fairseq.models # noqa +import fairseq.modules # noqa +import fairseq.optim # noqa +import fairseq.optim.lr_scheduler # noqa +import fairseq.pdb # noqa +import fairseq.scoring # noqa +import fairseq.tasks # noqa +import fairseq.token_generation_constraints # noqa + +import fairseq.benchmark # noqa +import fairseq.model_parallel # noqa diff --git a/fairseq/fairseq/binarizer.py b/fairseq/fairseq/binarizer.py new file mode 100644 index 0000000000000000000000000000000000000000..6f03d7a2cbb16db6aa218713211c1323adbc7d45 --- /dev/null +++ b/fairseq/fairseq/binarizer.py @@ -0,0 +1,381 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import logging +import os +import typing as tp +from abc import ABC, abstractmethod +from collections import Counter +from dataclasses import dataclass +from multiprocessing import Pool + +import torch + +from fairseq.data import Dictionary, indexed_dataset +from fairseq.file_chunker_utils import Chunker, find_offsets +from fairseq.file_io import PathManager +from fairseq.tokenizer import tokenize_line + +logger = logging.getLogger("binarizer") + + +@dataclass +class BinarizeSummary: + """ + Keep track of what's going on in the binarizer + """ + + num_seq: int = 0 + replaced: tp.Optional[Counter] = None + num_tok: int = 0 + + @property + def num_replaced(self) -> int: + if self.replaced is None: + return 0 + return sum(self.replaced.values()) + + @property + def replaced_percent(self) -> float: + return 100 * self.num_replaced / self.num_tok + + def __str__(self) -> str: + base = f"{self.num_seq} sents, {self.num_tok} tokens" + if self.replaced is None: + return base + + return f"{base}, {self.replaced_percent:.3}% replaced" + + def merge(self, other: "BinarizeSummary"): + replaced = None + if self.replaced is not None: + replaced = self.replaced + if other.replaced is not None: + if replaced is None: + replaced = other.replaced + else: + replaced += other.replaced + self.replaced = replaced + self.num_seq += other.num_seq + self.num_tok += other.num_tok + + +class Binarizer(ABC): + """ + a binarizer describes how to take a string and build a tensor out of it + """ + + @abstractmethod + def binarize_line( + self, + line: str, + summary: BinarizeSummary, + ) -> torch.IntTensor: + ... + + +def _worker_prefix(output_prefix: str, worker_id: int): + return f"{output_prefix}.pt{worker_id}" + + +class FileBinarizer: + """ + An file binarizer can take a file, tokenize it, and binarize each line to a tensor + """ + + @classmethod + def multiprocess_dataset( + cls, + input_file: str, + dataset_impl: str, + binarizer: Binarizer, + output_prefix: str, + vocab_size=None, + num_workers=1, + ) -> BinarizeSummary: + final_summary = BinarizeSummary() + + offsets = find_offsets(input_file, num_workers) + # find_offsets returns a list of position [pos1, pos2, pos3, pos4] but we would want pairs: + # [(pos1, pos2), (pos2, pos3), (pos3, pos4)] to process the chunks with start/end info + # we zip the list with itself shifted by one to get all the pairs. + (first_chunk, *more_chunks) = zip(offsets, offsets[1:]) + pool = None + if num_workers > 1: + pool = Pool(processes=num_workers - 1) + worker_results = [ + pool.apply_async( + cls._binarize_chunk_and_finalize, + args=( + binarizer, + input_file, + start_offset, + end_offset, + _worker_prefix( + output_prefix, + worker_id, + ), + dataset_impl, + ), + kwds={ + "vocab_size": vocab_size, + } + if vocab_size is not None + else {}, + ) + for worker_id, (start_offset, end_offset) in enumerate( + more_chunks, start=1 + ) + ] + + pool.close() + pool.join() + for r in worker_results: + summ = r.get() + final_summary.merge(summ) + + # do not close the bin file as we need to merge the worker results in + final_ds, summ = cls._binarize_file_chunk( + binarizer, + input_file, + offset_start=first_chunk[0], + offset_end=first_chunk[1], + output_prefix=output_prefix, + dataset_impl=dataset_impl, + vocab_size=vocab_size if vocab_size is not None else None, + ) + final_summary.merge(summ) + + if num_workers > 1: + for worker_id in range(1, num_workers): + # merge the worker outputs + worker_output_prefix = _worker_prefix( + output_prefix, + worker_id, + ) + final_ds.merge_file_(worker_output_prefix) + try: + os.remove(indexed_dataset.data_file_path(worker_output_prefix)) + os.remove(indexed_dataset.index_file_path(worker_output_prefix)) + except Exception as e: + logger.error( + f"couldn't remove {worker_output_prefix}.*", exc_info=e + ) + + # now we can close the file + idx_file = indexed_dataset.index_file_path(output_prefix) + final_ds.finalize(idx_file) + return final_summary + + @staticmethod + def _binarize_file_chunk( + binarizer: Binarizer, + filename: str, + offset_start: int, + offset_end: int, + output_prefix: str, + dataset_impl: str, + vocab_size=None, + ) -> tp.Tuple[tp.Any, BinarizeSummary]: # (dataset builder, BinarizeSummary) + """ + creates a dataset builder and append binarized items to it. This function does not + finalize the builder, this is useful if you want to do other things with your bin file + like appending/merging other files + """ + bin_file = indexed_dataset.data_file_path(output_prefix) + ds = indexed_dataset.make_builder( + bin_file, + impl=dataset_impl, + vocab_size=vocab_size, + ) + summary = BinarizeSummary() + + with Chunker( + PathManager.get_local_path(filename), offset_start, offset_end + ) as line_iterator: + for line in line_iterator: + ds.add_item(binarizer.binarize_line(line, summary)) + + return ds, summary + + @classmethod + def _binarize_chunk_and_finalize( + cls, + binarizer: Binarizer, + filename: str, + offset_start: int, + offset_end: int, + output_prefix: str, + dataset_impl: str, + vocab_size=None, + ): + """ + same as above, but also finalizes the builder + """ + ds, summ = cls._binarize_file_chunk( + binarizer, + filename, + offset_start, + offset_end, + output_prefix, + dataset_impl, + vocab_size=vocab_size, + ) + + idx_file = indexed_dataset.index_file_path(output_prefix) + ds.finalize(idx_file) + + return summ + + +class VocabularyDatasetBinarizer(Binarizer): + """ + Takes a Dictionary/Vocabulary, assign ids to each + token using the dictionary encode_line function. + """ + + def __init__( + self, + dict: Dictionary, + tokenize: tp.Callable[[str], tp.List[str]] = tokenize_line, + append_eos: bool = True, + reverse_order: bool = False, + already_numberized: bool = False, + ) -> None: + self.dict = dict + self.tokenize = tokenize + self.append_eos = append_eos + self.reverse_order = reverse_order + self.already_numberized = already_numberized + super().__init__() + + def binarize_line( + self, + line: str, + summary: BinarizeSummary, + ): + if summary.replaced is None: + summary.replaced = Counter() + + def replaced_consumer(word, idx): + if idx == self.dict.unk_index and word != self.dict.unk_word: + summary.replaced.update([word]) + + if self.already_numberized: + id_strings = line.strip().split() + id_list = [int(id_string) for id_string in id_strings] + if self.reverse_order: + id_list.reverse() + if self.append_eos: + id_list.append(self.dict.eos()) + ids = torch.IntTensor(id_list) + else: + ids = self.dict.encode_line( + line=line, + line_tokenizer=self.tokenize, + add_if_not_exist=False, + consumer=replaced_consumer, + append_eos=self.append_eos, + reverse_order=self.reverse_order, + ) + + summary.num_seq += 1 + summary.num_tok += len(ids) + return ids + + +class AlignmentDatasetBinarizer(Binarizer): + """ + binarize by parsing a set of alignments and packing + them in a tensor (see utils.parse_alignment) + """ + + def __init__( + self, + alignment_parser: tp.Callable[[str], torch.IntTensor], + ) -> None: + super().__init__() + self.alignment_parser = alignment_parser + + def binarize_line( + self, + line: str, + summary: BinarizeSummary, + ): + ids = self.alignment_parser(line) + summary.num_seq += 1 + summary.num_tok += len(ids) + return ids + + +class LegacyBinarizer: + @classmethod + def binarize( + cls, + filename: str, + dico: Dictionary, + consumer: tp.Callable[[torch.IntTensor], None], + tokenize: tp.Callable[[str], tp.List[str]] = tokenize_line, + append_eos: bool = True, + reverse_order: bool = False, + offset: int = 0, + end: int = -1, + already_numberized: bool = False, + ) -> tp.Dict[str, int]: + binarizer = VocabularyDatasetBinarizer( + dict=dico, + tokenize=tokenize, + append_eos=append_eos, + reverse_order=reverse_order, + already_numberized=already_numberized, + ) + return cls._consume_file( + filename, + binarizer, + consumer, + offset_start=offset, + offset_end=end, + ) + + @classmethod + def binarize_alignments( + cls, + filename: str, + alignment_parser: tp.Callable[[str], torch.IntTensor], + consumer: tp.Callable[[torch.IntTensor], None], + offset: int = 0, + end: int = -1, + ) -> tp.Dict[str, int]: + binarizer = AlignmentDatasetBinarizer(alignment_parser) + return cls._consume_file( + filename, + binarizer, + consumer, + offset_start=offset, + offset_end=end, + ) + + @staticmethod + def _consume_file( + filename: str, + binarizer: Binarizer, + consumer: tp.Callable[[torch.IntTensor], None], + offset_start: int, + offset_end: int, + ) -> tp.Dict[str, int]: + summary = BinarizeSummary() + + with Chunker( + PathManager.get_local_path(filename), offset_start, offset_end + ) as line_iterator: + for line in line_iterator: + consumer(binarizer.binarize_line(line, summary)) + + return { + "nseq": summary.num_seq, + "nunk": summary.num_replaced, + "ntok": summary.num_tok, + "replaced": summary.replaced, + } diff --git a/fairseq/fairseq/checkpoint_utils.py b/fairseq/fairseq/checkpoint_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..e3f316b9e78d0e9591d86ca4b75eafccfced66a4 --- /dev/null +++ b/fairseq/fairseq/checkpoint_utils.py @@ -0,0 +1,936 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import ast +import collections +import contextlib +import inspect +import logging +import os +import re +import time +import traceback +from collections import OrderedDict +from pathlib import Path +from typing import Any, Dict, Optional, Union + +import numpy as np +import torch +from fairseq.data import data_utils +from fairseq.dataclass.configs import CheckpointConfig +from fairseq.dataclass.utils import ( + convert_namespace_to_omegaconf, + overwrite_args_by_name, +) +from fairseq.distributed.fully_sharded_data_parallel import FSDP, has_FSDP +from fairseq.file_io import PathManager +from fairseq.models import FairseqDecoder, FairseqEncoder +from omegaconf import DictConfig, OmegaConf, open_dict + +logger = logging.getLogger(__name__) + + +def save_checkpoint(cfg: CheckpointConfig, trainer, epoch_itr, val_loss): + from fairseq import meters + + # only one worker should attempt to create the required dir + if trainer.data_parallel_rank == 0: + os.makedirs(cfg.save_dir, exist_ok=True) + + prev_best = getattr(save_checkpoint, "best", val_loss) + if val_loss is not None: + best_function = max if cfg.maximize_best_checkpoint_metric else min + save_checkpoint.best = best_function(val_loss, prev_best) + + if cfg.no_save: + return None + + trainer.consolidate_optimizer() # TODO(SS): do we need this if no_save_optimizer_state + + if not trainer.should_save_checkpoint_on_current_rank: + if trainer.always_call_state_dict_during_save_checkpoint: + trainer.state_dict() + return None + + write_timer = meters.StopwatchMeter() + write_timer.start() + + epoch = epoch_itr.epoch + end_of_epoch = epoch_itr.end_of_epoch() + updates = trainer.get_num_updates() + + logger.info(f"Preparing to save checkpoint for epoch {epoch} @ {updates} updates") + + def is_better(a, b): + return a >= b if cfg.maximize_best_checkpoint_metric else a <= b + + suffix = trainer.checkpoint_suffix + checkpoint_conds = collections.OrderedDict() + checkpoint_conds["checkpoint{}{}.pt".format(epoch, suffix)] = ( + end_of_epoch and not cfg.no_epoch_checkpoints and epoch % cfg.save_interval == 0 + ) + checkpoint_conds["checkpoint_{}_{}{}.pt".format(epoch, updates, suffix)] = ( + not end_of_epoch + and cfg.save_interval_updates > 0 + and updates % cfg.save_interval_updates == 0 + ) + checkpoint_conds["checkpoint_best{}.pt".format(suffix)] = val_loss is not None and ( + not hasattr(save_checkpoint, "best") + or is_better(val_loss, save_checkpoint.best) + ) + if val_loss is not None and cfg.keep_best_checkpoints > 0: + worst_best = getattr(save_checkpoint, "best", None) + chkpts = checkpoint_paths( + cfg.save_dir, + pattern=r"checkpoint\.best_{}_(\d+\.?\d*){}\.pt".format( + cfg.best_checkpoint_metric, suffix + ), + ) + if len(chkpts) > 0: + p = chkpts[-1] if cfg.maximize_best_checkpoint_metric else chkpts[0] + worst_best = float(p.rsplit("_")[-1].replace("{}.pt".format(suffix), "")) + # add random digits to resolve ties + with data_utils.numpy_seed(epoch, updates, val_loss): + rand_sfx = np.random.randint(0, cfg.keep_best_checkpoints) + + checkpoint_conds[ + "checkpoint.best_{}_{:.3f}{}{}.pt".format( + cfg.best_checkpoint_metric, val_loss, rand_sfx, suffix + ) + ] = worst_best is None or is_better(val_loss, worst_best) + checkpoint_conds[ + "checkpoint_last{}.pt".format(suffix) + ] = not cfg.no_last_checkpoints + + extra_state = { + "train_iterator": epoch_itr.state_dict(), + "val_loss": val_loss, + } + + # Going forward, different tasks could expose an API like this to dump all + # the checkpoint worthy attributes in a dictionary which then will be + # merged with the parent dictionary to create the "extra_state". This + # allows for an extensible yet simple design to checkpoint task level + # attributes + if hasattr(trainer.task, "get_checkpoint_dict"): + extra_state = {**extra_state, **trainer.task.get_checkpoint_dict()} + logger.info(f"State of {trainer.task.__class__.__name__} is ready to be persisted with the checkpoint") + + if hasattr(save_checkpoint, "best"): + extra_state.update({"best": save_checkpoint.best}) + + checkpoints = [ + os.path.join(cfg.save_dir, fn) for fn, cond in checkpoint_conds.items() if cond + ] + saved_cp = None + if len(checkpoints) > 0 and trainer.should_save_checkpoint_on_current_rank: + saved_cp = trainer.save_checkpoint(checkpoints[0], extra_state) + for cp in checkpoints[1:]: + if cfg.write_checkpoints_asynchronously: + # TODO[ioPath]: Need to implement a delayed asynchronous + # file copying/moving feature. + logger.warning( + f"ioPath is not copying {checkpoints[0]} to {cp} " + "since async write mode is on." + ) + else: + assert PathManager.copy( + checkpoints[0], cp, overwrite=True + ), f"Failed to copy {checkpoints[0]} to {cp}" + + write_timer.stop() + logger.info( + "Saved checkpoint {} (epoch {} @ {} updates, score {}) (writing took {} seconds)".format( + checkpoints[0], epoch, updates, val_loss, write_timer.sum + ) + ) + + if ( + not end_of_epoch + and cfg.keep_interval_updates > 0 + and trainer.should_save_checkpoint_on_current_rank + ): + # remove old checkpoints; checkpoints are sorted in descending order + if cfg.keep_interval_updates_pattern == -1: + checkpoints = checkpoint_paths( + cfg.save_dir, pattern=r"checkpoint_\d+_(\d+){}\.pt".format(suffix) + ) + else: + checkpoints = checkpoint_paths( + cfg.save_dir, + pattern=r"checkpoint_\d+_(\d+){}\.pt".format(suffix), + keep_match=True, + ) + checkpoints = [ + x[0] + for x in checkpoints + if x[1] % cfg.keep_interval_updates_pattern != 0 + ] + + for old_chk in checkpoints[cfg.keep_interval_updates :]: + if os.path.lexists(old_chk): + os.remove(old_chk) + elif PathManager.exists(old_chk): + PathManager.rm(old_chk) + + if cfg.keep_last_epochs > 0 and trainer.should_save_checkpoint_on_current_rank: + # remove old epoch checkpoints; checkpoints are sorted in descending order + checkpoints = checkpoint_paths( + cfg.save_dir, pattern=r"checkpoint(\d+){}\.pt".format(suffix) + ) + for old_chk in checkpoints[cfg.keep_last_epochs :]: + if os.path.lexists(old_chk): + os.remove(old_chk) + elif PathManager.exists(old_chk): + PathManager.rm(old_chk) + + if cfg.keep_best_checkpoints > 0 and trainer.should_save_checkpoint_on_current_rank: + # only keep the best N checkpoints according to validation metric + checkpoints = checkpoint_paths( + cfg.save_dir, + pattern=r"checkpoint\.best_{}_(\d+\.?\d*){}\.pt".format( + cfg.best_checkpoint_metric, suffix + ), + ) + if not cfg.maximize_best_checkpoint_metric: + checkpoints = checkpoints[::-1] + for old_chk in checkpoints[cfg.keep_best_checkpoints :]: + if os.path.lexists(old_chk): + os.remove(old_chk) + elif PathManager.exists(old_chk): + PathManager.rm(old_chk) + + return saved_cp + + +def load_checkpoint(cfg: CheckpointConfig, trainer, **passthrough_args): + """ + Load a checkpoint and restore the training iterator. + + *passthrough_args* will be passed through to + ``trainer.get_train_iterator``. + """ + + reset_optimizer = cfg.reset_optimizer + reset_lr_scheduler = cfg.reset_lr_scheduler + optimizer_overrides = ast.literal_eval(cfg.optimizer_overrides) + reset_meters = cfg.reset_meters + reset_dataloader = cfg.reset_dataloader + + if cfg.finetune_from_model is not None and ( + reset_optimizer or reset_lr_scheduler or reset_meters or reset_dataloader + ): + raise ValueError( + "--finetune-from-model can not be set together with either --reset-optimizer" + " or reset_lr_scheduler or reset_meters or reset_dataloader" + ) + + suffix = trainer.checkpoint_suffix + if ( + cfg.restore_file == "checkpoint_last.pt" + ): # default value of restore_file is 'checkpoint_last.pt' + checkpoint_path = os.path.join( + cfg.save_dir, "checkpoint_last{}.pt".format(suffix) + ) + first_launch = not PathManager.exists(checkpoint_path) + if first_launch and getattr(cfg, "continue_once", None) is not None: + checkpoint_path = cfg.continue_once + elif cfg.finetune_from_model is not None and first_launch: + # if there is no last checkpoint to restore, start the finetune from pretrained model + # else just use usual logic to load checkpoint, e.g. restart from last checkpoint and etc. + if PathManager.exists(cfg.finetune_from_model): + checkpoint_path = cfg.finetune_from_model + reset_optimizer = True + reset_lr_scheduler = True + reset_meters = True + reset_dataloader = True + logger.info( + f"loading pretrained model from {checkpoint_path}: " + "optimizer, lr scheduler, meters, dataloader will be reset" + ) + else: + raise ValueError( + f"--finetune-from-model {cfg.finetune_from_model} does not exist" + ) + elif suffix is not None: + checkpoint_path = cfg.restore_file.replace(".pt", suffix + ".pt") + else: + checkpoint_path = cfg.restore_file + + if cfg.restore_file != "checkpoint_last.pt" and cfg.finetune_from_model: + raise ValueError( + "--finetune-from-model and --restore-file (non-default value) " + "can not be specified together: " + str(cfg) + ) + + extra_state = trainer.load_checkpoint( + checkpoint_path, + reset_optimizer, + reset_lr_scheduler, + optimizer_overrides, + reset_meters=reset_meters, + ) + + if ( + extra_state is not None + and "best" in extra_state + and not reset_optimizer + and not reset_meters + ): + save_checkpoint.best = extra_state["best"] + + if extra_state is not None and not reset_dataloader: + # restore iterator from checkpoint + itr_state = extra_state["train_iterator"] + epoch_itr = trainer.get_train_iterator( + epoch=itr_state["epoch"], load_dataset=True, **passthrough_args + ) + epoch_itr.load_state_dict(itr_state) + + # Preload the checkpoint for the task + task_cp_dict = extra_state.get(trainer.task.__class__.__name__, {}) + if task_cp_dict and hasattr(trainer.task, "set_checkpoint_dict"): + trainer.task.set_checkpoint_dict(task_cp_dict) + else: + epoch_itr = trainer.get_train_iterator( + epoch=1, load_dataset=True, **passthrough_args + ) + + trainer.lr_step(epoch_itr.epoch) + + return extra_state, epoch_itr + + +def load_checkpoint_to_cpu(path, arg_overrides=None, load_on_all_ranks=False): + """Loads a checkpoint to CPU (with upgrading for backward compatibility). + + If doing single-GPU training or if the checkpoint is only being loaded by at + most one process on each node (current default behavior is for only rank 0 + to read the checkpoint from disk), load_on_all_ranks should be False to + avoid errors from torch.distributed not having been initialized or + torch.distributed.barrier() hanging. + + If all processes on each node may be loading the checkpoint + simultaneously, load_on_all_ranks should be set to True to avoid I/O + conflicts. + + There's currently no support for > 1 but < all processes loading the + checkpoint on each node. + """ + local_path = PathManager.get_local_path(path) + # The locally cached file returned by get_local_path() may be stale for + # remote files that are periodically updated/overwritten (ex: + # checkpoint_last.pt) - so we remove the local copy, sync across processes + # (if needed), and then download a fresh copy. + if local_path != path and PathManager.path_requires_pathmanager(path): + try: + os.remove(local_path) + except FileNotFoundError: + # With potentially multiple processes removing the same file, the + # file being missing is benign (missing_ok isn't available until + # Python 3.8). + pass + if load_on_all_ranks: + torch.distributed.barrier() + local_path = PathManager.get_local_path(path) + + with open(local_path, "rb") as f: + state = torch.load(f, map_location=torch.device("cpu")) + + if "args" in state and state["args"] is not None and arg_overrides is not None: + args = state["args"] + for arg_name, arg_val in arg_overrides.items(): + setattr(args, arg_name, arg_val) + + if "cfg" in state and state["cfg"] is not None: + + # hack to be able to set Namespace in dict config. this should be removed when we update to newer + # omegaconf version that supports object flags, or when we migrate all existing models + from omegaconf import __version__ as oc_version + from omegaconf import _utils + + if oc_version < "2.2": + old_primitive = _utils.is_primitive_type + _utils.is_primitive_type = lambda _: True + + state["cfg"] = OmegaConf.create(state["cfg"]) + + _utils.is_primitive_type = old_primitive + OmegaConf.set_struct(state["cfg"], True) + else: + state["cfg"] = OmegaConf.create(state["cfg"], flags={"allow_objects": True}) + + if arg_overrides is not None: + overwrite_args_by_name(state["cfg"], arg_overrides) + + state = _upgrade_state_dict(state) + return state + + +def load_model_ensemble( + filenames, + arg_overrides: Optional[Dict[str, Any]] = None, + task=None, + strict=True, + suffix="", + num_shards=1, + state=None, +): + """Loads an ensemble of models. + + Args: + filenames (List[str]): checkpoint files to load + arg_overrides (Dict[str,Any], optional): override model args that + were used during model training + task (fairseq.tasks.FairseqTask, optional): task to use for loading + """ + assert not ( + strict and num_shards > 1 + ), "Cannot load state dict with strict=True and checkpoint shards > 1" + ensemble, args, _task = load_model_ensemble_and_task( + filenames, + arg_overrides, + task, + strict, + suffix, + num_shards, + state, + ) + return ensemble, args + + +def get_maybe_sharded_checkpoint_filename( + filename: str, suffix: str, shard_idx: int, num_shards: int +) -> str: + orig_filename = filename + filename = filename.replace(".pt", suffix + ".pt") + fsdp_filename = filename[:-3] + f"-shard{shard_idx}.pt" + model_parallel_filename = orig_filename[:-3] + f"_part{shard_idx}.pt" + if PathManager.exists(fsdp_filename): + return fsdp_filename + elif num_shards > 1: + return model_parallel_filename + else: + return filename + + +def load_model_ensemble_and_task( + filenames, + arg_overrides: Optional[Dict[str, Any]] = None, + task=None, + strict=True, + suffix="", + num_shards=1, + state=None, +): + assert state is None or len(filenames) == 1 + + from fairseq import tasks + + assert not ( + strict and num_shards > 1 + ), "Cannot load state dict with strict=True and checkpoint shards > 1" + ensemble = [] + cfg = None + for filename in filenames: + orig_filename = filename + model_shard_state = {"shard_weights": [], "shard_metadata": []} + assert num_shards > 0 + st = time.time() + for shard_idx in range(num_shards): + filename = get_maybe_sharded_checkpoint_filename( + orig_filename, suffix, shard_idx, num_shards + ) + + if not PathManager.exists(filename): + raise IOError("Model file not found: {}".format(filename)) + if state is None: + state = load_checkpoint_to_cpu(filename, arg_overrides) + if "args" in state and state["args"] is not None: + cfg = convert_namespace_to_omegaconf(state["args"]) + elif "cfg" in state and state["cfg"] is not None: + cfg = state["cfg"] + else: + raise RuntimeError( + f"Neither args nor cfg exist in state keys = {state.keys()}" + ) + + if task is None: + task = tasks.setup_task(cfg.task, from_checkpoint=True) + + if "task_state" in state: + task.load_state_dict(state["task_state"]) + + argspec = inspect.getfullargspec(task.build_model) + + if "fsdp_metadata" in state and num_shards > 1: + model_shard_state["shard_weights"].append(state["model"]) + model_shard_state["shard_metadata"].append(state["fsdp_metadata"]) + # check FSDP import before the code goes too far + if not has_FSDP: + raise ImportError( + "Cannot find FullyShardedDataParallel. " + "Please install fairscale with: pip install fairscale" + ) + if shard_idx == num_shards - 1: + consolidated_model_state = FSDP.consolidate_shard_weights( + shard_weights=model_shard_state["shard_weights"], + shard_metadata=model_shard_state["shard_metadata"], + ) + if "from_checkpoint" in argspec.args: + model = task.build_model(cfg.model, from_checkpoint=True) + else: + model = task.build_model(cfg.model) + if ( + "optimizer_history" in state + and len(state["optimizer_history"]) > 0 + and "num_updates" in state["optimizer_history"][-1] + ): + model.set_num_updates( + state["optimizer_history"][-1]["num_updates"] + ) + model.load_state_dict( + consolidated_model_state, strict=strict, model_cfg=cfg.model + ) + else: + # model parallel checkpoint or unsharded checkpoint + # support old external tasks + + if "from_checkpoint" in argspec.args: + model = task.build_model(cfg.model, from_checkpoint=True) + else: + model = task.build_model(cfg.model) + if ( + "optimizer_history" in state + and len(state["optimizer_history"]) > 0 + and "num_updates" in state["optimizer_history"][-1] + ): + model.set_num_updates(state["optimizer_history"][-1]["num_updates"]) + model.load_state_dict( + state["model"], strict=strict, model_cfg=cfg.model + ) + + # reset state so it gets loaded for the next model in ensemble + state = None + if shard_idx % 10 == 0 and shard_idx > 0: + elapsed = time.time() - st + logger.info( + f"Loaded {shard_idx} shards in {elapsed:.2f}s, {elapsed / (shard_idx+1):.2f}s/shard" + ) + + # build model for ensemble + ensemble.append(model) + return ensemble, cfg, task + + +def load_model_ensemble_and_task_from_hf_hub( + model_id, + cache_dir: Optional[str] = None, + arg_overrides: Optional[Dict[str, Any]] = None, + **kwargs: Any, +): + try: + from huggingface_hub import snapshot_download + except ImportError: + raise ImportError( + "You need to install huggingface_hub to use `load_from_hf_hub`. " + "See https://pypi.org/project/huggingface-hub/ for installation." + ) + + library_name = "fairseq" + cache_dir = cache_dir or (Path.home() / ".cache" / library_name).as_posix() + cache_dir = snapshot_download( + model_id, cache_dir=cache_dir, library_name=library_name, **kwargs + ) + + _arg_overrides = arg_overrides or {} + _arg_overrides["data"] = cache_dir + return load_model_ensemble_and_task( + [p.as_posix() for p in Path(cache_dir).glob("*.pt")], + arg_overrides=_arg_overrides, + ) + + +def checkpoint_paths(path, pattern=r"checkpoint(\d+)\.pt", keep_match=False): + """Retrieves all checkpoints found in `path` directory. + + Checkpoints are identified by matching filename to the specified pattern. If + the pattern contains groups, the result will be sorted by the first group in + descending order. + """ + pt_regexp = re.compile(pattern) + files = PathManager.ls(path) + + entries = [] + for i, f in enumerate(files): + m = pt_regexp.fullmatch(f) + if m is not None: + idx = float(m.group(1)) if len(m.groups()) > 0 else i + entries.append((idx, m.group(0))) + if keep_match: + return [(os.path.join(path, x[1]), x[0]) for x in sorted(entries, reverse=True)] + else: + return [os.path.join(path, x[1]) for x in sorted(entries, reverse=True)] + + +def torch_persistent_save(obj, filename, async_write: bool = False): + if async_write: + with PathManager.opena(filename, "wb") as f: + _torch_persistent_save(obj, f) + else: + if PathManager.supports_rename(filename): + # do atomic save + with PathManager.open(filename + ".tmp", "wb") as f: + _torch_persistent_save(obj, f) + PathManager.rename(filename + ".tmp", filename) + else: + # fallback to non-atomic save + with PathManager.open(filename, "wb") as f: + _torch_persistent_save(obj, f) + + +def _torch_persistent_save(obj, f): + if isinstance(f, str): + with PathManager.open(f, "wb") as h: + torch_persistent_save(obj, h) + return + for i in range(3): + try: + return torch.save(obj, f) + except Exception: + if i == 2: + logger.error(traceback.format_exc()) + raise + else: + time.sleep(2.5) + + +def _upgrade_state_dict(state): + """Helper for upgrading old model checkpoints.""" + + # add optimizer_history + if "optimizer_history" not in state: + state["optimizer_history"] = [ + {"criterion_name": "CrossEntropyCriterion", "best_loss": state["best_loss"]} + ] + state["last_optimizer_state"] = state["optimizer"] + del state["optimizer"] + del state["best_loss"] + # move extra_state into sub-dictionary + if "epoch" in state and "extra_state" not in state: + state["extra_state"] = { + "epoch": state["epoch"], + "batch_offset": state["batch_offset"], + "val_loss": state["val_loss"], + } + del state["epoch"] + del state["batch_offset"] + del state["val_loss"] + # reduce optimizer history's memory usage (only keep the last state) + if "optimizer" in state["optimizer_history"][-1]: + state["last_optimizer_state"] = state["optimizer_history"][-1]["optimizer"] + for optim_hist in state["optimizer_history"]: + del optim_hist["optimizer"] + # record the optimizer class name + if "optimizer_name" not in state["optimizer_history"][-1]: + state["optimizer_history"][-1]["optimizer_name"] = "FairseqNAG" + # move best_loss into lr_scheduler_state + if "lr_scheduler_state" not in state["optimizer_history"][-1]: + state["optimizer_history"][-1]["lr_scheduler_state"] = { + "best": state["optimizer_history"][-1]["best_loss"] + } + del state["optimizer_history"][-1]["best_loss"] + # keep track of number of updates + if "num_updates" not in state["optimizer_history"][-1]: + state["optimizer_history"][-1]["num_updates"] = 0 + # use stateful training data iterator + if "train_iterator" not in state["extra_state"]: + state["extra_state"]["train_iterator"] = { + "epoch": state["extra_state"].get("epoch", 0), + "iterations_in_epoch": state["extra_state"].get("batch_offset", 0), + } + + # backward compatibility, cfg updates + if "args" in state and state["args"] is not None: + # old model checkpoints may not have separate source/target positions + if hasattr(state["args"], "max_positions") and not hasattr( + state["args"], "max_source_positions" + ): + state["args"].max_source_positions = state["args"].max_positions + state["args"].max_target_positions = state["args"].max_positions + # default to translation task + if not hasattr(state["args"], "task"): + state["args"].task = "translation" + # --raw-text and --lazy-load are deprecated + if getattr(state["args"], "raw_text", False): + state["args"].dataset_impl = "raw" + elif getattr(state["args"], "lazy_load", False): + state["args"].dataset_impl = "lazy" + # epochs start at 1 + if state["extra_state"]["train_iterator"] is not None: + state["extra_state"]["train_iterator"]["epoch"] = max( + state["extra_state"]["train_iterator"].get("epoch", 1), 1 + ) + # --remove-bpe ==> --postprocess + if hasattr(state["args"], "remove_bpe"): + state["args"].post_process = state["args"].remove_bpe + # --min-lr ==> --stop-min-lr + if hasattr(state["args"], "min_lr"): + state["args"].stop_min_lr = state["args"].min_lr + del state["args"].min_lr + # binary_cross_entropy / kd_binary_cross_entropy => wav2vec criterion + if hasattr(state["args"], "criterion") and state["args"].criterion in [ + "binary_cross_entropy", + "kd_binary_cross_entropy", + ]: + state["args"].criterion = "wav2vec" + # remove log_keys if it's None (criteria will supply a default value of []) + if hasattr(state["args"], "log_keys") and state["args"].log_keys is None: + delattr(state["args"], "log_keys") + # speech_pretraining => audio pretraining + if ( + hasattr(state["args"], "task") + and state["args"].task == "speech_pretraining" + ): + state["args"].task = "audio_pretraining" + # audio_cpc => wav2vec + if hasattr(state["args"], "arch") and state["args"].arch == "audio_cpc": + state["args"].arch = "wav2vec" + # convert legacy float learning rate to List[float] + if hasattr(state["args"], "lr") and isinstance(state["args"].lr, float): + state["args"].lr = [state["args"].lr] + # convert task data arg to a string instead of List[string] + if ( + hasattr(state["args"], "data") + and isinstance(state["args"].data, list) + and len(state["args"].data) > 0 + ): + state["args"].data = state["args"].data[0] + + state["cfg"] = convert_namespace_to_omegaconf(state["args"]) + + if "cfg" in state and state["cfg"] is not None: + cfg = state["cfg"] + with open_dict(cfg): + # any upgrades for Hydra-based configs + if ( + "task" in cfg + and "eval_wer_config" in cfg.task + and isinstance(cfg.task.eval_wer_config.print_alignment, bool) + ): + cfg.task.eval_wer_config.print_alignment = "hard" + if "generation" in cfg and isinstance(cfg.generation.print_alignment, bool): + cfg.generation.print_alignment = ( + "hard" if cfg.generation.print_alignment else None + ) + if ( + "model" in cfg + and "w2v_args" in cfg.model + and cfg.model.w2v_args is not None + and ( + hasattr(cfg.model.w2v_args, "task") or "task" in cfg.model.w2v_args + ) + and hasattr(cfg.model.w2v_args.task, "eval_wer_config") + and cfg.model.w2v_args.task.eval_wer_config is not None + and isinstance( + cfg.model.w2v_args.task.eval_wer_config.print_alignment, bool + ) + ): + cfg.model.w2v_args.task.eval_wer_config.print_alignment = "hard" + + return state + + +def prune_state_dict(state_dict, model_cfg: Optional[DictConfig]): + """Prune the given state_dict if desired for LayerDrop + (https://arxiv.org/abs/1909.11556). + + Training with LayerDrop allows models to be robust to pruning at inference + time. This function prunes state_dict to allow smaller models to be loaded + from a larger model and re-maps the existing state_dict for this to occur. + + It's called by functions that load models from checkpoints and does not + need to be called directly. + """ + arch = None + if model_cfg is not None: + arch = ( + model_cfg._name + if isinstance(model_cfg, DictConfig) + else getattr(model_cfg, "arch", None) + ) + + if not model_cfg or arch is None or arch == "ptt_transformer": + # args should not be none, but don't crash if it is. + return state_dict + + encoder_layers_to_keep = getattr(model_cfg, "encoder_layers_to_keep", None) + decoder_layers_to_keep = getattr(model_cfg, "decoder_layers_to_keep", None) + + if not encoder_layers_to_keep and not decoder_layers_to_keep: + return state_dict + + # apply pruning + logger.info( + "Pruning model to specified layer configuration - this works best if the model was trained with LayerDrop" + ) + + def create_pruning_pass(layers_to_keep, layer_name): + keep_layers = sorted( + int(layer_string) for layer_string in layers_to_keep.split(",") + ) + mapping_dict = {} + for i in range(len(keep_layers)): + mapping_dict[str(keep_layers[i])] = str(i) + + regex = re.compile(r"^{layer}.*\.layers\.(\d+)".format(layer=layer_name)) + return {"substitution_regex": regex, "mapping_dict": mapping_dict} + + pruning_passes = [] + if encoder_layers_to_keep: + pruning_passes.append(create_pruning_pass(encoder_layers_to_keep, "encoder")) + if decoder_layers_to_keep: + pruning_passes.append(create_pruning_pass(decoder_layers_to_keep, "decoder")) + + new_state_dict = {} + for layer_name in state_dict.keys(): + match = re.search(r"\.layers\.(\d+)\.", layer_name) + # if layer has no number in it, it is a supporting layer, such as an + # embedding + if not match: + new_state_dict[layer_name] = state_dict[layer_name] + continue + + # otherwise, layer should be pruned. + original_layer_number = match.group(1) + # figure out which mapping dict to replace from + for pruning_pass in pruning_passes: + if original_layer_number in pruning_pass["mapping_dict"] and pruning_pass[ + "substitution_regex" + ].search(layer_name): + new_layer_number = pruning_pass["mapping_dict"][original_layer_number] + substitution_match = pruning_pass["substitution_regex"].search( + layer_name + ) + new_state_key = ( + layer_name[: substitution_match.start(1)] + + new_layer_number + + layer_name[substitution_match.end(1) :] + ) + new_state_dict[new_state_key] = state_dict[layer_name] + + # Since layers are now pruned, *_layers_to_keep are no longer needed. + # This is more of "It would make it work fix" rather than a proper fix. + if isinstance(model_cfg, DictConfig): + context = open_dict(model_cfg) + else: + context = contextlib.ExitStack() + with context: + if hasattr(model_cfg, "encoder_layers_to_keep"): + model_cfg.encoder_layers_to_keep = None + if hasattr(model_cfg, "decoder_layers_to_keep"): + model_cfg.decoder_layers_to_keep = None + + return new_state_dict + + +def load_pretrained_component_from_model( + component: Union[FairseqEncoder, FairseqDecoder], + checkpoint: str, + strict: bool = True, +): + """ + Load a pretrained FairseqEncoder or FairseqDecoder from checkpoint into the + provided `component` object. If state_dict fails to load, there may be a + mismatch in the architecture of the corresponding `component` found in the + `checkpoint` file. + """ + if not PathManager.exists(checkpoint): + raise IOError("Model file not found: {}".format(checkpoint)) + state = load_checkpoint_to_cpu(checkpoint) + if isinstance(component, FairseqEncoder): + component_type = "encoder" + elif isinstance(component, FairseqDecoder): + component_type = "decoder" + else: + raise ValueError( + "component to load must be either a FairseqEncoder or " + "FairseqDecoder. Loading other component types are not supported." + ) + component_state_dict = OrderedDict() + for key in state["model"].keys(): + if key.startswith(component_type): + # encoder.input_layers.0.0.weight --> input_layers.0.0.weight + component_subkey = key[len(component_type) + 1 :] + component_state_dict[component_subkey] = state["model"][key] + component.load_state_dict(component_state_dict, strict=strict) + return component + + +def verify_checkpoint_directory(save_dir: str) -> None: + if not os.path.exists(save_dir): + os.makedirs(save_dir, exist_ok=True) + temp_file_path = os.path.join(save_dir, "dummy") + try: + with open(temp_file_path, "w"): + pass + except OSError as e: + logger.warning( + "Unable to access checkpoint save directory: {}".format(save_dir) + ) + raise e + else: + os.remove(temp_file_path) + + +def save_ema_as_checkpoint(src_path, dst_path): + state = load_ema_from_checkpoint(src_path) + torch_persistent_save(state, dst_path) + + +def load_ema_from_checkpoint(fpath): + """Loads exponential moving averaged (EMA) checkpoint from input and + returns a model with ema weights. + + Args: + fpath: A string path of checkpoint to load from. + + Returns: + A dict of string keys mapping to various values. The 'model' key + from the returned dict should correspond to an OrderedDict mapping + string parameter names to torch Tensors. + """ + params_dict = collections.OrderedDict() + new_state = None + + with PathManager.open(fpath, "rb") as f: + new_state = torch.load( + f, + map_location=( + lambda s, _: torch.serialization.default_restore_location(s, "cpu") + ), + ) + + # EMA model is stored in a separate "extra state" + model_params = new_state["extra_state"]["ema"] + + for key in list(model_params.keys()): + p = model_params[key] + if isinstance(p, torch.HalfTensor): + p = p.float() + if key not in params_dict: + params_dict[key] = p.clone() + # NOTE: clone() is needed in case of p is a shared parameter + else: + raise ValueError("Key {} is repeated in EMA model params.".format(key)) + + if len(params_dict) == 0: + raise ValueError( + f"Input checkpoint path '{fpath}' does not contain " + "ema model weights, is this model trained with EMA?" + ) + + new_state["model"] = params_dict + return new_state diff --git a/fairseq/fairseq/file_chunker_utils.py b/fairseq/fairseq/file_chunker_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..3f275490993dbbcd05d990c050ae6c7b4c9568c9 --- /dev/null +++ b/fairseq/fairseq/file_chunker_utils.py @@ -0,0 +1,84 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import os +import typing as tp + + +def _safe_readline(fd) -> str: + pos = fd.tell() + while True: + try: + return fd.readline() + except UnicodeDecodeError: + pos -= 1 + fd.seek(pos) # search where this character begins + + +def find_offsets(filename: str, num_chunks: int) -> tp.List[int]: + """ + given a file and a number of chuncks, find the offsets in the file + to be able to chunk around full lines. + """ + with open(filename, "r", encoding="utf-8") as f: + size = os.fstat(f.fileno()).st_size + chunk_size = size // num_chunks + offsets = [0 for _ in range(num_chunks + 1)] + for i in range(1, num_chunks): + f.seek(chunk_size * i) + _safe_readline(f) + offsets[i] = f.tell() + offsets[-1] = size + return offsets + + +class ChunkLineIterator: + """ + Iterator to properly iterate over lines of a file chunck. + """ + + def __init__(self, fd, start_offset: int, end_offset: int): + self._fd = fd + self._start_offset = start_offset + self._end_offset = end_offset + + def __iter__(self) -> tp.Iterable[str]: + self._fd.seek(self._start_offset) + # next(f) breaks f.tell(), hence readline() must be used + line = _safe_readline(self._fd) + while line: + pos = self._fd.tell() + # f.tell() does not always give the byte position in the file + # sometimes it skips to a very large number + # it is unlikely that through a normal read we go from + # end bytes to end + 2**32 bytes (4 GB) and this makes it unlikely + # that the procedure breaks by the undeterministic behavior of + # f.tell() + if ( + self._end_offset > 0 + and pos > self._end_offset + and pos < self._end_offset + 2**32 + ): + break + yield line + line = self._fd.readline() + + +class Chunker: + """ + contextmanager to read a chunck of a file line by line. + """ + + def __init__(self, path: str, start_offset: int, end_offset: int): + self.path = path + self.start_offset = start_offset + self.end_offset = end_offset + + def __enter__(self) -> ChunkLineIterator: + self.fd = open(self.path, "r", encoding="utf-8") + return ChunkLineIterator(self.fd, self.start_offset, self.end_offset) + + def __exit__(self, exc_type, exc_val, exc_tb) -> None: + self.fd.close() diff --git a/fairseq/fairseq/file_io.py b/fairseq/fairseq/file_io.py new file mode 100644 index 0000000000000000000000000000000000000000..8eca70a0668d09e211c06b5b432e4d0d2125ca72 --- /dev/null +++ b/fairseq/fairseq/file_io.py @@ -0,0 +1,196 @@ +#!/usr/bin/env python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import logging +import os +import shutil +from typing import List, Optional + + +logger = logging.getLogger(__file__) + + +try: + from iopath.common.file_io import g_pathmgr as IOPathManager + + try: + # [FB only - for now] AWS PathHandler for PathManager + from .fb_pathhandlers import S3PathHandler + + IOPathManager.register_handler(S3PathHandler()) + except KeyError: + logging.warning("S3PathHandler already registered.") + except ImportError: + logging.debug( + "S3PathHandler couldn't be imported. Either missing fb-only files, or boto3 module." + ) + +except ImportError: + IOPathManager = None + + +class PathManager: + """ + Wrapper for insulating OSS I/O (using Python builtin operations) from + iopath's PathManager abstraction (for transparently handling various + internal backends). + """ + + @staticmethod + def open( + path: str, + mode: str = "r", + buffering: int = -1, + encoding: Optional[str] = None, + errors: Optional[str] = None, + newline: Optional[str] = None, + ): + if IOPathManager: + return IOPathManager.open( + path=path, + mode=mode, + buffering=buffering, + encoding=encoding, + errors=errors, + newline=newline, + ) + return open( + path, + mode=mode, + buffering=buffering, + encoding=encoding, + errors=errors, + newline=newline, + ) + + @staticmethod + def copy(src_path: str, dst_path: str, overwrite: bool = False) -> bool: + if IOPathManager: + return IOPathManager.copy( + src_path=src_path, dst_path=dst_path, overwrite=overwrite + ) + return shutil.copyfile(src_path, dst_path) + + @staticmethod + def get_local_path(path: str, **kwargs) -> str: + if IOPathManager: + return IOPathManager.get_local_path(path, **kwargs) + return path + + @staticmethod + def exists(path: str) -> bool: + if IOPathManager: + return IOPathManager.exists(path) + return os.path.exists(path) + + @staticmethod + def isfile(path: str) -> bool: + if IOPathManager: + return IOPathManager.isfile(path) + return os.path.isfile(path) + + @staticmethod + def ls(path: str) -> List[str]: + if IOPathManager: + return IOPathManager.ls(path) + return os.listdir(path) + + @staticmethod + def mkdirs(path: str) -> None: + if IOPathManager: + return IOPathManager.mkdirs(path) + os.makedirs(path, exist_ok=True) + + @staticmethod + def rm(path: str) -> None: + if IOPathManager: + return IOPathManager.rm(path) + os.remove(path) + + @staticmethod + def chmod(path: str, mode: int) -> None: + if not PathManager.path_requires_pathmanager(path): + os.chmod(path, mode) + + @staticmethod + def register_handler(handler) -> None: + if IOPathManager: + return IOPathManager.register_handler(handler=handler) + + @staticmethod + def copy_from_local( + local_path: str, dst_path: str, overwrite: bool = False, **kwargs + ) -> None: + if IOPathManager: + return IOPathManager.copy_from_local( + local_path=local_path, dst_path=dst_path, overwrite=overwrite, **kwargs + ) + return shutil.copyfile(local_path, dst_path) + + @staticmethod + def path_requires_pathmanager(path: str) -> bool: + """Do we require PathManager to access given path?""" + if IOPathManager: + for p in IOPathManager._path_handlers.keys(): + if path.startswith(p): + return True + return False + + @staticmethod + def supports_rename(path: str) -> bool: + # PathManager doesn't yet support renames + return not PathManager.path_requires_pathmanager(path) + + @staticmethod + def rename(src: str, dst: str): + os.rename(src, dst) + + """ + ioPath async PathManager methods: + """ + + @staticmethod + def opena( + path: str, + mode: str = "r", + buffering: int = -1, + encoding: Optional[str] = None, + errors: Optional[str] = None, + newline: Optional[str] = None, + ): + """ + Return file descriptor with asynchronous write operations. + """ + global IOPathManager + if not IOPathManager: + logging.info("ioPath is initializing PathManager.") + try: + from iopath.common.file_io import PathManager + + IOPathManager = PathManager() + except Exception: + logging.exception("Failed to initialize ioPath PathManager object.") + return IOPathManager.opena( + path=path, + mode=mode, + buffering=buffering, + encoding=encoding, + errors=errors, + newline=newline, + ) + + @staticmethod + def async_close() -> bool: + """ + Wait for files to be written and clean up asynchronous PathManager. + NOTE: `PathManager.async_close()` must be called at the end of any + script that uses `PathManager.opena(...)`. + """ + global IOPathManager + if IOPathManager: + return IOPathManager.async_close() + return False diff --git a/fairseq/fairseq/hub_utils.py b/fairseq/fairseq/hub_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..b0c2da15bf2484a1109871d36c7a16d60219c42d --- /dev/null +++ b/fairseq/fairseq/hub_utils.py @@ -0,0 +1,326 @@ +#!/usr/bin/env python3 -u +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import argparse +import copy +import logging +import os +from typing import Any, Dict, Iterator, List + +import torch +from omegaconf import open_dict +from torch import nn + +from fairseq import utils +from fairseq.data import encoders + +logger = logging.getLogger(__name__) + + +def from_pretrained( + model_name_or_path, + checkpoint_file="model.pt", + data_name_or_path=".", + archive_map=None, + **kwargs +): + from fairseq import checkpoint_utils, file_utils + + if archive_map is not None: + if model_name_or_path in archive_map: + model_name_or_path = archive_map[model_name_or_path] + if data_name_or_path is not None and data_name_or_path in archive_map: + data_name_or_path = archive_map[data_name_or_path] + + # allow archive_map to set default arg_overrides (e.g., tokenizer, bpe) + # for each model + if isinstance(model_name_or_path, dict): + for k, v in model_name_or_path.items(): + if k == "checkpoint_file": + checkpoint_file = v + elif ( + k != "path" + # only set kwargs that don't already have overrides + and k not in kwargs + ): + kwargs[k] = v + model_name_or_path = model_name_or_path["path"] + + model_path = file_utils.load_archive_file(model_name_or_path) + + # convenience hack for loading data and BPE codes from model archive + if data_name_or_path.startswith("."): + kwargs["data"] = os.path.abspath(os.path.join(model_path, data_name_or_path)) + else: + kwargs["data"] = file_utils.load_archive_file(data_name_or_path) + for file, arg in { + "code": "bpe_codes", + "bpecodes": "bpe_codes", + "sentencepiece.bpe.model": "sentencepiece_model", + "merges.txt": "bpe_merges", + "vocab.json": "bpe_vocab", + }.items(): + path = os.path.join(model_path, file) + if os.path.exists(path): + kwargs[arg] = path + + if "user_dir" in kwargs: + utils.import_user_module(argparse.Namespace(user_dir=kwargs["user_dir"])) + + model_path = [ + os.path.join(model_path, cpt) for cpt in checkpoint_file.split(os.pathsep) + ] + + if "is_vocoder" in kwargs: + args = {"data": kwargs["data"], "model_path": model_path} + task = None + models = None + else: + models, args, task = checkpoint_utils.load_model_ensemble_and_task( + model_path, + arg_overrides=kwargs, + ) + if "generation_args" in kwargs and kwargs["generation_args"]: + for key in kwargs["generation_args"]: + setattr(args["generation"], key, kwargs["generation_args"][key]) + + return { + "args": args, + "task": task, + "models": models, + } + + +class GeneratorHubInterface(nn.Module): + """ + PyTorch Hub interface for generating sequences from a pre-trained + translation or language model. + """ + + def __init__(self, cfg, task, models): + super().__init__() + self.cfg = cfg + self.task = task + self.models = nn.ModuleList(models) + self.src_dict = task.source_dictionary + self.tgt_dict = task.target_dictionary + + # optimize model for generation + for model in self.models: + model.prepare_for_inference_(cfg) + + # Load alignment dictionary for unknown word replacement + # (None if no unknown word replacement, empty if no path to align dictionary) + self.align_dict = utils.load_align_dict(cfg.generation.replace_unk) + + self.tokenizer = encoders.build_tokenizer(cfg.tokenizer) + self.bpe = encoders.build_bpe(cfg.bpe) + + self.max_positions = utils.resolve_max_positions( + self.task.max_positions(), *[model.max_positions() for model in models] + ) + + # this is useful for determining the device + self.register_buffer("_float_tensor", torch.tensor([0], dtype=torch.float)) + + @property + def device(self): + return self._float_tensor.device + + def translate( + self, sentences: List[str], beam: int = 5, verbose: bool = False, **kwargs + ) -> List[str]: + return self.sample(sentences, beam, verbose, **kwargs) + + def sample( + self, sentences: List[str], beam: int = 1, verbose: bool = False, **kwargs + ) -> List[str]: + if isinstance(sentences, str): + return self.sample([sentences], beam=beam, verbose=verbose, **kwargs)[0] + tokenized_sentences = [self.encode(sentence) for sentence in sentences] + batched_hypos = self.generate(tokenized_sentences, beam, verbose, **kwargs) + return [self.decode(hypos[0]["tokens"]) for hypos in batched_hypos] + + def score( + self, sentences: List[str], replace_newline_with_eos: bool = False, **kwargs + ): + if isinstance(sentences, str): + return self.score( + [sentences], replace_newline_with_eos=replace_newline_with_eos, **kwargs + )[0] + + def encode(sentence): + if replace_newline_with_eos: + return torch.cat([self.encode(line) for line in sentence.splitlines()]) + else: + return self.encode(sentence) + + # NOTE: this doesn't support translation tasks currently + tokenized_sentences = [encode(sentence) for sentence in sentences] + return [ + hypos[0] + for hypos in self.generate( + tokenized_sentences, score_reference=True, **kwargs + ) + ] + + def generate( + self, + tokenized_sentences: List[torch.LongTensor], + beam: int = 5, + verbose: bool = False, + skip_invalid_size_inputs=False, + inference_step_args=None, + prefix_allowed_tokens_fn=None, + **kwargs + ) -> List[List[Dict[str, torch.Tensor]]]: + if torch.is_tensor(tokenized_sentences) and tokenized_sentences.dim() == 1: + return self.generate( + tokenized_sentences.unsqueeze(0), beam=beam, verbose=verbose, **kwargs + )[0] + + # build generator using current args as well as any kwargs + gen_args = copy.deepcopy(self.cfg.generation) + with open_dict(gen_args): + gen_args.beam = beam + for k, v in kwargs.items(): + setattr(gen_args, k, v) + generator = self.task.build_generator( + self.models, + gen_args, + prefix_allowed_tokens_fn=prefix_allowed_tokens_fn, + ) + + inference_step_args = inference_step_args or {} + results = [] + for batch in self._build_batches(tokenized_sentences, skip_invalid_size_inputs): + batch = utils.apply_to_sample(lambda t: t.to(self.device), batch) + translations = self.task.inference_step( + generator, self.models, batch, **inference_step_args + ) + for id, hypos in zip(batch["id"].tolist(), translations): + results.append((id, hypos)) + + # sort output to match input order + outputs = [hypos for _, hypos in sorted(results, key=lambda x: x[0])] + + if verbose: + + def getarg(name, default): + return getattr(gen_args, name, getattr(self.cfg, name, default)) + + for source_tokens, target_hypotheses in zip(tokenized_sentences, outputs): + src_str_with_unk = self.string(source_tokens) + logger.info("S\t{}".format(src_str_with_unk)) + for hypo in target_hypotheses: + hypo_str = self.decode(hypo["tokens"]) + logger.info("H\t{}\t{}".format(hypo["score"], hypo_str)) + logger.info( + "P\t{}".format( + " ".join( + map( + lambda x: "{:.4f}".format(x), + hypo["positional_scores"].tolist(), + ) + ) + ) + ) + if hypo["alignment"] is not None and getarg( + "print_alignment", False + ): + logger.info( + "A\t{}".format( + " ".join( + [ + "{}-{}".format(src_idx, tgt_idx) + for src_idx, tgt_idx in hypo["alignment"] + ] + ) + ) + ) + return outputs + + def encode(self, sentence: str) -> torch.LongTensor: + sentence = self.tokenize(sentence) + sentence = self.apply_bpe(sentence) + return self.binarize(sentence) + + def decode(self, tokens: torch.LongTensor) -> str: + sentence = self.string(tokens) + sentence = self.remove_bpe(sentence) + return self.detokenize(sentence) + + def tokenize(self, sentence: str) -> str: + if self.tokenizer is not None: + sentence = self.tokenizer.encode(sentence) + return sentence + + def detokenize(self, sentence: str) -> str: + if self.tokenizer is not None: + sentence = self.tokenizer.decode(sentence) + return sentence + + def apply_bpe(self, sentence: str) -> str: + if self.bpe is not None: + sentence = self.bpe.encode(sentence) + return sentence + + def remove_bpe(self, sentence: str) -> str: + if self.bpe is not None: + sentence = self.bpe.decode(sentence) + return sentence + + def binarize(self, sentence: str) -> torch.LongTensor: + return self.src_dict.encode_line(sentence, add_if_not_exist=False).long() + + def string(self, tokens: torch.LongTensor) -> str: + return self.tgt_dict.string(tokens) + + def _build_batches( + self, tokens: List[List[int]], skip_invalid_size_inputs: bool + ) -> Iterator[Dict[str, Any]]: + lengths = torch.LongTensor([t.numel() for t in tokens]) + batch_iterator = self.task.get_batch_iterator( + dataset=self.task.build_dataset_for_inference(tokens, lengths), + max_tokens=self.cfg.dataset.max_tokens, + max_sentences=self.cfg.dataset.batch_size, + max_positions=self.max_positions, + ignore_invalid_inputs=skip_invalid_size_inputs, + disable_iterator_cache=True, + ).next_epoch_itr(shuffle=False) + return batch_iterator + + +class BPEHubInterface(object): + """PyTorch Hub interface for Byte-Pair Encoding (BPE).""" + + def __init__(self, bpe, **kwargs): + super().__init__() + args = argparse.Namespace(bpe=bpe, **kwargs) + self.bpe = encoders.build_bpe(args) + assert self.bpe is not None + + def encode(self, sentence: str) -> str: + return self.bpe.encode(sentence) + + def decode(self, sentence: str) -> str: + return self.bpe.decode(sentence) + + +class TokenizerHubInterface(object): + """PyTorch Hub interface for tokenization.""" + + def __init__(self, tokenizer, **kwargs): + super().__init__() + args = argparse.Namespace(tokenizer=tokenizer, **kwargs) + self.tokenizer = encoders.build_tokenizer(args) + assert self.tokenizer is not None + + def encode(self, sentence: str) -> str: + return self.tokenizer.encode(sentence) + + def decode(self, sentence: str) -> str: + return self.tokenizer.decode(sentence) diff --git a/fairseq/fairseq/incremental_decoding_utils.py b/fairseq/fairseq/incremental_decoding_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..b26e6cd01cd4cbdffa23d88b354eb4a55a94189b --- /dev/null +++ b/fairseq/fairseq/incremental_decoding_utils.py @@ -0,0 +1,51 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import uuid +from typing import Dict, Optional + +from torch import Tensor + + +class FairseqIncrementalState(object): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.init_incremental_state() + + def init_incremental_state(self): + self._incremental_state_id = str(uuid.uuid4()) + + def _get_full_incremental_state_key(self, key: str) -> str: + return "{}.{}".format(self._incremental_state_id, key) + + def get_incremental_state( + self, + incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]], + key: str, + ) -> Optional[Dict[str, Optional[Tensor]]]: + """Helper for getting incremental state for an nn.Module.""" + full_key = self._get_full_incremental_state_key(key) + if incremental_state is None or full_key not in incremental_state: + return None + return incremental_state[full_key] + + def set_incremental_state( + self, + incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]], + key: str, + value: Dict[str, Optional[Tensor]], + ) -> Optional[Dict[str, Dict[str, Optional[Tensor]]]]: + """Helper for setting incremental state for an nn.Module.""" + if incremental_state is not None: + full_key = self._get_full_incremental_state_key(key) + incremental_state[full_key] = value + return incremental_state + + +def with_incremental_state(cls): + cls.__bases__ = (FairseqIncrementalState,) + tuple( + b for b in cls.__bases__ if b != FairseqIncrementalState + ) + return cls diff --git a/fairseq/fairseq/iterative_refinement_generator.py b/fairseq/fairseq/iterative_refinement_generator.py new file mode 100644 index 0000000000000000000000000000000000000000..3d32c6bf4dcaacde7f7834da0d1f58d59c8345a9 --- /dev/null +++ b/fairseq/fairseq/iterative_refinement_generator.py @@ -0,0 +1,359 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from collections import namedtuple + +import numpy as np +import torch +from fairseq import utils + + +DecoderOut = namedtuple( + "IterativeRefinementDecoderOut", + ["output_tokens", "output_scores", "attn", "step", "max_step", "history"], +) + + +class IterativeRefinementGenerator(object): + def __init__( + self, + tgt_dict, + models=None, + eos_penalty=0.0, + max_iter=10, + max_ratio=2, + beam_size=1, + decoding_format=None, + retain_dropout=False, + adaptive=True, + retain_history=False, + reranking=False, + ): + """ + Generates translations based on iterative refinement. + + Args: + tgt_dict: target dictionary + eos_penalty: if > 0.0, it penalized early-stopping in decoding + max_iter: maximum number of refinement iterations + max_ratio: generate sequences of maximum length ax, where x is the source length + decoding_format: decoding mode in {'unigram', 'ensemble', 'vote', 'dp', 'bs'} + retain_dropout: retaining dropout in the inference + adaptive: decoding with early stop + """ + self.bos = tgt_dict.bos() + self.pad = tgt_dict.pad() + self.unk = tgt_dict.unk() + self.eos = tgt_dict.eos() + self.vocab_size = len(tgt_dict) + self.eos_penalty = eos_penalty + self.max_iter = max_iter + self.max_ratio = max_ratio + self.beam_size = beam_size + self.reranking = reranking + self.decoding_format = decoding_format + self.retain_dropout = retain_dropout + self.retain_history = retain_history + self.adaptive = adaptive + self.models = models + + def generate_batched_itr( + self, + data_itr, + maxlen_a=None, + maxlen_b=None, + cuda=False, + timer=None, + prefix_size=0, + ): + """Iterate over a batched dataset and yield individual translations. + + Args: + maxlen_a/b: generate sequences of maximum length ax + b, + where x is the source sentence length. + cuda: use GPU for generation + timer: StopwatchMeter for timing generations. + """ + + for sample in data_itr: + if "net_input" not in sample: + continue + if timer is not None: + timer.start() + with torch.no_grad(): + hypos = self.generate( + self.models, + sample, + prefix_tokens=sample["target"][:, :prefix_size] + if prefix_size > 0 + else None, + ) + if timer is not None: + timer.stop(sample["ntokens"]) + for i, id in enumerate(sample["id"]): + # remove padding + src = utils.strip_pad(sample["net_input"]["src_tokens"][i, :], self.pad) + ref = utils.strip_pad(sample["target"][i, :], self.pad) + yield id, src, ref, hypos[i] + + @torch.no_grad() + def generate(self, models, sample, prefix_tokens=None, constraints=None): + if constraints is not None: + raise NotImplementedError( + "Constrained decoding with the IterativeRefinementGenerator is not supported" + ) + + # TODO: iterative refinement generator does not support ensemble for now. + if not self.retain_dropout: + for model in models: + model.eval() + + model, reranker = models[0], None + if self.reranking: + assert len(models) > 1, "Assuming the last checkpoint is the reranker" + assert ( + self.beam_size > 1 + ), "Reranking requires multiple translation for each example" + + reranker = models[-1] + models = models[:-1] + + if len(models) > 1 and hasattr(model, "enable_ensemble"): + assert model.allow_ensemble, "{} does not support ensembling".format( + model.__class__.__name__ + ) + model.enable_ensemble(models) + + # TODO: better encoder inputs? + src_tokens = sample["net_input"]["src_tokens"] + src_lengths = sample["net_input"]["src_lengths"] + bsz, src_len = src_tokens.size() + + # initialize + encoder_out = model.forward_encoder([src_tokens, src_lengths]) + prev_decoder_out = model.initialize_output_tokens(encoder_out, src_tokens) + + if self.beam_size > 1: + assert ( + model.allow_length_beam + ), "{} does not support decoding with length beam.".format( + model.__class__.__name__ + ) + + # regenerate data based on length-beam + length_beam_order = ( + utils.new_arange(src_tokens, self.beam_size, bsz).t().reshape(-1) + ) + encoder_out = model.encoder.reorder_encoder_out( + encoder_out, length_beam_order + ) + prev_decoder_out = model.regenerate_length_beam( + prev_decoder_out, self.beam_size + ) + bsz = bsz * self.beam_size + + sent_idxs = torch.arange(bsz) + prev_output_tokens = prev_decoder_out.output_tokens.clone() + + if self.retain_history: + prev_decoder_out = prev_decoder_out._replace(history=[prev_output_tokens]) + + finalized = [[] for _ in range(bsz)] + + def is_a_loop(x, y, s, a): + b, l_x, l_y = x.size(0), x.size(1), y.size(1) + if l_x > l_y: + y = torch.cat([y, x.new_zeros(b, l_x - l_y).fill_(self.pad)], 1) + s = torch.cat([s, s.new_zeros(b, l_x - l_y)], 1) + if a is not None: + a = torch.cat([a, a.new_zeros(b, l_x - l_y, a.size(2))], 1) + elif l_x < l_y: + x = torch.cat([x, y.new_zeros(b, l_y - l_x).fill_(self.pad)], 1) + return (x == y).all(1), y, s, a + + def finalized_hypos(step, prev_out_token, prev_out_score, prev_out_attn): + cutoff = prev_out_token.ne(self.pad) + tokens = prev_out_token[cutoff] + if prev_out_score is None: + scores, score = None, None + else: + scores = prev_out_score[cutoff] + score = scores.mean() + + if prev_out_attn is None: + hypo_attn, alignment = None, None + else: + hypo_attn = prev_out_attn[cutoff] + alignment = hypo_attn.max(dim=1)[1] + return { + "steps": step, + "tokens": tokens, + "positional_scores": scores, + "score": score, + "hypo_attn": hypo_attn, + "alignment": alignment, + } + + for step in range(self.max_iter + 1): + + decoder_options = { + "eos_penalty": self.eos_penalty, + "max_ratio": self.max_ratio, + "decoding_format": self.decoding_format, + } + prev_decoder_out = prev_decoder_out._replace( + step=step, + max_step=self.max_iter + 1, + ) + + decoder_out = model.forward_decoder( + prev_decoder_out, encoder_out, **decoder_options + ) + + if self.adaptive: + # terminate if there is a loop + terminated, out_tokens, out_scores, out_attn = is_a_loop( + prev_output_tokens, + decoder_out.output_tokens, + decoder_out.output_scores, + decoder_out.attn, + ) + decoder_out = decoder_out._replace( + output_tokens=out_tokens, + output_scores=out_scores, + attn=out_attn, + ) + + else: + terminated = decoder_out.output_tokens.new_zeros( + decoder_out.output_tokens.size(0) + ).bool() + + if step == self.max_iter: # reach last iteration, terminate + terminated.fill_(1) + + # collect finalized sentences + finalized_idxs = sent_idxs[terminated.to(sent_idxs.device)] + finalized_tokens = decoder_out.output_tokens[terminated] + finalized_scores = decoder_out.output_scores[terminated] + finalized_attn = ( + None + if (decoder_out.attn is None or decoder_out.attn.size(0) == 0) + else decoder_out.attn[terminated] + ) + + if self.retain_history: + finalized_history_tokens = [h[terminated] for h in decoder_out.history] + + for i in range(finalized_idxs.size(0)): + finalized[finalized_idxs[i]] = [ + finalized_hypos( + step, + finalized_tokens[i], + finalized_scores[i], + None if finalized_attn is None else finalized_attn[i], + ) + ] + + if self.retain_history: + finalized[finalized_idxs[i]][0]["history"] = [] + for j in range(len(finalized_history_tokens)): + finalized[finalized_idxs[i]][0]["history"].append( + finalized_hypos( + step, finalized_history_tokens[j][i], None, None + ) + ) + + # check if all terminated + if terminated.sum() == terminated.size(0): + break + + # for next step + not_terminated = ~terminated + prev_decoder_out = decoder_out._replace( + output_tokens=decoder_out.output_tokens[not_terminated], + output_scores=decoder_out.output_scores[not_terminated], + attn=decoder_out.attn[not_terminated] + if (decoder_out.attn is not None and decoder_out.attn.size(0) > 0) + else None, + history=[h[not_terminated] for h in decoder_out.history] + if decoder_out.history is not None + else None, + ) + encoder_out = model.encoder.reorder_encoder_out( + encoder_out, not_terminated.nonzero(as_tuple=False).squeeze() + ) + sent_idxs = sent_idxs[not_terminated.to(sent_idxs.device)] + prev_output_tokens = prev_decoder_out.output_tokens.clone() + + if self.beam_size > 1: + if reranker is not None: + finalized = self.rerank( + reranker, finalized, [src_tokens, src_lengths], self.beam_size + ) + + # aggregate information from length beam + finalized = [ + finalized[ + np.argmax( + [ + finalized[self.beam_size * i + j][0]["score"] + for j in range(self.beam_size) + ] + ) + + self.beam_size * i + ] + for i in range(len(finalized) // self.beam_size) + ] + + return finalized + + def rerank(self, reranker, finalized, encoder_input, beam_size): + def rebuild_batch(finalized): + finalized_tokens = [f[0]["tokens"] for f in finalized] + finalized_maxlen = max(f.size(0) for f in finalized_tokens) + final_output_tokens = ( + finalized_tokens[0] + .new_zeros(len(finalized_tokens), finalized_maxlen) + .fill_(self.pad) + ) + for i, f in enumerate(finalized_tokens): + final_output_tokens[i, : f.size(0)] = f + return final_output_tokens + + final_output_tokens = rebuild_batch(finalized) + final_output_tokens[ + :, 0 + ] = self.eos # autoregressive model assumes starting with EOS + + reranker_encoder_out = reranker.encoder(*encoder_input) + length_beam_order = ( + utils.new_arange( + final_output_tokens, beam_size, reranker_encoder_out.encoder_out.size(1) + ) + .t() + .reshape(-1) + ) + reranker_encoder_out = reranker.encoder.reorder_encoder_out( + reranker_encoder_out, length_beam_order + ) + reranking_scores = reranker.get_normalized_probs( + reranker.decoder(final_output_tokens[:, :-1], reranker_encoder_out), + True, + None, + ) + reranking_scores = reranking_scores.gather(2, final_output_tokens[:, 1:, None]) + reranking_masks = final_output_tokens[:, 1:].ne(self.pad) + reranking_scores = ( + reranking_scores[:, :, 0].masked_fill_(~reranking_masks, 0).sum(1) + ) + reranking_scores = reranking_scores / reranking_masks.sum(1).type_as( + reranking_scores + ) + + for i in range(len(finalized)): + finalized[i][0]["score"] = reranking_scores[i] + + return finalized diff --git a/fairseq/fairseq/options.py b/fairseq/fairseq/options.py new file mode 100644 index 0000000000000000000000000000000000000000..920591635a05aa4aca728321a47fed6f3c28e504 --- /dev/null +++ b/fairseq/fairseq/options.py @@ -0,0 +1,413 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import argparse +from pathlib import Path +from typing import Callable, List, Optional, Union + +import torch +from fairseq import utils +from fairseq.data.indexed_dataset import get_available_dataset_impl +from fairseq.dataclass.configs import ( + CheckpointConfig, + CommonConfig, + CommonEvalConfig, + DatasetConfig, + DistributedTrainingConfig, + EvalLMConfig, + GenerationConfig, + InteractiveConfig, + OptimizationConfig, + EMAConfig, +) +from fairseq.dataclass.utils import gen_parser_from_dataclass + +# this import is for backward compatibility +from fairseq.utils import csv_str_list, eval_bool, eval_str_dict, eval_str_list # noqa + + +def get_preprocessing_parser(default_task="translation"): + parser = get_parser("Preprocessing", default_task) + add_preprocess_args(parser) + return parser + + +def get_training_parser(default_task="translation"): + parser = get_parser("Trainer", default_task) + add_dataset_args(parser, train=True) + add_distributed_training_args(parser) + add_model_args(parser) + add_optimization_args(parser) + add_checkpoint_args(parser) + add_ema_args(parser) + return parser + + +def get_generation_parser(interactive=False, default_task="translation"): + parser = get_parser("Generation", default_task) + add_dataset_args(parser, gen=True) + add_distributed_training_args(parser, default_world_size=1) + add_generation_args(parser) + add_checkpoint_args(parser) + if interactive: + add_interactive_args(parser) + return parser + + +def get_speech_generation_parser(default_task="text_to_speech"): + parser = get_parser("Speech Generation", default_task) + add_dataset_args(parser, gen=True) + add_distributed_training_args(parser, default_world_size=1) + add_speech_generation_args(parser) + return parser + + +def get_interactive_generation_parser(default_task="translation"): + return get_generation_parser(interactive=True, default_task=default_task) + + +def get_eval_lm_parser(default_task="language_modeling"): + parser = get_parser("Evaluate Language Model", default_task) + add_dataset_args(parser, gen=True) + add_distributed_training_args(parser, default_world_size=1) + add_eval_lm_args(parser) + return parser + + +def get_validation_parser(default_task=None): + parser = get_parser("Validation", default_task) + add_dataset_args(parser, train=True) + add_distributed_training_args(parser, default_world_size=1) + group = parser.add_argument_group("Evaluation") + gen_parser_from_dataclass(group, CommonEvalConfig()) + return parser + + +def parse_args_and_arch( + parser: argparse.ArgumentParser, + input_args: List[str] = None, + parse_known: bool = False, + suppress_defaults: bool = False, + modify_parser: Optional[Callable[[argparse.ArgumentParser], None]] = None, +): + """ + Args: + parser (ArgumentParser): the parser + input_args (List[str]): strings to parse, defaults to sys.argv + parse_known (bool): only parse known arguments, similar to + `ArgumentParser.parse_known_args` + suppress_defaults (bool): parse while ignoring all default values + modify_parser (Optional[Callable[[ArgumentParser], None]]): + function to modify the parser, e.g., to set default values + """ + if suppress_defaults: + # Parse args without any default values. This requires us to parse + # twice, once to identify all the necessary task/model args, and a second + # time with all defaults set to None. + args = parse_args_and_arch( + parser, + input_args=input_args, + parse_known=parse_known, + suppress_defaults=False, + ) + suppressed_parser = argparse.ArgumentParser(add_help=False, parents=[parser]) + suppressed_parser.set_defaults(**{k: None for k, v in vars(args).items()}) + args = suppressed_parser.parse_args(input_args) + return argparse.Namespace( + **{k: v for k, v in vars(args).items() if v is not None} + ) + + from fairseq.models import ARCH_MODEL_REGISTRY, ARCH_CONFIG_REGISTRY, MODEL_REGISTRY + + # Before creating the true parser, we need to import optional user module + # in order to eagerly import custom tasks, optimizers, architectures, etc. + usr_parser = argparse.ArgumentParser(add_help=False, allow_abbrev=False) + usr_parser.add_argument("--user-dir", default=None) + usr_args, _ = usr_parser.parse_known_args(input_args) + utils.import_user_module(usr_args) + + if modify_parser is not None: + modify_parser(parser) + + # The parser doesn't know about model/criterion/optimizer-specific args, so + # we parse twice. First we parse the model/criterion/optimizer, then we + # parse a second time after adding the *-specific arguments. + # If input_args is given, we will parse those args instead of sys.argv. + args, _ = parser.parse_known_args(input_args) + + # Add model-specific args to parser. + if hasattr(args, "arch"): + model_specific_group = parser.add_argument_group( + "Model-specific configuration", + # Only include attributes which are explicitly given as command-line + # arguments or which have default values. + argument_default=argparse.SUPPRESS, + ) + if args.arch in ARCH_MODEL_REGISTRY: + ARCH_MODEL_REGISTRY[args.arch].add_args(model_specific_group) + elif args.arch in MODEL_REGISTRY: + MODEL_REGISTRY[args.arch].add_args(model_specific_group) + else: + raise RuntimeError() + + if hasattr(args, "task"): + from fairseq.tasks import TASK_REGISTRY + + TASK_REGISTRY[args.task].add_args(parser) + if getattr(args, "use_bmuf", False): + # hack to support extra args for block distributed data parallelism + from fairseq.optim.bmuf import FairseqBMUF + + FairseqBMUF.add_args(parser) + + # Add *-specific args to parser. + from fairseq.registry import REGISTRIES + + for registry_name, REGISTRY in REGISTRIES.items(): + choice = getattr(args, registry_name, None) + if choice is not None: + cls = REGISTRY["registry"][choice] + if hasattr(cls, "add_args"): + cls.add_args(parser) + elif hasattr(cls, "__dataclass"): + gen_parser_from_dataclass(parser, cls.__dataclass()) + + # Modify the parser a second time, since defaults may have been reset + if modify_parser is not None: + modify_parser(parser) + + # Parse a second time. + if parse_known: + args, extra = parser.parse_known_args(input_args) + else: + args = parser.parse_args(input_args) + extra = None + # Post-process args. + if ( + hasattr(args, "batch_size_valid") and args.batch_size_valid is None + ) or not hasattr(args, "batch_size_valid"): + args.batch_size_valid = args.batch_size + if hasattr(args, "max_tokens_valid") and args.max_tokens_valid is None: + args.max_tokens_valid = args.max_tokens + if getattr(args, "memory_efficient_fp16", False): + args.fp16 = True + if getattr(args, "memory_efficient_bf16", False): + args.bf16 = True + args.tpu = getattr(args, "tpu", False) + args.bf16 = getattr(args, "bf16", False) + if args.bf16: + args.tpu = True + if args.tpu and args.fp16: + raise ValueError("Cannot combine --fp16 and --tpu, use --bf16 on TPUs") + + if getattr(args, "seed", None) is None: + args.seed = 1 # default seed for training + args.no_seed_provided = True + else: + args.no_seed_provided = False + + if getattr(args, "update_epoch_batch_itr", None) is None: + if hasattr(args, "grouped_shuffling"): + args.update_epoch_batch_itr = args.grouped_shuffling + else: + args.grouped_shuffling = False + args.update_epoch_batch_itr = False + + # Apply architecture configuration. + if hasattr(args, "arch") and args.arch in ARCH_CONFIG_REGISTRY: + ARCH_CONFIG_REGISTRY[args.arch](args) + + if parse_known: + return args, extra + else: + return args + + +def get_parser(desc, default_task="translation"): + # Before creating the true parser, we need to import optional user module + # in order to eagerly import custom tasks, optimizers, architectures, etc. + usr_parser = argparse.ArgumentParser(add_help=False, allow_abbrev=False) + usr_parser.add_argument("--user-dir", default=None) + usr_args, _ = usr_parser.parse_known_args() + utils.import_user_module(usr_args) + + parser = argparse.ArgumentParser(allow_abbrev=False) + gen_parser_from_dataclass(parser, CommonConfig()) + + from fairseq.registry import REGISTRIES + + for registry_name, REGISTRY in REGISTRIES.items(): + parser.add_argument( + "--" + registry_name.replace("_", "-"), + default=REGISTRY["default"], + choices=REGISTRY["registry"].keys(), + ) + + # Task definitions can be found under fairseq/tasks/ + from fairseq.tasks import TASK_REGISTRY + + parser.add_argument( + "--task", + metavar="TASK", + default=default_task, + choices=TASK_REGISTRY.keys(), + help="task", + ) + # fmt: on + return parser + + +def add_preprocess_args(parser): + group = parser.add_argument_group("Preprocessing") + # fmt: off + group.add_argument("-s", "--source-lang", default=None, metavar="SRC", + help="source language") + group.add_argument("-t", "--target-lang", default=None, metavar="TARGET", + help="target language") + group.add_argument("--trainpref", metavar="FP", default=None, + help="train file prefix (also used to build dictionaries)") + group.add_argument("--validpref", metavar="FP", default=None, + help="comma separated, valid file prefixes " + "(words missing from train set are replaced with )") + group.add_argument("--testpref", metavar="FP", default=None, + help="comma separated, test file prefixes " + "(words missing from train set are replaced with )") + group.add_argument("--align-suffix", metavar="FP", default=None, + help="alignment file suffix") + group.add_argument("--destdir", metavar="DIR", default="data-bin", + help="destination dir") + group.add_argument("--thresholdtgt", metavar="N", default=0, type=int, + help="map words appearing less than threshold times to unknown") + group.add_argument("--thresholdsrc", metavar="N", default=0, type=int, + help="map words appearing less than threshold times to unknown") + group.add_argument("--tgtdict", metavar="FP", + help="reuse given target dictionary") + group.add_argument("--srcdict", metavar="FP", + help="reuse given source dictionary") + group.add_argument("--nwordstgt", metavar="N", default=-1, type=int, + help="number of target words to retain") + group.add_argument("--nwordssrc", metavar="N", default=-1, type=int, + help="number of source words to retain") + group.add_argument("--alignfile", metavar="ALIGN", default=None, + help="an alignment file (optional)") + parser.add_argument('--dataset-impl', metavar='FORMAT', default='mmap', + choices=get_available_dataset_impl(), + help='output dataset implementation') + group.add_argument("--joined-dictionary", action="store_true", + help="Generate joined dictionary") + group.add_argument("--only-source", action="store_true", + help="Only process the source language") + group.add_argument("--padding-factor", metavar="N", default=8, type=int, + help="Pad dictionary size to be multiple of N") + group.add_argument("--workers", metavar="N", default=1, type=int, + help="number of parallel workers") + group.add_argument("--dict-only", action='store_true', + help="if true, only builds a dictionary and then exits") + # fmt: on + return parser + + +def add_dataset_args(parser, train=False, gen=False): + group = parser.add_argument_group("dataset_data_loading") + gen_parser_from_dataclass(group, DatasetConfig()) + # fmt: on + return group + + +def add_distributed_training_args(parser, default_world_size=None): + group = parser.add_argument_group("distributed_training") + if default_world_size is None: + default_world_size = max(1, torch.cuda.device_count()) + gen_parser_from_dataclass( + group, DistributedTrainingConfig(distributed_world_size=default_world_size) + ) + return group + + +def add_optimization_args(parser): + group = parser.add_argument_group("optimization") + # fmt: off + gen_parser_from_dataclass(group, OptimizationConfig()) + # fmt: on + return group + + +def add_checkpoint_args(parser): + group = parser.add_argument_group("checkpoint") + # fmt: off + gen_parser_from_dataclass(group, CheckpointConfig()) + # fmt: on + return group + + +def add_common_eval_args(group): + gen_parser_from_dataclass(group, CommonEvalConfig()) + + +def add_eval_lm_args(parser): + group = parser.add_argument_group("LM Evaluation") + add_common_eval_args(group) + gen_parser_from_dataclass(group, EvalLMConfig()) + + +def add_generation_args(parser): + group = parser.add_argument_group("Generation") + add_common_eval_args(group) + gen_parser_from_dataclass(group, GenerationConfig()) + return group + + +def add_speech_generation_args(parser): + group = parser.add_argument_group("Speech Generation") + add_common_eval_args(group) # NOTE: remove_bpe is not needed + # fmt: off + group.add_argument('--eos_prob_threshold', default=0.5, type=float, + help='terminate when eos probability exceeds this') + # fmt: on + return group + + +def add_interactive_args(parser): + group = parser.add_argument_group("Interactive") + gen_parser_from_dataclass(group, InteractiveConfig()) + + +def add_model_args(parser): + group = parser.add_argument_group("Model configuration") + # fmt: off + + # Model definitions can be found under fairseq/models/ + # + # The model architecture can be specified in several ways. + # In increasing order of priority: + # 1) model defaults (lowest priority) + # 2) --arch argument + # 3) --encoder/decoder-* arguments (highest priority) + from fairseq.models import ARCH_MODEL_REGISTRY + group.add_argument('--arch', '-a', metavar='ARCH', + choices=ARCH_MODEL_REGISTRY.keys(), + help='model architecture') + # fmt: on + return group + + +def get_args( + data: Union[str, Path], + task: str = "translation", + arch: str = "transformer", + **overrides +): + parser = get_training_parser(task) + args = parse_args_and_arch(parser, [str(data), "--task", task, "--arch", arch]) + + for k, v in overrides.items(): + setattr(args, k, v) + + return args + + +def add_ema_args(parser): + group = parser.add_argument_group("EMA configuration") + gen_parser_from_dataclass(group, EMAConfig()) diff --git a/fairseq/fairseq/pdb.py b/fairseq/fairseq/pdb.py new file mode 100644 index 0000000000000000000000000000000000000000..1ba6ef0d336b30717cfdde94e1b838cfe2bfeb20 --- /dev/null +++ b/fairseq/fairseq/pdb.py @@ -0,0 +1,47 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import multiprocessing +import os +import pdb +import sys + + +__all__ = ["set_trace"] + + +_stdin = [None] +_stdin_lock = multiprocessing.Lock() +try: + _stdin_fd = sys.stdin.fileno() +except Exception: + _stdin_fd = None + + +class MultiprocessingPdb(pdb.Pdb): + """A Pdb wrapper that works in a multiprocessing environment. + + Usage: `from fairseq import pdb; pdb.set_trace()` + """ + + def __init__(self): + pdb.Pdb.__init__(self, nosigint=True) + + def _cmdloop(self): + stdin_bak = sys.stdin + with _stdin_lock: + try: + if _stdin_fd is not None: + if not _stdin[0]: + _stdin[0] = os.fdopen(_stdin_fd) + sys.stdin = _stdin[0] + self.cmdloop() + finally: + sys.stdin = stdin_bak + + +def set_trace(): + pdb = MultiprocessingPdb() + pdb.set_trace(sys._getframe().f_back) diff --git a/fairseq/fairseq/token_generation_constraints.py b/fairseq/fairseq/token_generation_constraints.py new file mode 100644 index 0000000000000000000000000000000000000000..e708dc51bcb0ffb7b411496239c74d5e6f3c2448 --- /dev/null +++ b/fairseq/fairseq/token_generation_constraints.py @@ -0,0 +1,506 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +"""Implements tracking of constraints for a beam item. + +A list of constraints is given as a list of one or more token +sequences, each of length at least one token. For example, for an input sentence + +> Die maschinelle Übersetzung ist schwer zu kontrollieren. + +We could have the constraints: +* to influence +* hard + +There are two implementations: +* OrderedConstraintState: Tracks progress through an ordered list of multitoken constraints. +* UnorderedConstraintState: Tracks progress through an unordered list of multitoken constraints. + +The difference is that in the first, the constraints are assumed to be +in order; the algorithm will permit zero or more tokens between them. +In the second, the constraints are not ordered, so many orderings will +be explored. + +The same sequence can be present any number of times, and will appear +that many times in the output. +""" + +from collections import Counter +from typing import List, Optional, Set, Tuple + +import torch + + +class ConstraintState: + def __init__(self): + pass + + +def pack_constraints(batch_constraints: List[List[torch.Tensor]]) -> torch.Tensor: + """Takes a list of list of constraints in tensor form (a list of + tensor constraints for each sentence) and transforms it into a + packed Tensor. For example, here is a batch of size 3 with 3, 0, + and 1 constraints: + + [ [ [3 1 2], [3], [4 5 6 7], ] + [], + [ [1 8 9 10 1 4 11 12], ] + ] + + Its corresponding packed structure is: + + [ [ 3 3 1 2 0 3 0 4 5 6 7 0], + [ 0 0 0 0 0 0 0 0 0 0 0 0], + [ 1 1 8 9 10 1 4 11 12 0 0 0] ] + + The packed tensor has shape (batch size, maxlen), where + maxlen is defined below. Each row contains concatenated + constraint tokens for that sentence, with 0 appended after + each constraint. The first item in each row is the number + of constraints for that sentence. So maxlen is the maximum + of + + (number of constraints) + (sum length of constraints) + 1. + + across all sentences in the batch. + """ + # The maximum word length of concatenated constraints for any sentence + max_constraints_len = 1 + for sentence_constraints in batch_constraints: + if len(sentence_constraints): + # number of constraints, plus sum of constrain lens, plus a zero after each + constraints_len = ( + 1 + + sum([c.size(0) for c in sentence_constraints]) + + len(sentence_constraints) + ) + max_constraints_len = max(max_constraints_len, constraints_len) + + batch_size = len(batch_constraints) + constraints_tensor = torch.zeros((batch_size, max_constraints_len)).long() + for i, sentence_constraints in enumerate(batch_constraints): + constraints_tensor[i, 0] = len(sentence_constraints) + offset = 1 + for j, constraint in enumerate(sentence_constraints): + this_len = constraint.size(0) + constraints_tensor[i, offset : offset + this_len] = constraint + offset += this_len + 1 + + return constraints_tensor.long() + + +def unpack_constraints(constraint_tensor: torch.Tensor) -> List[torch.Tensor]: + """ + Transforms *one row* of a packed constraint tensor (e.g., for one + sentence in the batch) into a list of constraint tensors. + """ + constraint_list = [] + num_constraints = constraint_tensor[0] + constraints = constraint_tensor.tolist() + offset = 1 + for i in range(num_constraints): + where = constraints.index(0, offset) + constraint_list.append(constraint_tensor[offset:where]) + offset = where + 1 + + return constraint_list + + +class ConstraintNode: + """ + Represents a node in a trie managing unordered constraints. + """ + + def __init__(self, token: int = None, parent=None): + # The token associate with this node (None for the root) + self.token = int(token) if token is not None else None + # The parent (None at the root) + self.parent = parent + # Whether this node is a completed constraint + self.terminal = 0 + # List of child nodes + self.children = {} + + # The cumulative number of constraints from this point in the + # trie forward + self.num_constraints = 0 + + @property + def id(self): + return self.token + + def __str__(self): + term = self.terminal != 0 + return f"[{self.token}].{term}#{self.num_constraints}" + + def __getitem__(self, key: int): + return self.children.get(key, None) + + def next_tokens(self) -> Set[int]: + """The set of child labels.""" + return set(self.children.keys()) + + @staticmethod + def create(constraints: List[List[int]]): + root = ConstraintNode() + for sequence in constraints: + root.add_sequence(sequence) + + return root + + @staticmethod + def print_graph(node: "ConstraintNode"): + if len(node.children) == 0: + return str(node) + else: + s = f"({node}" + for child in node.children.values(): + s += " " + ConstraintNode.print_graph(child) + s += ")" + return s + + def token_counts(self) -> Counter: + """Returns a counter of the number of times each token is used + in a constraint. + """ + token_counts = Counter() + kids = list(self.children.values()) + while len(kids) > 0: + kid = kids.pop() + token_counts[kid.id] += kid.num_constraints + kids += list(kid.children.values()) + + return token_counts + + def tokens(self) -> Set[int]: + """Returns the set of tokens in constraints.""" + return set(self.token_counts().keys()) + + def add_sequence(self, sequence: List[int]): + """Adds a constraint, represented as a list of integers, to + the trie.""" + assert len(sequence) > 0 + + token = int(sequence[0]) + if token not in self.children: + self.children[token] = ConstraintNode(token, parent=self) + + node = self.children[token] + if len(sequence) == 1: + node.terminal += 1 + node.num_constraints += 1 + parent = node.parent + while parent is not None: + parent.num_constraints += 1 + parent = parent.parent + else: + node.add_sequence(sequence[1:]) + + +class UnorderedConstraintState(ConstraintState): + """ + Records progress through the set of constraints for each item in the beam + using a trie. + """ + + def __init__(self, node: ConstraintNode, copy_from: "ConstraintState" = None): + self.node = node + + if copy_from is None: + # The root node + self.root = node + # The set of states in the graph that have been completed + self.completed = Counter() + # The... + self.generated = Counter() + # The list of tokens we need to generate + self.needed_tokens = self.root.tokens() + else: + self.completed = Counter(copy_from.completed) + self.generated = Counter(copy_from.generated) + self.root = copy_from.root + + # Mark the node as generated + if self.node != self.root: + self.generated[node] += 1 + + @staticmethod + def create(constraint_tensor: torch.Tensor): + constraint_list = unpack_constraints(constraint_tensor) + constraint_trie_root = ConstraintNode.create(constraint_list) + return UnorderedConstraintState(constraint_trie_root) + + def __str__(self): + gen_str = ",".join([str(node) for node in self.generated]) + return f"{self.name}/{self.bank}({gen_str})x{self.num_completed}" + + def __copy__(self): + copied_state = UnorderedConstraintState(self.node, copy_from=self) + return copied_state + + def copy(self): + return self.__copy__() + + @property + def name(self): + if self.node.id is None: + return "ROOT" + else: + return str(self.node.id) + + @property + def is_root(self): + return self.node == self.root + + @property + def bank(self): + return sum(self.generated.values()) + + @property + def num_completed(self): + """The number of constraints (not constraint tokens) that are completed. + In addition to the already-completed states, we need to account for the + current state, which might get marked as completed when another token + is generated. + """ + in_final = self.node.terminal and self.completed[self.node] < self.node.terminal + return sum(self.completed.values()) + in_final + + @property + def finished(self): + return self.root.num_constraints - self.num_completed == 0 + + @property + def token_counts(self): + return self.root.token_counts() + + @property + def tokens(self): + return self.root.tokens() + + @property + def num_constraint_tokens(self): + return sum(self.token_counts.values()) + + def next_tokens(self) -> Set[int]: + """Returns the list of tokens that could come next. + These are (a) all tokens extending the root state and, for + non-root states, additionally all tokens extending the current + state.""" + + if self.node != self.root: + return self.root.next_tokens().union(self.node.next_tokens()) + else: + return self.root.next_tokens() + + def advance(self, token: int): + """Reads in a token and advances the state. Here's how it works. + + We can advance to the next state if: + - there is a matching child + - its path isn't blocked + + A path is blocked when all constraints that are descendants of + that node have already been generated, in the current state. + + If we are not able to advance from the current state, we "fall + off the graph" and return to the root state. There, we again + try to advance, checking the same criteria. + + In any case, when falling off the graph, we need to do some + bookkeeping. We: + - check whether any constraints were met (all prefixes of + current state) + - if one is found, mark it as completed + - adjust visited nodes accordingly + """ + token = int(token) + + next_state = None + child = self.node[token] + if child is not None and self.generated[child] < child.num_constraints: + next_state = UnorderedConstraintState(child, copy_from=self) + + def rewind(): + """If we're mid-trie and an "illegal" token is chosen next, we need + to reset our state to the root state. However, along the way, we need + to check whether a prefix of the current trie state represents a state + we could mark as completed. + """ + node = self.node + while node != self.root: + if node.terminal and self.completed[node] < node.terminal: + next_state.completed[node] += 1 + return + + next_state.generated[node] -= 1 + node = node.parent + + # Fall off the graph, check the root + if next_state is None and token in self.root.next_tokens(): + child = self.root[token] + # We can only traverse this edge if it's not saturated + if self.generated[child] < child.num_constraints: + next_state = UnorderedConstraintState(child, copy_from=self) + else: + next_state = UnorderedConstraintState(self.root, copy_from=self) + + # Rewind + rewind() + + elif next_state is None: + next_state = UnorderedConstraintState(self.root, copy_from=self) + # Rewind + rewind() + + return next_state + + +class ConstraintSequence: + def __init__(self, sequences: List[List[int]]): + """Represents a set of possibly multitoken constraints by + concatenating them and internally recording the end points. + """ + self.sequences = [] + self.endpoints = [] + self.num_tokens = 0 + self.tokens = set() + for sequence in sequences: + for token in sequence: + self.tokens.add(token) + self.num_tokens += len(sequence) + self.endpoints += [False for x in range(len(sequence) - 1)] + [True] + self.sequences += sequence + + def __getitem__(self, key: int): + return self.sequences[key] + + def __len__(self): + return len(self.sequences) + + def __str__(self): + return str(self.sequences) + + +class OrderedConstraintState(ConstraintState): + """ + Records progress through the set of linear nonbranching constraints with gaps. + """ + + def __init__(self, sequence: ConstraintSequence, state: int = -1): + self.sequence = sequence + self.state = state + + @staticmethod + def create(constraint_tensor: torch.Tensor): + constraint_list = unpack_constraints(constraint_tensor) + return OrderedConstraintState(ConstraintSequence(constraint_list), -1) + + def __str__(self): + return f"{self.state}/{self.bank}x{self.num_completed}" + + def __copy__(self): + return OrderedConstraintState(self.sequence, self.state) + + def copy(self): + return self.__copy__() + + @property + def num_completed(self): + if self.state == -1: + return 0 + count = len( + list(filter(lambda x: x, self.sequence.endpoints[0 : self.state + 1])) + ) + return count + + @property + def is_root(self): + return self.state == -1 + + @property + def name(self): + if self.state == -1: + return "ROOT" + else: + return str(self.sequence[self.state]) + + @property + def bank(self) -> int: + return self.state + 1 + + @property + def finished(self): + return self.state + 1 == len(self.sequence) + + @property + def token_counts(self): + return self.sequence.token_counts() + + @property + def tokens(self): + return self.sequence.tokens + + @property + def num_constraint_tokens(self): + return sum(self.token_counts.values()) + + def next_tokens(self) -> Set[int]: + """Returns the list of tokens that could come next. + These are (a) all tokens extending the root state and, for + non-root states, additionally all tokens extending the current + state.""" + + tokens = set() + if self.state > 0: + tokens.add(self.sequence[0]) + if not self.finished: + tokens.add(self.sequence[self.state + 1]) + return tokens + + def advance(self, token: int): + """Reads in a token and advances the state. Here's how it works. + + We can advance to the next state if: + - there is a matching child + - its path isn't blocked + + A path is blocked when all constraints that are descendants of + that node have already been generated, in the current state. + + If we are not able to advance from the current state, we "fall + off the graph" and return to the root state. There, we again + try to advance, checking the same criteria. + + In any case, when falling off the graph, we need to do some + bookkeeping. We: + - check whether any constraints were met (all prefixes of + current state) + - if one is found, mark it as completed + - adjust visited nodes accordingly + """ + token = int(token) + # print(f"{self} ADVANCE({token}) {self.sequence} -> ", end="") + + if self.finished: + # Accept anything + next_state = self.copy() + + elif self.sequence[self.state + 1] == token: + # Advance to the next token + next_state = OrderedConstraintState(self.sequence, self.state + 1) + + elif self.sequence.endpoints[self.state]: + # Accept anything between constraints (*) + next_state = self.copy() + + elif token == self.sequence[0]: + # Start over having generated the first token + next_state = OrderedConstraintState(self.sequence, 0) + else: + # Start over from the root + next_state = OrderedConstraintState(self.sequence, -1) + + return next_state