Skip to content

Commit 6b99b02

Browse files
committed
feat: add hooks for Set.contains & Set.remove
1 parent d10428a commit 6b99b02

File tree

1 file changed

+107
-29
lines changed

1 file changed

+107
-29
lines changed

src/main/java/com/code_intelligence/jazzer/runtime/TraceCmpHooks.java

Lines changed: 107 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -802,9 +802,32 @@ public static void arraysCompareRange(
802802
TraceDataFlowNativeCallbacks.traceMemcmp(first, second, returnValue, hookId);
803803
}
804804

805-
// The maximal number of elements of a non-TreeMap Map that will be sorted and searched for the
806-
// key closest to the current lookup key in the mapGet hook.
807-
private static final int MAX_NUM_KEYS_TO_ENUMERATE = 100;
805+
// The maximal number of elements of collections we enumerate to find values close to a lookup.
806+
private static final int MAX_NUM_ELEMENTS_TO_ENUMERATE = 100;
807+
808+
@MethodHook(
809+
type = HookType.AFTER,
810+
targetClassName = "java.util.Set",
811+
targetMethod = "contains",
812+
targetMethodDescriptor = "(Ljava/lang/Object;)Z")
813+
public static void setContains(
814+
MethodHandle method, Object thisObject, Object[] arguments, int hookId, Boolean isContained) {
815+
if (!isContained) {
816+
setHookInternal((Set) thisObject, arguments[0], hookId);
817+
}
818+
}
819+
820+
@MethodHook(
821+
type = HookType.AFTER,
822+
targetClassName = "java.util.Set",
823+
targetMethod = "remove",
824+
targetMethodDescriptor = "(Ljava/lang/Object;)Z")
825+
public static void setRemove(
826+
MethodHandle method, Object thisObject, Object[] arguments, int hookId, Boolean wasRemoved) {
827+
if (!wasRemoved) {
828+
setHookInternal((Set) thisObject, arguments[0], hookId);
829+
}
830+
}
808831

809832
@MethodHook(
810833
type = HookType.AFTER,
@@ -843,6 +866,43 @@ public static void mapGetOrDefault(
843866
}
844867
}
845868

