Skip to content

Commit 2b35f84

Browse files
mgouicemdzarukin
authored andcommitted
api: extend dropout attribute with offset host_scalars and s64
1 parent a8eb12b commit 2b35f84

33 files changed

+485
-143
lines changed

doc/programming_model/attributes_dropout.md

Lines changed: 35 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -16,32 +16,51 @@ on a deterministic PRNG (current implementation uses a variation of Philox
1616
algorithm) and transforms the values as follows:
1717

1818
\f[
19-
\mathrm{mask}[:] = (\mathrm{PRNG}(S, ...) > P) \\
19+
\mathrm{mask}[:] = (\mathrm{PRNG}(S, \mathrm{off}, :) > P) \\
2020
\mathrm{dst}[:] = \mathrm{mask}[:] \cdot {{\mathrm{dst}[:]} \over {1 - P}}
2121
\f]
2222

2323
where:
2424

25-
* \f$\mathrm{mask}\f$ is the output buffer (always of the same dimensions and
26-
usually of the same layout as \f$\mathrm{dst}\f$, but potentially differing from
27-
it in type that can only be `u8`) whose values may be either 0 if the
28-
corresponding value in \f$\mathrm{dst}\f$ got zeroed (a.k.a. dropped out) or 1
29-
otherwise
30-
* \f$S\f$ is the integer seed for the PRNG algorithm
25+
* \f$\mathrm{mask}\f$ values may be either 0 if the corresponding value in
26+
\f$\mathrm{dst}\f$ got zeroed (a.k.a. dropped out) or 1, otherwise.
27+
* \f$S, off\f$ are the seed and the offset for the PRNG algorithm.
3128
* \f$P\f$ is the probability for any given value to get dropped out,
32-
\f$0 \leq P \leq 1\f$
29+
\f$0 \leq P \leq 1\f$
3330

3431
## API
3532

3633
- C: @ref dnnl_primitive_attr_get_dropout, @ref dnnl_primitive_attr_set_dropout
3734
- C++: @ref dnnl::primitive_attr::get_dropout, @ref dnnl::primitive_attr::set_dropout
3835

39-
If the dropout operation gets specified in the primitive's attributes, the user
40-
must provide three additional buffers to it on execution:
36+
The dropout primitive attribute has the following parameters:
4137

42-
* `DNNL_ARG_ATTR_DROPOUT_MASK`: through this ID the user has to pass the
43-
\f$\mathrm{mask}\f$ output buffer
44-
* `DNNL_ARG_ATTR_DROPOUT_PROBABILITY`: this is a single-value `f32` input buffer
45-
that holds \f$P\f$
46-
* `DNNL_ARG_ATTR_DROPOUT_SEED`: this is a single-value `s32` input buffer that
47-
holds \f$S\f$
38+
* `mask_desc`: when set to a zero (or empty) memory descriptor, mask values are
39+
not written to the memory. Otherwise, it should have the same dimensions and
40+
the same layout as \f$\mathrm{dst}\f$, as well as `u8` data type.
41+
* `seed_dt`: data type of the seed argument \f$S\f$, `s64` is recommended, `s32`
42+
is supported as a backward compatibility option.
43+
* `use_offset`: boolean to express if an offset argument will be provided
44+
by the user at the execution time. When false, an offset of 0 is assumed.
45+
* `use_host_scalars`: boolean specifying if probability, seed, and offset memory
46+
arguments will be passed as host_scalar memory objects when `true`, or
47+
as device memory objects, otherwise.
48+
49+
When the dropout primitive attribute is set, the user must provide two
50+
additional memory arguments to the primitive execution:
51+
52+
* `DNNL_ARG_ATTR_DROPOUT_PROBABILITY`: this is a single-value `f32` input memory
53+
argument that holds \f$P\f$.
54+
* `DNNL_ARG_ATTR_DROPOUT_SEED`: this is a single-value input memory argument
55+
that holds \f$S\f$. Its data type is specified by the `seed_dt` primitive
56+
attribute parameter and can be either `s32` or `s64`.
57+
58+
Additionally, the following arguments conditionally need to be passed
59+
at the execution time as well:
60+
61+
* `DNNL_ARG_ATTR_DROPOUT_MASK`: if the `mask_desc` primitive attribute parameter
62+
is not a zero memory descriptor, the user must pass the \f$\mathrm{mask}\f$
63+
through this output memory argument.
64+
* `DNNL_ARG_ATTR_DROPOUT_OFFSET`: if the `use_offset` primitive attribute
65+
parameter is set, the user must pass the \f$\mathrm{off}\f$ through this
66+
input memory argument. This is a single-value `s64` memory argument.

include/oneapi/dnnl/dnnl.h

