|
15 | 15 | // specific language governing permissions and limitations |
16 | 16 | // under the License. |
17 | 17 |
|
18 | | -use crate::ptr_eq::{arc_ptr_eq, arc_ptr_hash}; |
19 | 18 | use crate::{ReturnFieldArgs, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl}; |
20 | 19 | use arrow::datatypes::{DataType, FieldRef}; |
21 | 20 | use async_trait::async_trait; |
@@ -62,17 +61,18 @@ pub struct AsyncScalarUDF { |
62 | 61 |
|
63 | 62 | impl PartialEq for AsyncScalarUDF { |
64 | 63 | fn eq(&self, other: &Self) -> bool { |
| 64 | + // Deconstruct to catch any new fields added in future |
65 | 65 | let Self { inner } = self; |
66 | | - // TODO when MSRV >= 1.86.0, switch to `inner.equals(other.inner.as_ref())` leveraging trait upcasting. |
67 | | - arc_ptr_eq(inner, &other.inner) |
| 66 | + inner.dyn_eq(other.inner.as_any()) |
68 | 67 | } |
69 | 68 | } |
70 | 69 | impl Eq for AsyncScalarUDF {} |
71 | 70 |
|
72 | 71 | impl Hash for AsyncScalarUDF { |
73 | 72 | fn hash<H: Hasher>(&self, state: &mut H) { |
| 73 | + // Deconstruct to catch any new fields added in future |
74 | 74 | let Self { inner } = self; |
75 | | - arc_ptr_hash(inner, state); |
| 75 | + inner.dyn_hash(state); |
76 | 76 | } |
77 | 77 | } |
78 | 78 |
|
@@ -132,3 +132,129 @@ impl Display for AsyncScalarUDF { |
132 | 132 | write!(f, "AsyncScalarUDF: {}", self.inner.name()) |
133 | 133 | } |
134 | 134 | } |
| 135 | + |
| 136 | +#[cfg(test)] |
| 137 | +mod tests { |
| 138 | + use std::{ |
| 139 | + hash::{DefaultHasher, Hash, Hasher}, |
| 140 | + sync::Arc, |
| 141 | + }; |
| 142 | + |
| 143 | + use arrow::datatypes::DataType; |
| 144 | + use async_trait::async_trait; |
| 145 | + use datafusion_common::error::Result; |
| 146 | + use datafusion_expr_common::{columnar_value::ColumnarValue, signature::Signature}; |
| 147 | + |
| 148 | + use crate::{ |
| 149 | + async_udf::{AsyncScalarUDF, AsyncScalarUDFImpl}, |
| 150 | + ScalarFunctionArgs, ScalarUDFImpl, |
| 151 | + }; |
| 152 | + |
| 153 | + #[derive(Debug, PartialEq, Eq, Hash, Clone)] |
| 154 | + struct TestAsyncUDFImpl1 { |
| 155 | + a: i32, |
| 156 | + } |
| 157 | + |
| 158 | + impl ScalarUDFImpl for TestAsyncUDFImpl1 { |
| 159 | + fn as_any(&self) -> &dyn std::any::Any { |
| 160 | + self |
| 161 | + } |
| 162 | + |
| 163 | + fn name(&self) -> &str { |
| 164 | + todo!() |
| 165 | + } |
| 166 | + |
| 167 | + fn signature(&self) -> &Signature { |
| 168 | + todo!() |
| 169 | + } |
| 170 | + |
| 171 | + fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> { |
| 172 | + todo!() |
| 173 | + } |
| 174 | + |
| 175 | + fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result<ColumnarValue> { |
| 176 | + todo!() |
| 177 | + } |
| 178 | + } |
| 179 | + |
| 180 | + #[async_trait] |
| 181 | + impl AsyncScalarUDFImpl for TestAsyncUDFImpl1 { |
| 182 | + async fn invoke_async_with_args( |
| 183 | + &self, |
| 184 | + _args: ScalarFunctionArgs, |
| 185 | + ) -> Result<ColumnarValue> { |
| 186 | + todo!() |
| 187 | + } |
| 188 | + } |
| 189 | + |
| 190 | + #[derive(Debug, PartialEq, Eq, Hash, Clone)] |
| 191 | + struct TestAsyncUDFImpl2 { |
| 192 | + a: i32, |
| 193 | + } |
| 194 | + |
| 195 | + impl ScalarUDFImpl for TestAsyncUDFImpl2 { |
| 196 | + fn as_any(&self) -> &dyn std::any::Any { |
| 197 | + self |
| 198 | + } |
| 199 | + |
| 200 | + fn name(&self) -> &str { |
| 201 | + todo!() |
| 202 | + } |
| 203 | + |
| 204 | + fn signature(&self) -> &Signature { |
| 205 | + todo!() |
| 206 | + } |
| 207 | + |
| 208 | + fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> { |
| 209 | + todo!() |
| 210 | + } |
| 211 | + |
| 212 | + fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result<ColumnarValue> { |
| 213 | + todo!() |
| 214 | + } |
| 215 | + } |
| 216 | + |
| 217 | + #[async_trait] |
| 218 | + impl AsyncScalarUDFImpl for TestAsyncUDFImpl2 { |
| 219 | + async fn invoke_async_with_args( |
| 220 | + &self, |
| 221 | + _args: ScalarFunctionArgs, |
| 222 | + ) -> Result<ColumnarValue> { |
| 223 | + todo!() |
| 224 | + } |
| 225 | + } |
| 226 | + |
| 227 | + fn hash<T: Hash>(value: &T) -> u64 { |
| 228 | + let hasher = &mut DefaultHasher::new(); |
| 229 | + value.hash(hasher); |
| 230 | + hasher.finish() |
| 231 | + } |
| 232 | + |
| 233 | + #[test] |
| 234 | + fn test_async_udf_partial_eq_and_hash() { |
| 235 | + // Inner is same cloned arc -> equal |
| 236 | + let inner = Arc::new(TestAsyncUDFImpl1 { a: 1 }); |
| 237 | + let a = AsyncScalarUDF::new(Arc::clone(&inner) as Arc<dyn AsyncScalarUDFImpl>); |
| 238 | + let b = AsyncScalarUDF::new(inner); |
| 239 | + assert_eq!(a, b); |
| 240 | + assert_eq!(hash(&a), hash(&b)); |
| 241 | + |
| 242 | + // Inner is distinct arc -> still equal |
| 243 | + let a = AsyncScalarUDF::new(Arc::new(TestAsyncUDFImpl1 { a: 1 })); |
| 244 | + let b = AsyncScalarUDF::new(Arc::new(TestAsyncUDFImpl1 { a: 1 })); |
| 245 | + assert_eq!(a, b); |
| 246 | + assert_eq!(hash(&a), hash(&b)); |
| 247 | + |
| 248 | + // Negative case: inner is different value -> not equal |
| 249 | + let a = AsyncScalarUDF::new(Arc::new(TestAsyncUDFImpl1 { a: 1 })); |
| 250 | + let b = AsyncScalarUDF::new(Arc::new(TestAsyncUDFImpl1 { a: 2 })); |
| 251 | + assert_ne!(a, b); |
| 252 | + assert_ne!(hash(&a), hash(&b)); |
| 253 | + |
| 254 | + // Negative case: different functions -> not equal |
| 255 | + let a = AsyncScalarUDF::new(Arc::new(TestAsyncUDFImpl1 { a: 1 })); |
| 256 | + let b = AsyncScalarUDF::new(Arc::new(TestAsyncUDFImpl2 { a: 1 })); |
| 257 | + assert_ne!(a, b); |
| 258 | + assert_ne!(hash(&a), hash(&b)); |
| 259 | + } |
| 260 | +} |
0 commit comments