Hi, thank you for the excellent work and released model!
I am trying to reproduce training the small 75FPS model from scratch, and I am seeing an unexpectedly large number of dead codewords when running:
python train.py fit --config configs/wavtokenizer_smalldata_frame75_3s_nq1_code4096_dim512_kmeans200_attn.yaml
This seems to be caused by the logic in EuclideanCodebook.forward(), when expire_codes_ is called to replace the dead codewords. expire_codes_ does not re-compute assignment counts, so afterwards in forward, the EMA update to self.cluster_size is made with the old embed_onehot.sum(), which is 0 for the respawned codewords. Similarly, self.embed_avg is updated with the old information embed_sum = x.t() @ embed_onehot which is 0-vectors for the respawned codewords.
So, at the end of forward(), self.embed.data.copy_(embed_normalized) seems to be updated with old information about the codewords that does not take into account the codewords that were respawned earlier in the function call.
Is this the same logic that was used in the original training of WavTokenizer? Maybe this is not the most up-to-date version of the code? Possibly I'm missing something. Help understanding how this should work would be appreciated :)

Hi, thank you for the excellent work and released model!
I am trying to reproduce training the small 75FPS model from scratch, and I am seeing an unexpectedly large number of dead codewords when running:
This seems to be caused by the logic in
EuclideanCodebook.forward(), whenexpire_codes_is called to replace the dead codewords.expire_codes_does not re-compute assignment counts, so afterwards inforward, the EMA update toself.cluster_sizeis made with the oldembed_onehot.sum(), which is 0 for the respawned codewords. Similarly,self.embed_avgis updated with the old informationembed_sum = x.t() @ embed_onehotwhich is 0-vectors for the respawned codewords.So, at the end of
forward(),self.embed.data.copy_(embed_normalized)seems to be updated with old information about the codewords that does not take into account the codewords that were respawned earlier in the function call.Is this the same logic that was used in the original training of WavTokenizer? Maybe this is not the most up-to-date version of the code? Possibly I'm missing something. Help understanding how this should work would be appreciated :)