Lines changed: 44 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -271,24 +271,61 @@ dnnl_status_t DNNL_API dnnl_primitive_attr_clone(
271271
/// otherwise.
272272
dnnl_status_t DNNL_API dnnl_primitive_attr_destroy(dnnl_primitive_attr_t attr);
273273

274-
/// Returns probability for output dropout primitive attribute.
274+
/// Gets dropout primitive attribute.
275275
///
276276
/// @param attr Primitive attributes.
277-
/// @param dropout_desc Output dropout memory descriptor
277+
/// @param mask_desc Output memory descriptor for dropout masks. If a default
278+
/// memory descriptor is returned, the mask values will not be written to
279+
/// the output memory buffer during the primitive execution.
278280
/// @returns #dnnl_success on success and a status describing the error
279281
/// otherwise.
280282
dnnl_status_t DNNL_API dnnl_primitive_attr_get_dropout(
281-
const_dnnl_primitive_attr_t attr,
282-
const_dnnl_memory_desc_t *dropout_desc);
283+
const_dnnl_primitive_attr_t attr, const_dnnl_memory_desc_t *mask_desc);
283284

284-
/// Sets probability for output dropout primitive attribute.
285+
/// Sets dropout primitive attribute.
285286
///
286287
/// @param attr Primitive attributes.
287-
/// @param dropout_desc Output dropout memory descriptor
288+
/// @param mask_desc Memory descriptor for dropout masks. If a default memory
289+
/// descriptor is passed, the mask values will not be written to the output
290+
/// memory buffer during the primitive execution.
288291
/// @returns #dnnl_success on success and a status describing the error
289292
/// otherwise.
290293
dnnl_status_t DNNL_API dnnl_primitive_attr_set_dropout(
291-
dnnl_primitive_attr_t attr, const_dnnl_memory_desc_t dropout_desc);
294+
dnnl_primitive_attr_t attr, const_dnnl_memory_desc_t mask_desc);
295+
296+
/// Gets dropout primitive attribute parameters.
297+
///
298+
/// @param attr Primitive attributes.
299+
/// @param mask_desc Output memory descriptor for dropout masks. If a default
300+
/// memory descriptor is returned, the mask values will not be written to
301+
/// the output memory buffer during the primitive execution.
302+
/// @param seed_dt Output datatype for seed argument.
303+
/// @param use_offset Output boolean. If true, an offset argument must be passed
304+
/// at the execution and will be used in random number generation.
305+
/// @param use_host_scalars Output boolean. If true, probability, seed and
306+
/// offset arguments are passed as host_scalar memory objects.
307+
/// @returns #dnnl_success on success and a status describing the error
308+
/// otherwise.
309+
dnnl_status_t DNNL_API dnnl_primitive_attr_get_dropout_v2(
310+
const_dnnl_primitive_attr_t attr, const_dnnl_memory_desc_t *mask_desc,
311+
dnnl_data_type_t *seed_dt, int *use_offset, int *use_host_scalars);
312+
313+
/// Sets dropout primitive attribute parameters.
314+
///
315+
/// @param attr Primitive attributes.
316+
/// @param mask_desc Memory descriptor for dropout masks. If a default memory
317+
/// descriptor is passed, the mask values will not be written to the output
318+
/// memory buffer during the primitive execution.
319+
/// @param seed_dt Datatype for seed argument.
320+
/// @param use_offset If true, an offset argument must be passed at the
321+
/// execution and will be used in random number generation.
322+
/// @param use_host_scalars If true, probability, seed and offset arguments are
323+
/// passed as host_scalar memory objects.
324+
/// @returns #dnnl_success on success and a status describing the error
325+
/// otherwise.
326+
dnnl_status_t DNNL_API dnnl_primitive_attr_set_dropout_v2(
327+
dnnl_primitive_attr_t attr, const_dnnl_memory_desc_t mask_desc,
328+
dnnl_data_type_t seed_dt, int use_offset, int use_host_scalars);
292329

293330
/// Returns the floating-point math mode primitive attribute.
294331
///

include/oneapi/dnnl/dnnl.hpp

