@@ -802,9 +802,32 @@ public static void arraysCompareRange(
802802 TraceDataFlowNativeCallbacks .traceMemcmp (first , second , returnValue , hookId );
803803 }
804804
805- // The maximal number of elements of a non-TreeMap Map that will be sorted and searched for the
806- // key closest to the current lookup key in the mapGet hook.
807- private static final int MAX_NUM_KEYS_TO_ENUMERATE = 100 ;
805+ // The maximal number of elements of collections we enumerate to find values close to a lookup.
806+ private static final int MAX_NUM_ELEMENTS_TO_ENUMERATE = 100 ;
807+
808+ @ MethodHook (
809+ type = HookType .AFTER ,
810+ targetClassName = "java.util.Set" ,
811+ targetMethod = "contains" ,
812+ targetMethodDescriptor = "(Ljava/lang/Object;)Z" )
813+ public static void setContains (
814+ MethodHandle method , Object thisObject , Object [] arguments , int hookId , Boolean isContained ) {
815+ if (!isContained ) {
816+ setHookInternal ((Set ) thisObject , arguments [0 ], hookId );
817+ }
818+ }
819+
820+ @ MethodHook (
821+ type = HookType .AFTER ,
822+ targetClassName = "java.util.Set" ,
823+ targetMethod = "remove" ,
824+ targetMethodDescriptor = "(Ljava/lang/Object;)Z" )
825+ public static void setRemove (
826+ MethodHandle method , Object thisObject , Object [] arguments , int hookId , Boolean wasRemoved ) {
827+ if (!wasRemoved ) {
828+ setHookInternal ((Set ) thisObject , arguments [0 ], hookId );
829+ }
830+ }
808831
809832 @ MethodHook (
810833 type = HookType .AFTER ,
@@ -843,6 +866,43 @@ public static void mapGetOrDefault(
843866 }
844867 }
845868
869+ static class LowerUpperBounds {
870+ public Object lowerBound ;
871+ public Object upperBound ;
872+ }
873+
874+ private static <E > LowerUpperBounds getLowerUpperBounds (Set <E > elements , E currentElement ) {
875+ int enumeratedKeys = 0 ;
876+ Comparable comparableElement = (Comparable ) currentElement ;
877+
878+ LowerUpperBounds bounds = new LowerUpperBounds ();
879+ for (Object validElement : elements ) {
880+ if (!(validElement instanceof Comparable )) continue ;
881+ final Comparable comparableValidElement = (Comparable ) validElement ;
882+ // If the element sorts lower than the non-existing elements, but higher than the current
883+ // lower
884+ // bound, update the lower bound and vice versa for the upper bound.
885+ try {
886+ if (comparableValidElement .compareTo (comparableElement ) < 0
887+ && (bounds .lowerBound == null
888+ || comparableValidElement .compareTo (bounds .lowerBound ) > 0 )) {
889+ bounds .lowerBound = validElement ;
890+ }
891+ if (comparableValidElement .compareTo (comparableElement ) > 0
892+ && (bounds .upperBound == null
893+ || comparableValidElement .compareTo (bounds .upperBound ) < 0 )) {
894+ bounds .upperBound = validElement ;
895+ }
896+ } catch (ClassCastException ignored ) {
897+ // Can be thrown by Comparable.compareTo if comparableElement is of a type that can't be
898+ // compared to the elements set.
899+ }
900+ if (enumeratedKeys ++ > MAX_NUM_ELEMENTS_TO_ENUMERATE ) break ;
901+ }
902+
903+ return bounds ;
904+ }
905+
846906 @ SuppressWarnings ({"rawtypes" , "unchecked" })
847907 private static <K , V > void mapHookInternal (Map <K , V > map , K currentKey , int hookId ) {
848908 if (map == null || map .isEmpty ()) return ;
@@ -853,8 +913,8 @@ private static <K, V> void mapHookInternal(Map<K, V> map, K currentKey, int hook
853913 Object lowerBoundKey = null ;
854914 Object upperBoundKey = null ;
855915 try {
856- if (map instanceof TreeMap ) {
857- final TreeMap <K , V > treeMap = (TreeMap <K , V >) map ;
916+ if (map instanceof NavigableMap ) {
917+ final NavigableMap <K , V > treeMap = (NavigableMap <K , V >) map ;
858918 try {
859919 lowerBoundKey = treeMap .floorKey (currentKey );
860920 upperBoundKey = treeMap .ceilingKey (currentKey );
@@ -863,30 +923,9 @@ private static <K, V> void mapHookInternal(Map<K, V> map, K currentKey, int hook
863923 // compared to the maps keys.
864924 }
865925 } else if (currentKey instanceof Comparable ) {
866- final Comparable comparableCurrentKey = (Comparable ) currentKey ;
867- // Find two keys that bracket currentKey.
868- // Note: This is not deterministic if map.size() > MAX_NUM_KEYS_TO_ENUMERATE.
869- int enumeratedKeys = 0 ;
870- for (Object validKey : map .keySet ()) {
871- if (!(validKey instanceof Comparable )) continue ;
872- final Comparable comparableValidKey = (Comparable ) validKey ;
873- // If the key sorts lower than the non-existing key, but higher than the current lower
874- // bound, update the lower bound and vice versa for the upper bound.
875- try {
876- if (comparableValidKey .compareTo (comparableCurrentKey ) < 0
877- && (lowerBoundKey == null || comparableValidKey .compareTo (lowerBoundKey ) > 0 )) {
878- lowerBoundKey = validKey ;
879- }
880- if (comparableValidKey .compareTo (comparableCurrentKey ) > 0
881- && (upperBoundKey == null || comparableValidKey .compareTo (upperBoundKey ) < 0 )) {
882- upperBoundKey = validKey ;
883- }
884- } catch (ClassCastException ignored ) {
885- // Can be thrown by floorKey and ceilingKey if currentKey is of a type that can't be
886- // compared to the maps keys.
887- }
888- if (enumeratedKeys ++ > MAX_NUM_KEYS_TO_ENUMERATE ) break ;
889- }
926+ LowerUpperBounds bounds = getLowerUpperBounds (map .keySet (), currentKey );
927+ lowerBoundKey = bounds .lowerBound ;
928+ upperBoundKey = bounds .upperBound ;
890929 }
891930 } catch (ConcurrentModificationException ignored ) {
892931 // map was modified by another thread, skip this invocation
@@ -901,6 +940,45 @@ private static <K, V> void mapHookInternal(Map<K, V> map, K currentKey, int hook
901940 }
902941 }
903942
943+ @ SuppressWarnings ({"rawtypes" , "unchecked" })
944+ private static <E > void setHookInternal (Set <E > set , E currentElement , int hookId ) {
945+ if (set == null || set .isEmpty ()) return ;
946+ if (currentElement == null ) return ;
947+
948+ Object lowerBoundElement = null ;
949+ Object upperBoundElement = null ;
950+
951+ try {
952+ if (set instanceof NavigableSet ) {
953+ final NavigableSet <E > navigableSet = (NavigableSet <E >) set ;
954+ try {
955+ lowerBoundElement = navigableSet .floor (currentElement );
956+ upperBoundElement = navigableSet .ceiling (currentElement );
957+ } catch (ClassCastException ignored ) {
958+ // Can be thrown by NavigableSet.floor and NavigableSet.ceiling if the element cannot be
959+ // compared is of a type that can't be
960+ // compared to the maps keys.
961+ }
962+
963+ } else if (currentElement instanceof Comparable ) {
964+ LowerUpperBounds bounds = getLowerUpperBounds (set , currentElement );
965+ lowerBoundElement = bounds .lowerBound ;
966+ upperBoundElement = bounds .upperBound ;
967+ }
968+ } catch (ConcurrentModificationException ignored ) {
969+ // set was modified by another thread, skip this invocation
970+ return ;
971+ }
972+
973+ if (lowerBoundElement != null ) {
974+ TraceDataFlowNativeCallbacks .traceGenericCmp (currentElement , lowerBoundElement , hookId );
975+ }
976+ if (upperBoundElement != null ) {
977+ TraceDataFlowNativeCallbacks .traceGenericCmp (
978+ currentElement , upperBoundElement , 31 * hookId + 11 );
979+ }
980+ }
981+
904982 @ MethodHook (
905983 type = HookType .AFTER ,
906984 targetClassName = "org.junit.jupiter.api.Assertions" ,
0 commit comments