@@ -4061,7 +4061,9 @@ struct primitive_attr : public handle<dnnl_primitive_attr_t> {
40614061
40624062 /// Returns the parameters of a dropout attribute.
40634063 ///
4064- /// @param mask_desc Output memory descriptor of a dropout mask.
4064+ /// @param mask_desc Output memory descriptor for dropout masks. If a
4065+ /// default memory descriptor is returned, the mask values will not be
4066+ /// written to the output memory buffer during the primitive execution.
40654067 void get_dropout(memory::desc &mask_desc) const {
40664068 const_dnnl_memory_desc_t cdesc;
40674069 error::wrap_c_api(dnnl_primitive_attr_get_dropout(get(), &cdesc),
@@ -4072,15 +4074,66 @@ struct primitive_attr : public handle<dnnl_primitive_attr_t> {
40724074 mask_desc = memory::desc(cloned_md);
40734075 }
40744076
4077+ /// Returns the parameters of a dropout attribute.
4078+ ///
4079+ /// @param mask_desc Output memory descriptor for dropout masks. If a
4080+ /// default memory descriptor is returned, the mask values will not be
4081+ /// written to the output memory buffer during the primitive execution.
4082+ /// @param seed_dt Output datatype for seed argument.
4083+ /// @param use_offset Output boolean. If true, an offset argument must be
4084+ /// passed at the execution and will be used in random number
4085+ /// generation.
4086+ /// @param use_host_scalars Output boolean. If true, probability, seed and
4087+ /// offset arguments are passed as host_scalar memory objects.
4088+ void get_dropout(memory::desc &mask_desc, memory::data_type &seed_dt,
4089+ bool &use_offset, bool &use_host_scalars) const {
4090+ const_dnnl_memory_desc_t cdesc;
4091+ dnnl_data_type_t c_seed_dt;
4092+ int c_use_offset;
4093+ int c_use_host_scalars;
4094+ error::wrap_c_api(
4095+ dnnl_primitive_attr_get_dropout_v2(get(), &cdesc, &c_seed_dt,
4096+ &c_use_offset, &c_use_host_scalars),
4097+ "could not get parameters of a dropout attribute");
4098+ dnnl_memory_desc_t cloned_md = nullptr;
4099+ error::wrap_c_api(dnnl_memory_desc_clone(&cloned_md, cdesc),
4100+ "could not clone a memory descriptor");
4101+ mask_desc = memory::desc(cloned_md);
4102+ seed_dt = static_cast<memory::data_type>(c_seed_dt);
4103+ use_offset = c_use_offset;
4104+ use_host_scalars = c_use_host_scalars;
4105+ }
4106+
40754107 /// Sets dropout probability.
40764108 ///
4077- /// @param mask_desc Output memory descriptor of a dropout mask.
4109+ /// @param mask_desc Memory descriptor for dropout masks. If a default
4110+ /// memory descriptor is passed, the mask values will not be written to
4111+ /// the output memory buffer during the primitive execution.
40784112 void set_dropout(const memory::desc &mask_desc) {
40794113 error::wrap_c_api(
40804114 dnnl_primitive_attr_set_dropout(get(), mask_desc.get()),
40814115 "could not set dropout primitive attribute");
40824116 }
40834117
4118+ /// Sets dropout probability.
4119+ ///
4120+ /// @param mask_desc Memory descriptor for dropout masks. If a default
4121+ /// memory descriptor is passed, the mask values will not be written to
4122+ /// the output memory buffer during the primitive execution.
4123+ /// @param seed_dt Datatype for seed argument.
4124+ /// @param use_offset If true, an offset argument must be passed at the
4125+ /// execution and will be used in random number generation.
4126+ /// @param use_host_scalars If true, probability, seed and offset arguments
4127+ /// are passed as host_scalar memory objects.
4128+ void set_dropout(const memory::desc &mask_desc, memory::data_type seed_dt,
4129+ bool use_offset, bool use_host_scalars) {
4130+ error::wrap_c_api(
4131+ dnnl_primitive_attr_set_dropout_v2(get(), mask_desc.get(),
4132+ memory::convert_to_c(seed_dt), use_offset,
4133+ use_host_scalars),
4134+ "could not set dropout primitive attribute");
4135+ }
4136+
40844137 /// Returns the fpmath mode
40854138 fpmath_mode get_fpmath_mode() const {
40864139 dnnl_fpmath_mode_t result;
0 commit comments