Lines changed: 55 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4061,7 +4061,9 @@ struct primitive_attr : public handle<dnnl_primitive_attr_t> {
40614061

40624062
/// Returns the parameters of a dropout attribute.
40634063
///
4064-
/// @param mask_desc Output memory descriptor of a dropout mask.
4064+
/// @param mask_desc Output memory descriptor for dropout masks. If a
4065+
/// default memory descriptor is returned, the mask values will not be
4066+
/// written to the output memory buffer during the primitive execution.
40654067
void get_dropout(memory::desc &mask_desc) const {
40664068
const_dnnl_memory_desc_t cdesc;
40674069
error::wrap_c_api(dnnl_primitive_attr_get_dropout(get(), &cdesc),
@@ -4072,15 +4074,66 @@ struct primitive_attr : public handle<dnnl_primitive_attr_t> {
40724074
mask_desc = memory::desc(cloned_md);
40734075
}
40744076

4077+
/// Returns the parameters of a dropout attribute.
4078+
///
4079+
/// @param mask_desc Output memory descriptor for dropout masks. If a
4080+
/// default memory descriptor is returned, the mask values will not be
4081+
/// written to the output memory buffer during the primitive execution.
4082+
/// @param seed_dt Output datatype for seed argument.
4083+
/// @param use_offset Output boolean. If true, an offset argument must be
4084+
/// passed at the execution and will be used in random number
4085+
/// generation.
4086+
/// @param use_host_scalars Output boolean. If true, probability, seed and
4087+
/// offset arguments are passed as host_scalar memory objects.
4088+
void get_dropout(memory::desc &mask_desc, memory::data_type &seed_dt,
4089+
bool &use_offset, bool &use_host_scalars) const {
4090+
const_dnnl_memory_desc_t cdesc;
4091+
dnnl_data_type_t c_seed_dt;
4092+
int c_use_offset;
4093+
int c_use_host_scalars;
4094+
error::wrap_c_api(
4095+
dnnl_primitive_attr_get_dropout_v2(get(), &cdesc, &c_seed_dt,
4096+
&c_use_offset, &c_use_host_scalars),
4097+
"could not get parameters of a dropout attribute");
4098+
dnnl_memory_desc_t cloned_md = nullptr;
4099+
error::wrap_c_api(dnnl_memory_desc_clone(&cloned_md, cdesc),
4100+
"could not clone a memory descriptor");
4101+
mask_desc = memory::desc(cloned_md);
4102+
seed_dt = static_cast<memory::data_type>(c_seed_dt);
4103+
use_offset = c_use_offset;
4104+
use_host_scalars = c_use_host_scalars;
4105+
}
4106+
40754107
/// Sets dropout probability.
40764108
///
4077-
/// @param mask_desc Output memory descriptor of a dropout mask.
4109+
/// @param mask_desc Memory descriptor for dropout masks. If a default
4110+
/// memory descriptor is passed, the mask values will not be written to
4111+
/// the output memory buffer during the primitive execution.
40784112
void set_dropout(const memory::desc &mask_desc) {
40794113
error::wrap_c_api(
40804114
dnnl_primitive_attr_set_dropout(get(), mask_desc.get()),
40814115
"could not set dropout primitive attribute");
40824116
}
40834117

4118+
/// Sets dropout probability.
4119+
///
4120+
/// @param mask_desc Memory descriptor for dropout masks. If a default
4121+
/// memory descriptor is passed, the mask values will not be written to
4122+
/// the output memory buffer during the primitive execution.
4123+
/// @param seed_dt Datatype for seed argument.
4124+
/// @param use_offset If true, an offset argument must be passed at the
4125+
/// execution and will be used in random number generation.
4126+
/// @param use_host_scalars If true, probability, seed and offset arguments
4127+
/// are passed as host_scalar memory objects.
4128+
void set_dropout(const memory::desc &mask_desc, memory::data_type seed_dt,
4129+
bool use_offset, bool use_host_scalars) {
4130+
error::wrap_c_api(
4131+
dnnl_primitive_attr_set_dropout_v2(get(), mask_desc.get(),
4132+
memory::convert_to_c(seed_dt), use_offset,
4133+
use_host_scalars),
4134+
"could not set dropout primitive attribute");
4135+
}
4136+
40844137
/// Returns the fpmath mode
40854138
fpmath_mode get_fpmath_mode() const {
40864139
dnnl_fpmath_mode_t result;

include/oneapi/dnnl/dnnl_types.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2694,6 +2694,9 @@ typedef const struct dnnl_primitive *const_dnnl_primitive_t;
26942694
/// A special mnemonic for shift argument of normalization primitives.
26952695
#define DNNL_ARG_DIFF_SHIFT 256
26962696

2697+
/// Dropout offset value passed via a buffer
2698+
#define DNNL_ARG_ATTR_DROPOUT_OFFSET 507
2699+
26972700
/// Rounding mode seed for stochastic rounding
26982701
/// Single seed needed independently of how many arguments need stochastic rounding
26992702
#define DNNL_ARG_ATTR_ROUNDING_SEED 508

scripts/verbose_converter/src/benchdnn_generator.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -250,6 +250,13 @@ def dropout(self):
250250
result = "0.5:12345"
251251
if dropout.tag:
252252
result += f":{dropout.tag}"
253+
# Seed dt is always s64 in benchdnn and is not passed to driver
254+
if dropout.use_offset == "1":
255+
result += ":987654321"
256+
else:
257+
result += ":0"
258+
if dropout.use_host_scalars == "1":
259+
result += f":{dropout.use_host_scalars}"
253260
return f"--attr-dropout={result}"
254261

255262
deterministic = attribute_flag("deterministic")

scripts/verbose_converter/src/ir.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -158,9 +158,14 @@ def __hash_str__(self):
158158
@dataclass(eq=False)
159159
class Dropout(Mapping):
160160
tag: Optional[str] = None
161+
seed_dt: Optional[str] = None
162+
use_offset: Optional[bool] = None
163+
use_host_scalars: Optional[bool] = None
161164

162165
def __str__(self):
163-
return self.tag or ""
166+
return ":".join(
167+
[self.tag, self.seed_dt, self.use_offset, self.use_host_scalars]
168+
)
164169

165170

166171
class FormattedMapping(Mapping, ABC):

scripts/verbose_converter/src/parse.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -390,7 +390,9 @@ def parse_binary_post_op(spec, alg) -> ir.BinaryPostOp:
390390

391391
@staticmethod
392392
def parse_dropout(args: str) -> ir.Dropout:
393-
return ir.Dropout(tag=args if args else None)
393+
spec = ParseSpec(args)
394+
fields = args.split(":")
395+
return ir.Dropout(fields[0], fields[1], fields[2], fields[3])
394396

395397
@staticmethod
396398
def parse_per_argument(attr, name, parse):

src/common/math_utils.hpp

Lines changed: 24 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -455,15 +455,20 @@ inline bool is_eltwise_ok(
455455
return eltwise_use_src || eltwise_use_dst;
456456
}
457457

458-
inline uint32_t philox4x32(uint32_t idx, uint32_t seed) {
458+
inline uint32_t philox4x32(uint64_t idx, uint64_t seed, uint64_t offset) {
459+
// This impl is aligned with PyTorch at
460+
// https://github.com/pytorch/pytorch/blob/09c950c/aten/src/ATen/core/PhiloxRNGEngine.h
461+
// - both offset and idx are used to fill ctr
462+
// - seed/offset/idx are uint64_t
463+
459464
// Note 1: This impl computes 4 different int32_t rand
460465
// values. Even though this is redundundant for sequential ref,
461466
// keeping vector version to guide optimized implementations.
462-
// Note 2: this can be used for 8x16 as well by changing indexing.
463467

464-
uint32_t x = (idx & ~3L);
465-
uint32_t ctr[4] = {x + 0, x + 1, x + 2, x + 3};
466-
uint32_t key[2] = {uint32_t(seed), uint32_t(seed)};
468+
uint64_t x = (idx & ~3L);
469+
uint32_t ctr[4] = {uint32_t(offset), uint32_t(offset >> 32), uint32_t(x),
470+
uint32_t(x >> 32)};
471+
uint32_t key[2] = {uint32_t(seed), uint32_t(seed >> 32)};
467472

468473
auto mulhilo32 = [&](uint32_t a, uint32_t b, uint32_t &hi, uint32_t &lo) {
469474
const uint64_t product = static_cast<uint64_t>(a) * b;
@@ -490,30 +495,25 @@ inline uint32_t philox4x32(uint32_t idx, uint32_t seed) {
490495
key[0] += PHILOX_W4x32_0;
491496
key[1] += PHILOX_W4x32_1;
492497
};
493-
494-
philox4x32round();
495-
philox4x32bumpkey();
496-
philox4x32round();
497-
philox4x32bumpkey();
498-
philox4x32round();
499-
philox4x32bumpkey();
500-
philox4x32round();
501-
philox4x32bumpkey();
502-
philox4x32round();
503-
philox4x32bumpkey();
504-
philox4x32round();
505-
philox4x32bumpkey();
506-
philox4x32round();
507-
philox4x32bumpkey();
508-
philox4x32round();
509-
philox4x32bumpkey();
510-
philox4x32round();
511-
philox4x32bumpkey();
498+
constexpr int nrounds = 10;
499+
for (int i = 0; i < (nrounds - 1); ++i) {
500+
philox4x32round();
501+
philox4x32bumpkey();
502+
}
512503
philox4x32round();
513504

514505
return ctr[idx & 3L];
515506
}
516507

508+
inline uint32_t philox4x32(uint32_t idx, uint32_t seed) {
509+
// Note: this is for compatibility with impls that don't support s64 rand
510+
uint64_t x = idx & ~3L;
511+
uint64_t idx_64 = ((x + 3) << 32) + (x + 2);
512+
uint64_t offset_64 = ((x + 1) << 32) + x;
513+
uint64_t seed_64 = (uint64_t(seed) << 32) + seed;
514+
return philox4x32(idx_64, seed_64, offset_64);
515+
}
516+
517517
inline uint16_t philox8x16(uint32_t idx, uint32_t seed) {
518518
// we split the index in two parts:
519519
// - 31 msb are used to generate 32 random bits

0 commit comments

Comments
 (0)