869+
static class LowerUpperBounds {
870+
public Object lowerBound;
871+
public Object upperBound;
872+
}
873+
874+
private static <E> LowerUpperBounds getLowerUpperBounds(Set<E> elements, E currentElement) {
875+
int enumeratedKeys = 0;
876+
Comparable comparableElement = (Comparable) currentElement;
877+
878+
LowerUpperBounds bounds = new LowerUpperBounds();
879+
for (Object validElement : elements) {
880+
if (!(validElement instanceof Comparable)) continue;
881+
final Comparable comparableValidElement = (Comparable) validElement;
882+
// If the element sorts lower than the non-existing elements, but higher than the current
883+
// lower
884+
// bound, update the lower bound and vice versa for the upper bound.
885+
try {
886+
if (comparableValidElement.compareTo(comparableElement) < 0
887+
&& (bounds.lowerBound == null
888+
|| comparableValidElement.compareTo(bounds.lowerBound) > 0)) {
889+
bounds.lowerBound = validElement;
890+
}
891+
if (comparableValidElement.compareTo(comparableElement) > 0
892+
&& (bounds.upperBound == null
893+
|| comparableValidElement.compareTo(bounds.upperBound) < 0)) {
894+
bounds.upperBound = validElement;
895+
}
896+
} catch (ClassCastException ignored) {
897+
// Can be thrown by Comparable.compareTo if comparableElement is of a type that can't be
898+
// compared to the elements set.
899+
}
900+
if (enumeratedKeys++ > MAX_NUM_ELEMENTS_TO_ENUMERATE) break;
901+
}
902+
903+
return bounds;
904+
}
905+
846906
@SuppressWarnings({"rawtypes", "unchecked"})
847907
private static <K, V> void mapHookInternal(Map<K, V> map, K currentKey, int hookId) {
848908
if (map == null || map.isEmpty()) return;
@@ -853,8 +913,8 @@ private static <K, V> void mapHookInternal(Map<K, V> map, K currentKey, int hook
853913
Object lowerBoundKey = null;
854914
Object upperBoundKey = null;
855915
try {
856-
if (map instanceof TreeMap) {
857-
final TreeMap<K, V> treeMap = (TreeMap<K, V>) map;
916+
if (map instanceof NavigableMap) {
917+
final NavigableMap<K, V> treeMap = (NavigableMap<K, V>) map;
858918
try {
859919
lowerBoundKey = treeMap.floorKey(currentKey);
860920
upperBoundKey = treeMap.ceilingKey(currentKey);
@@ -863,30 +923,9 @@ private static <K, V> void mapHookInternal(Map<K, V> map, K currentKey, int hook
863923
// compared to the maps keys.
864924
}
865925
} else if (currentKey instanceof Comparable) {
866-
final Comparable comparableCurrentKey = (Comparable) currentKey;
867-
// Find two keys that bracket currentKey.
868-
// Note: This is not deterministic if map.size() > MAX_NUM_KEYS_TO_ENUMERATE.
869-
int enumeratedKeys = 0;
870-
for (Object validKey : map.keySet()) {
871-
if (!(validKey instanceof Comparable)) continue;
872-
final Comparable comparableValidKey = (Comparable) validKey;
873-
// If the key sorts lower than the non-existing key, but higher than the current lower
874-
// bound, update the lower bound and vice versa for the upper bound.
875-
try {
876-
if (comparableValidKey.compareTo(comparableCurrentKey) < 0
877-
&& (lowerBoundKey == null || comparableValidKey.compareTo(lowerBoundKey) > 0)) {
878-
lowerBoundKey = validKey;
879-
}
880-
if (comparableValidKey.compareTo(comparableCurrentKey) > 0
881-
&& (upperBoundKey == null || comparableValidKey.compareTo(upperBoundKey) < 0)) {
882-
upperBoundKey = validKey;
883-
}
884-
} catch (ClassCastException ignored) {
885-
// Can be thrown by floorKey and ceilingKey if currentKey is of a type that can't be
886-
// compared to the maps keys.
887-
}
888-
if (enumeratedKeys++ > MAX_NUM_KEYS_TO_ENUMERATE) break;
889-
}
926+
LowerUpperBounds bounds = getLowerUpperBounds(map.keySet(), currentKey);
927+
lowerBoundKey = bounds.lowerBound;
928+
upperBoundKey = bounds.upperBound;
890929
}
891930
} catch (ConcurrentModificationException ignored) {
892931
// map was modified by another thread, skip this invocation
@@ -901,6 +940,45 @@ private static <K, V> void mapHookInternal(Map<K, V> map, K currentKey, int hook
901940
}
902941
}
903942

943+
@SuppressWarnings({"rawtypes", "unchecked"})
944+
private static <E> void setHookInternal(Set<E> set, E currentElement, int hookId) {
945+
if (set == null || set.isEmpty()) return;
946+
if (currentElement == null) return;
947+
948+
Object lowerBoundElement = null;
949+
Object upperBoundElement = null;
950+
951+
try {
952+
if (set instanceof NavigableSet) {
953+
final NavigableSet<E> navigableSet = (NavigableSet<E>) set;
954+
try {
955+
lowerBoundElement = navigableSet.floor(currentElement);
956+
upperBoundElement = navigableSet.ceiling(currentElement);
957+
} catch (ClassCastException ignored) {
958+
// Can be thrown by NavigableSet.floor and NavigableSet.ceiling if the element cannot be
959+
// compared is of a type that can't be
960+
// compared to the maps keys.
961+
}
962+
963+
} else if (currentElement instanceof Comparable) {
964+
LowerUpperBounds bounds = getLowerUpperBounds(set, currentElement);
965+
lowerBoundElement = bounds.lowerBound;
966+
upperBoundElement = bounds.upperBound;
967+
}
968+
} catch (ConcurrentModificationException ignored) {
969+
// set was modified by another thread, skip this invocation
970+
return;
971+
}
972+
973+
if (lowerBoundElement != null) {
974+
TraceDataFlowNativeCallbacks.traceGenericCmp(currentElement, lowerBoundElement, hookId);
975+
}
976+
if (upperBoundElement != null) {
977+
TraceDataFlowNativeCallbacks.traceGenericCmp(
978+
currentElement, upperBoundElement, 31 * hookId + 11);
979+
}
980+
}
981+
904982
@MethodHook(
905983
type = HookType.AFTER,
906984
targetClassName = "org.junit.jupiter.api.Assertions",

0 commit comments

Comments
 (0)