@@ -167,8 +167,10 @@ def __init__(
167167 stride ,
168168 padding ,
169169 frame_stride = 1 ,
170+ frame_padding = None ,
170171 frame_pooling_stride = 1 ,
171172 frame_pooling_kernel_size = 1 ,
173+ frame_pooling_padding = None ,
172174 pooling_kernel_size = 3 ,
173175 pooling_stride = 2 ,
174176 pooling_padding = 1 ,
@@ -188,16 +190,22 @@ def __init__(
188190
189191 n_filter_list_pairs = zip (n_filter_list [:- 1 ], n_filter_list [1 :])
190192
193+ if frame_padding is None :
194+ frame_padding = frame_kernel_size // 2
195+
196+ if frame_pooling_padding is None :
197+ frame_pooling_padding = frame_pooling_kernel_size // 2
198+
191199 self .conv_layers = nn .Sequential (
192200 * [nn .Sequential (
193201 nn .Conv3d (chan_in , chan_out ,
194202 kernel_size = (frame_kernel_size , kernel_size , kernel_size ),
195203 stride = (frame_stride , stride , stride ),
196- padding = (frame_kernel_size // 2 , padding , padding ), bias = conv_bias ),
204+ padding = (frame_padding , padding , padding ), bias = conv_bias ),
197205 nn .Identity () if not exists (activation ) else activation (),
198206 nn .MaxPool3d (kernel_size = (frame_pooling_kernel_size , pooling_kernel_size , pooling_kernel_size ),
199207 stride = (frame_pooling_stride , pooling_stride , pooling_stride ),
200- padding = (frame_pooling_kernel_size // 2 , pooling_padding , pooling_padding )) if max_pool else nn .Identity ()
208+ padding = (frame_pooling_padding , pooling_padding , pooling_padding )) if max_pool else nn .Identity ()
201209 )
202210 for chan_in , chan_out in n_filter_list_pairs
203211 ])
@@ -324,8 +332,10 @@ def __init__(
324332 n_conv_layers = 1 ,
325333 frame_stride = 1 ,
326334 frame_kernel_size = 3 ,
335+ frame_padding = None ,
327336 frame_pooling_kernel_size = 1 ,
328337 frame_pooling_stride = 1 ,
338+ frame_pooling_padding = None ,
329339 kernel_size = 7 ,
330340 stride = 2 ,
331341 padding = 3 ,
@@ -342,8 +352,10 @@ def __init__(
342352 n_output_channels = embedding_dim ,
343353 frame_stride = frame_stride ,
344354 frame_kernel_size = frame_kernel_size ,
355+ frame_padding = frame_padding ,
345356 frame_pooling_stride = frame_pooling_stride ,
346357 frame_pooling_kernel_size = frame_pooling_kernel_size ,
358+ frame_pooling_padding = frame_pooling_padding ,
347359 kernel_size = kernel_size ,
348360 stride = stride ,
349361 padding = padding ,
0 commit comments