Skip to content

Question about EuclideanCodebook expire_codes_ and EMA update behavior #89

@xyz-zy

Description

@xyz-zy

Hi, thank you for the excellent work and released model!

I am trying to reproduce training the small 75FPS model from scratch, and I am seeing an unexpectedly large number of dead codewords when running:

python train.py fit --config configs/wavtokenizer_smalldata_frame75_3s_nq1_code4096_dim512_kmeans200_attn.yaml

This seems to be caused by the logic in EuclideanCodebook.forward(), when expire_codes_ is called to replace the dead codewords. expire_codes_ does not re-compute assignment counts, so afterwards in forward, the EMA update to self.cluster_size is made with the old embed_onehot.sum(), which is 0 for the respawned codewords. Similarly, self.embed_avg is updated with the old information embed_sum = x.t() @ embed_onehot which is 0-vectors for the respawned codewords.

So, at the end of forward(), self.embed.data.copy_(embed_normalized) seems to be updated with old information about the codewords that does not take into account the codewords that were respawned earlier in the function call.

Is this the same logic that was used in the original training of WavTokenizer? Maybe this is not the most up-to-date version of the code? Possibly I'm missing something. Help understanding how this should work would be appreciated :)

Image

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions