Skip to content

Commit bcbc632

Browse files
authored
chore: utilize trait upcasting for AsyncScalarUDF PartialEq & Hash (apache#17872)
* chore: utilize trait upcasting for AsyncScalarUDF PartialEq & Hash * review comments
1 parent 9e8ec54 commit bcbc632

File tree

1 file changed

+130
-4
lines changed

1 file changed

+130
-4
lines changed

datafusion/expr/src/async_udf.rs

Lines changed: 130 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
// specific language governing permissions and limitations
1616
// under the License.
1717

18-
use crate::ptr_eq::{arc_ptr_eq, arc_ptr_hash};
1918
use crate::{ReturnFieldArgs, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl};
2019
use arrow::datatypes::{DataType, FieldRef};
2120
use async_trait::async_trait;
@@ -62,17 +61,18 @@ pub struct AsyncScalarUDF {
6261

6362
impl PartialEq for AsyncScalarUDF {
6463
fn eq(&self, other: &Self) -> bool {
64+
// Deconstruct to catch any new fields added in future
6565
let Self { inner } = self;
66-
// TODO when MSRV >= 1.86.0, switch to `inner.equals(other.inner.as_ref())` leveraging trait upcasting.
67-
arc_ptr_eq(inner, &other.inner)
66+
inner.dyn_eq(other.inner.as_any())
6867
}
6968
}
7069
impl Eq for AsyncScalarUDF {}
7170

7271
impl Hash for AsyncScalarUDF {
7372
fn hash<H: Hasher>(&self, state: &mut H) {
73+
// Deconstruct to catch any new fields added in future
7474
let Self { inner } = self;
75-
arc_ptr_hash(inner, state);
75+
inner.dyn_hash(state);
7676
}
7777
}
7878

@@ -132,3 +132,129 @@ impl Display for AsyncScalarUDF {
132132
write!(f, "AsyncScalarUDF: {}", self.inner.name())
133133
}
134134
}
135+
136+
#[cfg(test)]
137+
mod tests {
138+
use std::{
139+
hash::{DefaultHasher, Hash, Hasher},
140+
sync::Arc,
141+
};
142+
143+
use arrow::datatypes::DataType;
144+
use async_trait::async_trait;
145+
use datafusion_common::error::Result;
146+
use datafusion_expr_common::{columnar_value::ColumnarValue, signature::Signature};
147+
148+
use crate::{
149+
async_udf::{AsyncScalarUDF, AsyncScalarUDFImpl},
150+
ScalarFunctionArgs, ScalarUDFImpl,
151+
};
152+
153+
#[derive(Debug, PartialEq, Eq, Hash, Clone)]
154+
struct TestAsyncUDFImpl1 {
155+
a: i32,
156+
}
157+
158+
impl ScalarUDFImpl for TestAsyncUDFImpl1 {
159+
fn as_any(&self) -> &dyn std::any::Any {
160+
self
161+
}
162+
163+
fn name(&self) -> &str {
164+
todo!()
165+
}
166+
167+
fn signature(&self) -> &Signature {
168+
todo!()
169+
}
170+
171+
fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
172+
todo!()
173+
}
174+
175+
fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result<ColumnarValue> {
176+
todo!()
177+
}
178+
}
179+
180+
#[async_trait]
181+
impl AsyncScalarUDFImpl for TestAsyncUDFImpl1 {
182+
async fn invoke_async_with_args(
183+
&self,
184+
_args: ScalarFunctionArgs,
185+
) -> Result<ColumnarValue> {
186+
todo!()
187+
}
188+
}
189+
190+
#[derive(Debug, PartialEq, Eq, Hash, Clone)]
191+
struct TestAsyncUDFImpl2 {
192+
a: i32,
193+
}
194+
195+
impl ScalarUDFImpl for TestAsyncUDFImpl2 {
196+
fn as_any(&self) -> &dyn std::any::Any {
197+
self
198+
}
199+
200+
fn name(&self) -> &str {
201+
todo!()
202+
}
203+
204+
fn signature(&self) -> &Signature {
205+
todo!()
206+
}
207+
208+
fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
209+
todo!()
210+
}
211+
212+
fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result<ColumnarValue> {
213+
todo!()
214+
}
215+
}
216+
217+
#[async_trait]
218+
impl AsyncScalarUDFImpl for TestAsyncUDFImpl2 {
219+
async fn invoke_async_with_args(
220+
&self,
221+
_args: ScalarFunctionArgs,
222+
) -> Result<ColumnarValue> {
223+
todo!()
224+
}
225+
}
226+
227+
fn hash<T: Hash>(value: &T) -> u64 {
228+
let hasher = &mut DefaultHasher::new();
229+
value.hash(hasher);
230+
hasher.finish()
231+
}
232+
233+
#[test]
234+
fn test_async_udf_partial_eq_and_hash() {
235+
// Inner is same cloned arc -> equal
236+
let inner = Arc::new(TestAsyncUDFImpl1 { a: 1 });
237+
let a = AsyncScalarUDF::new(Arc::clone(&inner) as Arc<dyn AsyncScalarUDFImpl>);
238+
let b = AsyncScalarUDF::new(inner);
239+
assert_eq!(a, b);
240+
assert_eq!(hash(&a), hash(&b));
241+
242+
// Inner is distinct arc -> still equal
243+
let a = AsyncScalarUDF::new(Arc::new(TestAsyncUDFImpl1 { a: 1 }));
244+
let b = AsyncScalarUDF::new(Arc::new(TestAsyncUDFImpl1 { a: 1 }));
245+
assert_eq!(a, b);
246+
assert_eq!(hash(&a), hash(&b));
247+
248+
// Negative case: inner is different value -> not equal
249+
let a = AsyncScalarUDF::new(Arc::new(TestAsyncUDFImpl1 { a: 1 }));
250+
let b = AsyncScalarUDF::new(Arc::new(TestAsyncUDFImpl1 { a: 2 }));
251+
assert_ne!(a, b);
252+
assert_ne!(hash(&a), hash(&b));
253+
254+
// Negative case: different functions -> not equal
255+
let a = AsyncScalarUDF::new(Arc::new(TestAsyncUDFImpl1 { a: 1 }));
256+
let b = AsyncScalarUDF::new(Arc::new(TestAsyncUDFImpl2 { a: 1 }));
257+
assert_ne!(a, b);
258+
assert_ne!(hash(&a), hash(&b));
259+
}
260+
}

0 commit comments

Comments
 (0)