Skip to content

might be a error #15

@supertx

Description

@supertx
# it is possible that two elements of the noise is the same, so do a while loop to avoid it
while True:
    noise = torch.rand(bsz, seq_len,
                       device=rep.device)  # noise in [0, 1]
    sorted_noise, _ = torch.sort(
        noise, dim=1)  # ascend: small is remove, large is keep
    cutoff_drop = sorted_noise[:, num_dropped_tokens -
                               1:num_dropped_tokens]
    cutoff_mask = sorted_noise[:,
                               num_masked_tokens - 1:num_masked_tokens]
    token_drop_mask = (noise <= cutoff_drop).float()
    token_all_mask = (noise <= cutoff_mask).float()
    if token_drop_mask.sum() == bsz * num_dropped_tokens and \
            token_all_mask.sum() == bsz * num_masked_tokens:
        break
    else:
        print("Rerandom the noise!")

The code you wrote in the frame encoder to generate the mask and drop index seems to have unavoidable overlap. The drop_mask could be a subset of all_mask, which appears inconsistent with your comment above. Is it a mistake or a experiment setting.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions