# it is possible that two elements of the noise is the same, so do a while loop to avoid it
while True:
noise = torch.rand(bsz, seq_len,
device=rep.device) # noise in [0, 1]
sorted_noise, _ = torch.sort(
noise, dim=1) # ascend: small is remove, large is keep
cutoff_drop = sorted_noise[:, num_dropped_tokens -
1:num_dropped_tokens]
cutoff_mask = sorted_noise[:,
num_masked_tokens - 1:num_masked_tokens]
token_drop_mask = (noise <= cutoff_drop).float()
token_all_mask = (noise <= cutoff_mask).float()
if token_drop_mask.sum() == bsz * num_dropped_tokens and \
token_all_mask.sum() == bsz * num_masked_tokens:
break
else:
print("Rerandom the noise!")
The code you wrote in the frame encoder to generate the mask and drop index seems to have unavoidable overlap. The drop_mask could be a subset of all_mask, which appears inconsistent with your comment above. Is it a mistake or a experiment setting.
The code you wrote in the frame encoder to generate the mask and drop index seems to have unavoidable overlap. The drop_mask could be a subset of all_mask, which appears inconsistent with your comment above. Is it a mistake or a experiment setting.