@@ -802,9 +802,32 @@ public static void arraysCompareRange(
802802 TraceDataFlowNativeCallbacks .traceMemcmp (first , second , returnValue , hookId );
803803 }
804804
805- // The maximal number of elements of a non-TreeMap Map that will be sorted and searched for the
806- // key closest to the current lookup key in the mapGet hook.
807- private static final int MAX_NUM_KEYS_TO_ENUMERATE = 100 ;
805+ // The maximal number of elements of collections we enumerate to find values close to a lookup.
806+ private static final int MAX_NUM_ELEMENTS_TO_ENUMERATE = 100 ;
807+
808+ @ MethodHook (
809+ type = HookType .AFTER ,
810+ targetClassName = "java.util.Set" ,
811+ targetMethod = "contains" ,
812+ targetMethodDescriptor = "(Ljava/lang/Object;)Z" )
813+ public static void setContains (
814+ MethodHandle method , Object thisObject , Object [] arguments , int hookId , Boolean isContained ) {
815+ if (!isContained ) {
816+ setHookInternal ((Set ) thisObject , arguments [0 ], hookId );
817+ }
818+ }
819+
820+ @ MethodHook (
821+ type = HookType .AFTER ,
822+ targetClassName = "java.util.Set" ,
823+ targetMethod = "remove" ,
824+ targetMethodDescriptor = "(Ljava/lang/Object;)Z" )
825+ public static void setRemove (
826+ MethodHandle method , Object thisObject , Object [] arguments , int hookId , Boolean wasRemoved ) {
827+ if (!wasRemoved ) {
828+ setHookInternal ((Set ) thisObject , arguments [0 ], hookId );
829+ }
830+ }
808831
809832 @ MethodHook (
810833 type = HookType .AFTER ,
@@ -843,6 +866,54 @@ public static void mapGetOrDefault(
843866 }
844867 }
845868
869+ private static final class Bounds {
870+ private final Object lower ;
871+ private final Object upper ;
872+
873+ private Bounds (Object lower , Object upper ) {
874+ this .lower = lower ;
875+ this .upper = upper ;
876+ }
877+
878+ private Object getLower () {
879+ return lower ;
880+ }
881+
882+ private Object getUpper () {
883+ return upper ;
884+ }
885+ }
886+
887+ private static <E > Bounds getLowerUpperBounds (Set <E > elements , E currentElement ) {
888+ int enumeratedElements = 0 ;
889+ Comparable comparableElement = (Comparable ) currentElement ;
890+
891+ Object lowerBound = null ;
892+ Object upperBound = null ;
893+ for (Object validElement : elements ) {
894+ if (!(validElement instanceof Comparable )) continue ;
895+ final Comparable comparableValidElement = (Comparable ) validElement ;
896+ // If the element sorts lower than the non-existing elements, but higher than the current
897+ // lower bound, update the lower bound and vice versa for the upper bound.
898+ try {
899+ if (comparableValidElement .compareTo (comparableElement ) < 0
900+ && (lowerBound == null || comparableValidElement .compareTo (lowerBound ) > 0 )) {
901+ lowerBound = validElement ;
902+ }
903+ if (comparableValidElement .compareTo (comparableElement ) > 0
904+ && (upperBound == null || comparableValidElement .compareTo (upperBound ) < 0 )) {
905+ upperBound = validElement ;
906+ }
907+ } catch (ClassCastException ignored ) {
908+ // Can be thrown by Comparable.compareTo if comparableElement is of a type that can't be
909+ // compared to the elements set.
910+ }
911+ if (enumeratedElements ++ > MAX_NUM_ELEMENTS_TO_ENUMERATE ) break ;
912+ }
913+
914+ return new Bounds (lowerBound , upperBound );
915+ }
916+
846917 @ SuppressWarnings ({"rawtypes" , "unchecked" })
847918 private static <K , V > void mapHookInternal (Map <K , V > map , K currentKey , int hookId ) {
848919 if (map == null || map .isEmpty ()) return ;
@@ -853,8 +924,8 @@ private static <K, V> void mapHookInternal(Map<K, V> map, K currentKey, int hook
853924 Object lowerBoundKey = null ;
854925 Object upperBoundKey = null ;
855926 try {
856- if (map instanceof TreeMap ) {
857- final TreeMap <K , V > treeMap = (TreeMap <K , V >) map ;
927+ if (map instanceof NavigableMap ) {
928+ final NavigableMap <K , V > treeMap = (NavigableMap <K , V >) map ;
858929 try {
859930 lowerBoundKey = treeMap .floorKey (currentKey );
860931 upperBoundKey = treeMap .ceilingKey (currentKey );
@@ -863,30 +934,9 @@ private static <K, V> void mapHookInternal(Map<K, V> map, K currentKey, int hook
863934 // compared to the maps keys.
864935 }
865936 } else if (currentKey instanceof Comparable ) {
866- final Comparable comparableCurrentKey = (Comparable ) currentKey ;
867- // Find two keys that bracket currentKey.
868- // Note: This is not deterministic if map.size() > MAX_NUM_KEYS_TO_ENUMERATE.
869- int enumeratedKeys = 0 ;
870- for (Object validKey : map .keySet ()) {
871- if (!(validKey instanceof Comparable )) continue ;
872- final Comparable comparableValidKey = (Comparable ) validKey ;
873- // If the key sorts lower than the non-existing key, but higher than the current lower
874- // bound, update the lower bound and vice versa for the upper bound.
875- try {
876- if (comparableValidKey .compareTo (comparableCurrentKey ) < 0
877- && (lowerBoundKey == null || comparableValidKey .compareTo (lowerBoundKey ) > 0 )) {
878- lowerBoundKey = validKey ;
879- }
880- if (comparableValidKey .compareTo (comparableCurrentKey ) > 0
881- && (upperBoundKey == null || comparableValidKey .compareTo (upperBoundKey ) < 0 )) {
882- upperBoundKey = validKey ;
883- }
884- } catch (ClassCastException ignored ) {
885- // Can be thrown by floorKey and ceilingKey if currentKey is of a type that can't be
886- // compared to the maps keys.
887- }
888- if (enumeratedKeys ++ > MAX_NUM_KEYS_TO_ENUMERATE ) break ;
889- }
937+ Bounds bounds = getLowerUpperBounds (map .keySet (), currentKey );
938+ lowerBoundKey = bounds .getLower ();
939+ upperBoundKey = bounds .getUpper ();
890940 }
891941 } catch (ConcurrentModificationException ignored ) {
892942 // map was modified by another thread, skip this invocation
@@ -901,6 +951,44 @@ private static <K, V> void mapHookInternal(Map<K, V> map, K currentKey, int hook
901951 }
902952 }
903953
954+ @ SuppressWarnings ({"rawtypes" , "unchecked" })
955+ private static <E > void setHookInternal (Set <E > set , E currentElement , int hookId ) {
956+ if (set == null || set .isEmpty ()) return ;
957+ if (currentElement == null ) return ;
958+
959+ Object lowerBoundElement = null ;
960+ Object upperBoundElement = null ;
961+
962+ try {
963+ if (set instanceof NavigableSet ) {
964+ final NavigableSet <E > navigableSet = (NavigableSet <E >) set ;
965+ try {
966+ lowerBoundElement = navigableSet .floor (currentElement );
967+ upperBoundElement = navigableSet .ceiling (currentElement );
968+ } catch (ClassCastException ignored ) {
969+ // Can be thrown by NavigableSet.floor and NavigableSet.ceiling if the element cannot be
970+ // compared to elements in the set.
971+ }
972+
973+ } else if (currentElement instanceof Comparable ) {
974+ Bounds bounds = getLowerUpperBounds (set , currentElement );
975+ lowerBoundElement = bounds .getLower ();
976+ upperBoundElement = bounds .getUpper ();
977+ }
978+ } catch (ConcurrentModificationException ignored ) {
979+ // set was modified by another thread, skip this invocation
980+ return ;
981+ }
982+
983+ if (lowerBoundElement != null ) {
984+ TraceDataFlowNativeCallbacks .traceGenericCmp (currentElement , lowerBoundElement , hookId );
985+ }
986+ if (upperBoundElement != null ) {
987+ TraceDataFlowNativeCallbacks .traceGenericCmp (
988+ currentElement , upperBoundElement , 31 * hookId + 11 );
989+ }
990+ }
991+
904992 @ MethodHook (
905993 type = HookType .AFTER ,
906994 targetClassName = "org.junit.jupiter.api.Assertions" ,
0 commit comments