diff --git a/evm/src/main/java/org/hyperledger/besu/evm/UInt256.java b/evm/src/main/java/org/hyperledger/besu/evm/UInt256.java index 290ef16cffb..8fe19ed9207 100644 --- a/evm/src/main/java/org/hyperledger/besu/evm/UInt256.java +++ b/evm/src/main/java/org/hyperledger/besu/evm/UInt256.java @@ -15,23 +15,25 @@ package org.hyperledger.besu.evm; import java.math.BigInteger; -import java.nio.ByteBuffer; -import java.nio.ByteOrder; - -import com.google.common.annotations.VisibleForTesting; /** * 256-bits wide unsigned integer class. * *

This class is an optimised version of BigInteger for fixed width 256-bits integers. + * + * @param u3 4th digit + * @param u2 3rd digit + * @param u1 2nd digit + * @param u0 1st digit */ -public final class UInt256 { - // region Internals +public record UInt256(long u3, long u2, long u1, long u0) { + + // region Values + // -------------------------------------------------------------------------- + // UInt256 represents a big-endian 256-bits integer. + // As opposed to Java int, operations are by default unsigned, + // and signed version are interpreted in two-complements as usual. // -------------------------------------------------------------------------- - // UInt256 is a big-endian up to 256-bits integer. - // Internally, it is represented with fixed-size int/long limbs in little-endian order. - // Length is used to optimise algorithms, skipping leading zeroes. - // Nonetheless, 256bits are always allocated and initialised to zeroes. /** Fixed size in bytes. */ public static final int BYTESIZE = 32; @@ -39,54 +41,34 @@ public final class UInt256 { /** Fixed size in bits. */ public static final int BITSIZE = 256; - // Fixed number of limbs or digits - private static final int N_LIMBS = 8; - // Fixed number of bits per limb. - private static final int N_BITS_PER_LIMB = 32; - // Mask for long values - private static final long MASK_L = 0xFFFFFFFFL; - - private final int[] limbs; - private final int length; + /** The constant 0. */ + public static final UInt256 ZERO = new UInt256(0, 0, 0, 0); - @VisibleForTesting - int[] limbs() { - return limbs; - } + /** The constant All ones */ + public static final UInt256 MAX = + new UInt256( + 0xFFFFFFFFFFFFFFFFL, 0xFFFFFFFFFFFFFFFFL, 0xFFFFFFFFFFFFFFFFL, 0xFFFFFFFFFFFFFFFFL); // -------------------------------------------------------------------------- // endregion - /** The constant 0. */ - public static final UInt256 ZERO = new UInt256(new int[] {0, 0, 0, 0, 0, 0, 0, 0}, 0); - - /** The constant All ones */ - public static final UInt256 ALL_ONES = - new UInt256( - new int[] { - 0xFFFFFFFF, - 0xFFFFFFFF, - 0xFFFFFFFF, - 0xFFFFFFFF, - 0xFFFFFFFF, - 0xFFFFFFFF, - 0xFFFFFFFF, - 0xFFFFFFFF - }, - N_LIMBS); - - // region Constructors + // region (private) Internal Values // -------------------------------------------------------------------------- - UInt256(final int[] limbs, final int length) { - // Unchecked length: assumes limbs have length == N_LIMBS - this.limbs = limbs; - this.length = length; - } + // Fixed number of limbs or digits + // private static final int N_LIMBS = 4; + // Fixed number of bits per limb. + private static final int N_BITS_PER_LIMB = 64; - UInt256(final int[] limbs) { - this(limbs, N_LIMBS); - } + // Arrays of zeros. + // We accomodate up to a result of a multiplication + // private static final long[] ZERO_LONGS = new long[9]; + + // -------------------------------------------------------------------------- + // endregion + + // region Alternative Constructors + // -------------------------------------------------------------------------- /** * Instantiates a new UInt256 from byte array. @@ -95,48 +77,25 @@ int[] limbs() { * @return Big-endian UInt256 represented by the bytes. */ public static UInt256 fromBytesBE(final byte[] bytes) { - int byteLen = bytes.length; - if (byteLen == 0) return ZERO; - - int[] limbs = new int[N_LIMBS]; - - // Fast path for exactly 32 bytes - if (byteLen == 32) { - limbs[7] = getIntBE(bytes, 0); - limbs[6] = getIntBE(bytes, 4); - limbs[5] = getIntBE(bytes, 8); - limbs[4] = getIntBE(bytes, 12); - limbs[3] = getIntBE(bytes, 16); - limbs[2] = getIntBE(bytes, 20); - limbs[1] = getIntBE(bytes, 24); - limbs[0] = getIntBE(bytes, 28); - return new UInt256(limbs, N_LIMBS); - } - - // General path for variable length - int limbIndex = 0; - int byteIndex = byteLen - 1; - - while (byteIndex >= 0 && limbIndex < N_LIMBS) { - int limb = 0; - int shift = 0; - - for (int j = 0; j < 4 && byteIndex >= 0; j++, byteIndex--, shift += 8) { - limb |= (bytes[byteIndex] & 0xFF) << shift; - } - - limbs[limbIndex++] = limb; + if (bytes.length == 0) return ZERO; + long u3 = 0; + long u2 = 0; + long u1 = 0; + long u0 = 0; + int b = bytes.length - 1; // Index in bytes array + for (int shift = 0; shift < 64 && b >= 0; b--, shift += 8) { + u0 |= ((bytes[b] & 0xFFL) << shift); } - - return new UInt256(limbs, limbIndex); - } - - // Helper method to read 4 bytes as big-endian int - private static int getIntBE(final byte[] bytes, final int offset) { - return ((bytes[offset] & 0xFF) << 24) - | ((bytes[offset + 1] & 0xFF) << 16) - | ((bytes[offset + 2] & 0xFF) << 8) - | (bytes[offset + 3] & 0xFF); + for (int shift = 0; shift < 64 && b >= 0; b--, shift += 8) { + u1 |= ((bytes[b] & 0xFFL) << shift); + } + for (int shift = 0; shift < 64 && b >= 0; b--, shift += 8) { + u2 |= ((bytes[b] & 0xFFL) << shift); + } + for (int shift = 0; shift < 64 && b >= 0; b--, shift += 8) { + u3 |= ((bytes[b] & 0xFFL) << shift); + } + return new UInt256(u3, u2, u1, u0); } /** @@ -146,10 +105,7 @@ private static int getIntBE(final byte[] bytes, final int offset) { * @return The UInt256 equivalent of value. */ public static UInt256 fromInt(final int value) { - if (value == 0) return ZERO; - int[] limbs = new int[N_LIMBS]; - limbs[0] = value; - return new UInt256(limbs, 1); + return new UInt256(0, 0, 0, value & 0xFFFFFFFFL); } /** @@ -159,27 +115,25 @@ public static UInt256 fromInt(final int value) { * @return The UInt256 equivalent of value. */ public static UInt256 fromLong(final long value) { - if (value == 0) return ZERO; - int[] limbs = new int[N_LIMBS]; - limbs[0] = (int) value; - limbs[1] = (int) (value >>> 32); - return new UInt256(limbs, 2); + return new UInt256(0, 0, 0, value); } /** - * Instantiates a new UInt256 from an int array. + * Instantiates a new UInt256 from an array. * - *

The array is interpreted in little-endian order. It is either padded with 0s or truncated if - * necessary. + *

Read digits from an array starting from the end. The array must have at least N_LIMBS + * elements. * - * @param arr int array of limbs. - * @return The UInt256 equivalent of value. + * @param limbs The array holding the digits. + * @return The UInt256 from the array */ - public static UInt256 fromArray(final int[] arr) { - int[] limbs = new int[N_LIMBS]; - int len = Math.min(N_LIMBS, arr.length); - System.arraycopy(arr, 0, limbs, 0, len); - return new UInt256(limbs, len); + public static UInt256 fromArray(final long[] limbs) { + int i = limbs.length; + long z0 = limbs[--i]; + long z1 = limbs[--i]; + long z2 = limbs[--i]; + long z3 = limbs[--i]; + return new UInt256(z3, z2, z1, z0); } // -------------------------------------------------------------------------- @@ -193,7 +147,7 @@ public static UInt256 fromArray(final int[] arr) { * @return Value truncated to an int, possibly lossy. */ public int intValue() { - return limbs[0]; + return (int) u0; } /** @@ -202,7 +156,7 @@ public int intValue() { * @return Value truncated to a long, possibly lossy. */ public long longValue() { - return (limbs[0] & MASK_L) | ((limbs[1] & MASK_L) << 32); + return u0; } /** @@ -211,11 +165,24 @@ public long longValue() { * @return Big-endian ordered bytes for this UInt256 value. */ public byte[] toBytesBE() { - ByteBuffer buf = ByteBuffer.allocate(BYTESIZE).order(ByteOrder.BIG_ENDIAN); - for (int i = N_LIMBS - 1; i >= 0; i--) { - buf.putInt(limbs[i]); - } - return buf.array(); + byte[] result = new byte[BYTESIZE]; + longIntoBytes(result, 0, u3); + longIntoBytes(result, 8, u2); + longIntoBytes(result, 16, u1); + longIntoBytes(result, 24, u0); + return result; + } + + // Helper method to write 8 bytes from big-endian int + private static void longIntoBytes(final byte[] bytes, final int offset, final long value) { + bytes[offset] = (byte) (value >>> 56); + bytes[offset + 1] = (byte) (value >>> 48); + bytes[offset + 2] = (byte) (value >>> 40); + bytes[offset + 3] = (byte) (value >>> 32); + bytes[offset + 4] = (byte) (value >>> 24); + bytes[offset + 5] = (byte) (value >>> 16); + bytes[offset + 6] = (byte) (value >>> 8); + bytes[offset + 7] = (byte) value; } /** @@ -227,8 +194,14 @@ public BigInteger toBigInteger() { return new BigInteger(1, toBytesBE()); } - @Override - public String toString() { + /** + * Convert to hexstring. + * + *

Convert this integer into big-endian hexstring representation. + * + * @return The hexstring representing the integer. + */ + public String toHexString() { StringBuilder sb = new StringBuilder("0x"); for (byte b : toBytesBE()) { sb.append(String.format("%02x", b)); @@ -236,6 +209,42 @@ public String toString() { return sb.toString(); } + /** + * Fills an array with digits. + * + *

Fills an array with the integer's digits starting from the end. The array must have at least + * N_LIMBS elements. + * + * @param limbs The array to fill + */ + public void intoArray(final long[] limbs) { + int len = limbs.length; + limbs[len--] = u0; + limbs[len--] = u1; + limbs[len--] = u2; + limbs[len--] = u3; + } + + private UInt320 UInt320Value() { + return new UInt320(0, u3, u2, u1, u0); + } + + private Modulus64 asModulus64() { + return new Modulus64(u0); + } + + private Modulus128 asModulus128() { + return new Modulus128(u1, u0); + } + + private Modulus192 asModulus192() { + return new Modulus192(u2, u1, u0); + } + + private Modulus256 asModulus256() { + return new Modulus256(u3, u2, u1, u0); + } + // -------------------------------------------------------------------------- // endregion @@ -248,121 +257,83 @@ public String toString() { * @return true if this UInt256 value is 0. */ public boolean isZero() { - return (limbs[0] | limbs[1] | limbs[2] | limbs[3] | limbs[4] | limbs[5] | limbs[6] | limbs[7]) - == 0; + return (u0 | u1 | u2 | u3) == 0; } /** - * Compares two UInt256. + * Is the value 1 ? * - * @param a left UInt256 - * @param b right UInt256 - * @return 0 if a == b, negative if a < b and positive if a > b. + * @return true if this UInt256 value is 1. */ - public static int compare(final UInt256 a, final UInt256 b) { - int comp; - for (int i = N_LIMBS - 1; i >= 0; i--) { - comp = Integer.compareUnsigned(a.limbs[i], b.limbs[i]); - if (comp != 0) return comp; - } - return 0; + public boolean isOne() { + return ((u0 ^ 1L) | u1 | u2 | u3) == 0; } - @Override - public boolean equals(final Object obj) { - if (this == obj) return true; - if (!(obj instanceof UInt256)) return false; - UInt256 other = (UInt256) obj; - - int xor = - (this.limbs[0] ^ other.limbs[0]) - | (this.limbs[1] ^ other.limbs[1]) - | (this.limbs[2] ^ other.limbs[2]) - | (this.limbs[3] ^ other.limbs[3]) - | (this.limbs[4] ^ other.limbs[4]) - | (this.limbs[5] ^ other.limbs[5]) - | (this.limbs[6] ^ other.limbs[6]) - | (this.limbs[7] ^ other.limbs[7]); - return xor == 0; + /** + * Is the value 0 or 1 ? + * + * @return true if this UInt256 value is 1. + */ + public boolean isZeroOrOne() { + return ((u0 & -2L) | u1 | u2 | u3) == 0; } - @Override - public int hashCode() { - int h = 1; - for (int i = 0; i < N_LIMBS; i++) { - h = 31 * h + limbs[i]; - } - return h; + /** + * Is the two complements signed representation of this integer negative. + * + * @return True if the two complements representation of this integer is negative. + */ + public boolean isNegative() { + return u3 < 0; } - // -------------------------------------------------------------------------- - // endregion - - // region Arithmetic Operations - // -------------------------------------------------------------------------- - /** - * Unsigned modulo reduction. + * Does the value fit a long. * - * @param modulus The modulus of the reduction - * @return The remainder modulo {@code modulus}. + * @return true if it has at most 1 effective digit. */ - public UInt256 mod(final UInt256 modulus) { - if (this.isZero() || modulus.isZero()) return ZERO; - return new UInt256(knuthRemainder(this.limbs, modulus.limbs), modulus.length); + public boolean isUInt64() { + return (u1 | u2 | u3) == 0; } /** - * Signed modulo reduction. + * Does the value fit 2 longs. * - *

In signed modulo reduction, integers are interpretated as fixed 256 bits width two's - * complement signed integers. - * - * @param modulus The modulus of the reduction - * @return The remainder modulo {@code modulus}. + * @return true if it has at most 2 effective digits. */ - public UInt256 signedMod(final UInt256 modulus) { - if (this.isZero() || modulus.isZero()) return ZERO; - int[] x = new int[N_LIMBS]; - int[] y = new int[N_LIMBS]; - absInto(x, this.limbs, N_LIMBS); - absInto(y, modulus.limbs, N_LIMBS); - int[] r = knuthRemainder(x, y); - if (isNeg(this.limbs, N_LIMBS)) { - negate(r, N_LIMBS); - return new UInt256(r); - } - return new UInt256(r, modulus.length); + public boolean isUInt128() { + return (u2 | u3) == 0; } /** - * Modular addition. + * Does the value fit 3 longs. * - * @param other The integer to add to this. - * @param modulus The modulus of the reduction. - * @return This integer this + other (mod modulus). + * @return true if it has at most 3 effective digits. */ - public UInt256 addMod(final UInt256 other, final UInt256 modulus) { - if (modulus.isZero()) return ZERO; - int[] sum = addWithCarry(this.limbs, this.length, other.limbs, other.length); - int[] rem = knuthRemainder(sum, modulus.limbs); - return new UInt256(rem, modulus.length); + public boolean isUInt192() { + return u3 == 0; } /** - * Modular multiplication. + * Compares two UInt256. * - * @param other The integer to add to this. - * @param modulus The modulus of the reduction. - * @return This integer this + other (mod modulus). + * @param a left UInt256 + * @param b right UInt256 + * @return 0 if a == b, negative if a < b and positive if a > b. */ - public UInt256 mulMod(final UInt256 other, final UInt256 modulus) { - if (this.isZero() || other.isZero() || modulus.isZero()) return ZERO; - int[] result = addMul(this.limbs, this.length, other.limbs, other.length); - result = knuthRemainder(result, modulus.limbs); - return new UInt256(result, modulus.length); + public static int compare(final UInt256 a, final UInt256 b) { + if (a.u3 != b.u3) return Long.compareUnsigned(a.u3, b.u3); + if (a.u2 != b.u2) return Long.compareUnsigned(a.u2, b.u2); + if (a.u1 != b.u1) return Long.compareUnsigned(a.u1, b.u1); + return Long.compareUnsigned(a.u0, b.u0); } + // -------------------------------------------------------------------------- + // endregion + + // region Bitwise Operations + // -------------------------------------------------------------------------- + /** * Bitwise AND operation * @@ -370,16 +341,7 @@ public UInt256 mulMod(final UInt256 other, final UInt256 modulus) { * @return The UInt256 result from the bitwise AND operation */ public UInt256 and(final UInt256 other) { - int[] result = new int[N_LIMBS]; - result[0] = this.limbs[0] & other.limbs[0]; - result[1] = this.limbs[1] & other.limbs[1]; - result[2] = this.limbs[2] & other.limbs[2]; - result[3] = this.limbs[3] & other.limbs[3]; - result[4] = this.limbs[4] & other.limbs[4]; - result[5] = this.limbs[5] & other.limbs[5]; - result[6] = this.limbs[6] & other.limbs[6]; - result[7] = this.limbs[7] & other.limbs[7]; - return new UInt256(result, N_LIMBS); + return new UInt256(u3 & other.u3, u2 & other.u2, u1 & other.u1, u0 & other.u0); } /** @@ -389,16 +351,7 @@ public UInt256 and(final UInt256 other) { * @return The UInt256 result from the bitwise XOR operation */ public UInt256 xor(final UInt256 other) { - int[] result = new int[N_LIMBS]; - result[0] = this.limbs[0] ^ other.limbs[0]; - result[1] = this.limbs[1] ^ other.limbs[1]; - result[2] = this.limbs[2] ^ other.limbs[2]; - result[3] = this.limbs[3] ^ other.limbs[3]; - result[4] = this.limbs[4] ^ other.limbs[4]; - result[5] = this.limbs[5] ^ other.limbs[5]; - result[6] = this.limbs[6] ^ other.limbs[6]; - result[7] = this.limbs[7] ^ other.limbs[7]; - return new UInt256(result, N_LIMBS); + return new UInt256(u3 ^ other.u3, u2 ^ other.u2, u1 ^ other.u1, u0 ^ other.u0); } /** @@ -408,16 +361,7 @@ public UInt256 xor(final UInt256 other) { * @return The UInt256 result from the bitwise OR operation */ public UInt256 or(final UInt256 other) { - int[] result = new int[N_LIMBS]; - result[0] = this.limbs[0] | other.limbs[0]; - result[1] = this.limbs[1] | other.limbs[1]; - result[2] = this.limbs[2] | other.limbs[2]; - result[3] = this.limbs[3] | other.limbs[3]; - result[4] = this.limbs[4] | other.limbs[4]; - result[5] = this.limbs[5] | other.limbs[5]; - result[6] = this.limbs[6] | other.limbs[6]; - result[7] = this.limbs[7] | other.limbs[7]; - return new UInt256(result, N_LIMBS); + return new UInt256(u3 | other.u3, u2 | other.u2, u1 | other.u1, u0 | other.u0); } /** @@ -426,296 +370,1395 @@ public UInt256 or(final UInt256 other) { * @return The UInt256 result from the bitwise NOT operation */ public UInt256 not() { - int[] result = new int[N_LIMBS]; - result[0] = ~this.limbs[0]; - result[1] = ~this.limbs[1]; - result[2] = ~this.limbs[2]; - result[3] = ~this.limbs[3]; - result[4] = ~this.limbs[4]; - result[5] = ~this.limbs[5]; - result[6] = ~this.limbs[6]; - result[7] = ~this.limbs[7]; - return new UInt256(result, N_LIMBS); + return new UInt256(~u3, ~u2, ~u1, ~u0); + } + + /** + * Bitwise shift left. + * + * @param shift The number of bits to shift left (at most 64). + * @return The shifted UInt256. + */ + public UInt256 shiftLeft(final int shift) { + // Unchecked: 0 <= shift < 64 + if (shift == 0) return this; + int invShift = (N_BITS_PER_LIMB - shift); + long z0 = (u0 << shift); + long z1 = (u1 << shift) | u0 >>> invShift; + long z2 = (u2 << shift) | u1 >>> invShift; + long z3 = (u3 << shift) | u2 >>> invShift; + return new UInt256(z3, z2, z1, z0); + } + + /** + * Bitwise shift right. + * + * @param shift The number of bits to shift right (at most 64). + * @return The shifted UInt256. + */ + public UInt256 shiftRight(final int shift) { + // Unchecked: 0 <= shift < 64 + if (shift == 0) return this; + int invShift = (N_BITS_PER_LIMB - shift); + long z3 = (u3 >>> shift); + long z2 = (u2 >>> shift) | u3 << invShift; + long z1 = (u1 >>> shift) | u2 << invShift; + long z0 = (u0 >>> shift) | u1 << invShift; + return new UInt256(z3, z2, z1, z0); } // -------------------------------------------------------------------------- // endregion + // region Arithmetic Operations + // -------------------------------------------------------------------------- + + /** + * Compute the two-complement negative representation of this integer. + * + * @return The negative of this integer. + */ + public UInt256 neg() { + long carry = 1; + long z0 = ~u0 + carry; + carry = (z0 == 0 && carry == 1) ? 1 : 0; + long z1 = ~u1 + carry; + carry = (z1 == 0 && carry == 1) ? 1 : 0; + long z2 = ~u2 + carry; + carry = (z2 == 0 && carry == 1) ? 1 : 0; + long z3 = ~u3 + carry; + return new UInt256(z3, z2, z1, z0); + } + /** - * Simple addition + * Compute the absolute value for a two-complement negative representation of this integer. * - * @param other The UInt256 to add to this. - * @return The UInt256 result from the addition + * @return The absolute value of this integer. + */ + public UInt256 abs() { + return isNegative() ? neg() : this; + } + + /** + * Addition + * + *

Compute the wrapping sum of 2 256-bits integers. + * + * @param other Integer to add to this integer. + * @return The sum. */ public UInt256 add(final UInt256 other) { - return new UInt256( - addWithCarry(this.limbs, this.limbs.length, other.limbs, other.limbs.length)); + if (isZero()) return other; + if (other.isZero()) return this; + return adc(other).UInt256Value(); + } + + /** + * Multiplication + * + *

Compute the wrapping product of 2 256-bits integers. + * + * @param other Integer to multiply with this integer. + * @return The product. + */ + public UInt256 mul(final UInt256 other) { + if (isZero() || other.isZero()) return ZERO; + if (other.isOne()) return this; + if (this.isOne()) return other; + if (u3 != 0) return mul256(other).UInt256Value(); + if (u2 != 0) return mul192(other).UInt256Value(); + if (u1 != 0) return mul128(other); + return mul64(other); + } + + /** + * Unsigned modulo reduction. + * + *

Compute dividend (mod modulus) as unsigned big-endian integer. + * + * @param modulus The modulus of the reduction. + * @return The remainder. + */ + public UInt256 mod(final UInt256 modulus) { + if (isZero()) return ZERO; + if (modulus.u3 != 0) return modulus.asModulus256().reduce(this); + if (modulus.u2 != 0) return modulus.asModulus192().reduce(this); + if (modulus.u1 != 0) return modulus.asModulus128().reduce(this); + if ((modulus.u0 == 0) || (modulus.u0 == 1)) return ZERO; + return modulus.asModulus64().reduce(this); + } + + /** + * Signed modulo reduction. + * + *

In signed modulo reduction, integers are interpretated as fixed 256 bits width two's + * complement signed integers. + * + * @param modulus The modulus of the reduction + * @return The remainder modulo {@code modulus}. + */ + public UInt256 signedMod(final UInt256 modulus) { + if (isZero() || modulus.isZeroOrOne() || modulus.equals(MAX)) return ZERO; + UInt256 a = abs(); + UInt256 m = modulus.abs(); + UInt256 r = a.mod(m); + if (isNegative()) r = r.neg(); + return r; + } + + /** + * Modular addition. + * + * @param other The integer to add to this. + * @param modulus The modulus of the reduction. + * @return This integer this + other (mod modulus). + */ + public UInt256 addMod(final UInt256 other, final UInt256 modulus) { + if (isZero()) return other.mod(modulus); + if (other.isZero()) return this.mod(modulus); + if (modulus.isZeroOrOne()) return ZERO; + if (modulus.u3 != 0) return modulus.asModulus256().sum(this, other); + if (modulus.u2 != 0) return modulus.asModulus192().sum(this, other); + if (modulus.u1 != 0) return modulus.asModulus128().sum(this, other); + return modulus.asModulus64().sum(this, other); + } + + /** + * Modular multiplication. + * + * @param other The integer to add to this. + * @param modulus The modulus of the reduction. + * @return This integer this + other (mod modulus). + */ + public UInt256 mulMod(final UInt256 other, final UInt256 modulus) { + if (this.isZero() || other.isZero() || modulus.isZeroOrOne()) return ZERO; + if (this.isOne()) return other.mod(modulus); + if (other.isOne()) return this.mod(modulus); + if (modulus.u3 != 0) return modulus.asModulus256().mul(this, other); + if (modulus.u2 != 0) return modulus.asModulus192().mul(this, other); + if (modulus.u1 != 0) return modulus.asModulus128().mul(this, other); + return modulus.asModulus64().mul(this, other); } - // region Support (private) Algorithms // -------------------------------------------------------------------------- - private static int nSetLimbs(final int[] x) { - int offset = x.length - 1; - while ((offset >= 0) && (x[offset] == 0)) offset--; - return offset + 1; - } - - private static int compareLimbs(final int[] a, final int aLen, final int[] b, final int bLen) { - int cmp; - if (aLen > bLen) { - for (int i = aLen - 1; i >= bLen; i--) { - cmp = Integer.compareUnsigned(a[i], 0); - if (cmp != 0) return cmp; - } - } else if (aLen < bLen) { - for (int i = bLen - 1; i >= aLen; i--) { - cmp = Integer.compareUnsigned(0, b[i]); - if (cmp != 0) return cmp; - } + // endregion + + // region private basic operations + // + // adc (add and carry): carry, a <- a + b + // mac (multiply accumulate): a <- a + b * c + carryIn + // -------------------------------------------------------------------------- + + private UInt320 shiftLeftWide(final int shift) { + if (shift == 0) return UInt320Value(); + int invShift = (N_BITS_PER_LIMB - shift); + long z0 = (u0 << shift); + long z1 = (u1 << shift) | u0 >>> invShift; + long z2 = (u2 << shift) | u1 >>> invShift; + long z3 = (u3 << shift) | u2 >>> invShift; + long z4 = u3 >>> invShift; + return new UInt320(z4, z3, z2, z1, z0); + } + + private UInt256 shiftDigitsRight() { + return new UInt256(0, u3, u2, u1); + } + + private UInt257 adc(final UInt256 other) { + boolean carry; + if (isZero()) return new UInt257(false, other); + if (other.isZero()) return new UInt257(false, this); + long z0 = u0 + other.u0; + carry = Long.compareUnsigned(z0, u0) < 0; + long z1 = u1 + other.u1 + (carry ? 1L : 0L); + carry = Long.compareUnsigned(z1, u1) < 0 || (z1 == u1 && carry); + long z2 = u2 + other.u2 + (carry ? 1L : 0L); + carry = Long.compareUnsigned(z2, u2) < 0 || (z2 == u2 && carry); + long z3 = u3 + other.u3 + (carry ? 1L : 0L); + carry = Long.compareUnsigned(z3, u3) < 0 || (z3 == u3 && carry); + return new UInt257(carry, new UInt256(z3, z2, z1, z0)); + } + + private UInt256 mac128(final long multiplier, final UInt256 carryIn) { + // Multiply accumulate for 128bits integer (this): + // Returns this * multiplier + (carryIn >>> 64) + long hi, lo, carry; + if (multiplier == 0) return carryIn.shiftDigitsRight(); + + lo = u0 * multiplier; + hi = Math.unsignedMultiplyHigh(u0, multiplier); + long z0 = lo + carryIn.u1; + hi += (Long.compareUnsigned(z0, lo) < 0) ? 1 : 0; + + long z1 = hi + carryIn.u2; + carry = (Long.compareUnsigned(z1, hi) < 0) ? 1 : 0; + lo = u1 * multiplier; + hi = Math.unsignedMultiplyHigh(u1, multiplier); + z1 += lo; + hi += (Long.compareUnsigned(z1, lo) < 0) ? carry + 1 : carry; + + return new UInt256(0, hi, z1, z0); + } + + private UInt256 mac192(final long multiplier, final UInt256 carryIn) { + // Multiply accumulate for 192bits integer (this): + // Returns this * multiplier + (carryIn >>> 64) + long hi, lo, carry; + if (multiplier == 0) return carryIn.shiftDigitsRight(); + + lo = u0 * multiplier; + hi = Math.unsignedMultiplyHigh(u0, multiplier); + long z0 = lo + carryIn.u1; + hi += (Long.compareUnsigned(z0, lo) < 0) ? 1 : 0; + + long z1 = hi + carryIn.u2; + carry = (Long.compareUnsigned(z1, hi) < 0) ? 1 : 0; + lo = u1 * multiplier; + hi = Math.unsignedMultiplyHigh(u1, multiplier); + z1 += lo; + hi += (Long.compareUnsigned(z1, lo) < 0) ? carry + 1 : carry; + + long z2 = hi + carryIn.u3; + carry = (Long.compareUnsigned(z2, hi) < 0) ? 1 : 0; + lo = u2 * multiplier; + hi = Math.unsignedMultiplyHigh(u2, multiplier); + z2 += lo; + hi += (Long.compareUnsigned(z2, lo) < 0) ? carry + 1 : carry; + + return new UInt256(hi, z2, z1, z0); + } + + private UInt320 mac256(final long multiplier, final UInt320 carryIn) { + // Multiply accumulate for 192bits integer (this): + // Returns this * multiplier + carryIn + long hi, lo, carry; + if (multiplier == 0) return carryIn.shiftDigitsRight(); + + lo = u0 * multiplier; + hi = Math.unsignedMultiplyHigh(u0, multiplier); + long z0 = lo + carryIn.u1; + hi += (Long.compareUnsigned(z0, lo) < 0) ? 1 : 0; + carry = 0; + + long z1 = hi + carryIn.u2; + carry = (Long.compareUnsigned(z1, hi) < 0) ? 1 : 0; + lo = u1 * multiplier; + hi = Math.unsignedMultiplyHigh(u1, multiplier); + z1 += lo; + hi += (Long.compareUnsigned(z1, lo) < 0) ? carry + 1 : carry; + + long z2 = hi + carryIn.u3; + carry = (Long.compareUnsigned(z2, hi) < 0) ? 1 : 0; + lo = u2 * multiplier; + hi = Math.unsignedMultiplyHigh(u2, multiplier); + z2 += lo; + hi += (Long.compareUnsigned(z2, lo) < 0) ? carry + 1 : carry; + + long z3 = hi + carryIn.u4; + carry = (Long.compareUnsigned(z3, hi) < 0) ? 1 : 0; + lo = u3 * multiplier; + hi = Math.unsignedMultiplyHigh(u3, multiplier); + z3 += lo; + hi += (Long.compareUnsigned(z3, lo) < 0) ? carry + 1 : carry; + + return new UInt320(hi, z3, z2, z1, z0); + } + + // -------------------------------------------------------------------------- + // endregion + + // region private multiplication + // -------------------------------------------------------------------------- + + private UInt256 mul64(final UInt256 v) { + long lo = u0 * v.u0; + long hi = Math.unsignedMultiplyHigh(u0, v.u0); + return new UInt256(0, 0, hi, lo); + } + + private UInt256 mul128(final UInt256 v) { + long z0; + UInt256 res; + + res = mac128(v.u0, ZERO); + z0 = res.u0; + res = mac128(v.u1, res); + + return new UInt256(res.u2, res.u1, res.u0, z0); + } + + private UInt512 mul192(final UInt256 v) { + UInt256 res; + res = mac192(v.u0, ZERO); + long z0 = res.u0; + res = mac192(v.u1, res); + long z1 = res.u0; + res = mac192(v.u2, res); + + return new UInt512(0, 0, res.u3, res.u2, res.u1, res.u0, z1, z0); + } + + private UInt512 mul256(final UInt256 v) { + UInt320 res; + res = mac256(v.u0, UInt320.ZERO); + long z0 = res.u0; + res = mac256(v.u1, res); + long z1 = res.u0; + res = mac256(v.u2, res); + long z2 = res.u0; + res = mac256(v.u3, res); + return new UInt512(res.u4, res.u3, res.u2, res.u1, res.u0, z2, z1, z0); + } + + // -------------------------------------------------------------------------- + // endregion + + // region private quotient estimation + // -------------------------------------------------------------------------- + + // Lookup table for $\floor{\frac{2^{19} -3 ⋅ 2^8}{d_9 - 256}}$ + private static final short[] LUT = + new short[] { + 2045, 2037, 2029, 2021, 2013, 2005, 1998, 1990, 1983, 1975, 1968, 1960, 1953, 1946, 1938, + 1931, 1924, 1917, 1910, 1903, 1896, 1889, 1883, 1876, 1869, 1863, 1856, 1849, 1843, 1836, + 1830, 1824, 1817, 1811, 1805, 1799, 1792, 1786, 1780, 1774, 1768, 1762, 1756, 1750, 1745, + 1739, 1733, 1727, 1722, 1716, 1710, 1705, 1699, 1694, 1688, 1683, 1677, 1672, 1667, 1661, + 1656, 1651, 1646, 1641, 1636, 1630, 1625, 1620, 1615, 1610, 1605, 1600, 1596, 1591, 1586, + 1581, 1576, 1572, 1567, 1562, 1558, 1553, 1548, 1544, 1539, 1535, 1530, 1526, 1521, 1517, + 1513, 1508, 1504, 1500, 1495, 1491, 1487, 1483, 1478, 1474, 1470, 1466, 1462, 1458, 1454, + 1450, 1446, 1442, 1438, 1434, 1430, 1426, 1422, 1418, 1414, 1411, 1407, 1403, 1399, 1396, + 1392, 1388, 1384, 1381, 1377, 1374, 1370, 1366, 1363, 1359, 1356, 1352, 1349, 1345, 1342, + 1338, 1335, 1332, 1328, 1325, 1322, 1318, 1315, 1312, 1308, 1305, 1302, 1299, 1295, 1292, + 1289, 1286, 1283, 1280, 1276, 1273, 1270, 1267, 1264, 1261, 1258, 1255, 1252, 1249, 1246, + 1243, 1240, 1237, 1234, 1231, 1228, 1226, 1223, 1220, 1217, 1214, 1211, 1209, 1206, 1203, + 1200, 1197, 1195, 1192, 1189, 1187, 1184, 1181, 1179, 1176, 1173, 1171, 1168, 1165, 1163, + 1160, 1158, 1155, 1153, 1150, 1148, 1145, 1143, 1140, 1138, 1135, 1133, 1130, 1128, 1125, + 1123, 1121, 1118, 1116, 1113, 1111, 1109, 1106, 1104, 1102, 1099, 1097, 1095, 1092, 1090, + 1088, 1086, 1083, 1081, 1079, 1077, 1074, 1072, 1070, 1068, 1066, 1064, 1061, 1059, 1057, + 1055, 1053, 1051, 1049, 1047, 1044, 1042, 1040, 1038, 1036, 1034, 1032, 1030, 1028, 1026, + 1024, + }; + + private static long reciprocal(final long x) { + // Unchecked: x >= (1 << 63) + long x0 = x & 1L; + int x9 = (int) (x >>> 55); + long x40 = 1 + (x >>> 24); + long x63 = (x + 1) >>> 1; + long v0 = LUT[x9 - 256] & 0xFFFFL; + long v1 = (v0 << 11) - ((v0 * v0 * x40) >>> 40) - 1; + long v2 = (v1 << 13) + ((v1 * ((1L << 60) - v1 * x40)) >>> 47); + long e = ((v2 >>> 1) & (-x0)) - v2 * x63; + long s = Math.unsignedMultiplyHigh(v2, e); + long v3 = (s >>> 1) + (v2 << 31); + long t0 = v3 * x; + long t1 = Math.unsignedMultiplyHigh(v3, x); + t0 += x; + if (Long.compareUnsigned(t0, x) < 0) t1++; + t1 += x; + long v4 = v3 - t1; + return v4; + } + + // private static long reciprocal2(final long x1, final long x0) { + // // Unchecked: >= (1 << 127) + // long v = reciprocal(x1); + // long p = x1 * v + x0; + // if (Long.compareUnsigned(p, x0) < 0) { + // v--; + // if (Long.compareUnsigned(p, x1) >= 0) { + // v--; + // p -= x1; + // } + // p -= x1; + // } + // long t0 = v * x0; + // long t1 = Math.unsignedMultiplyHigh(v, x0); + // p += t1; + // if (Long.compareUnsigned(p, t1) < 0) { + // v--; + // int cmp = Long.compareUnsigned(p, x1); + // if ((cmp > 0) || ((cmp == 0) && (Long.compareUnsigned(t0, x0) >= 0))) v--; + // } + // return v; + // } + + private static DivEstimate div2by1(final long x1, final long x0, final long y, final long yInv) { + long z1 = x1; + long z0 = x0; + + // wrapping umul z1 * yInv + long q0 = z1 * yInv; + long q1 = Math.unsignedMultiplyHigh(z1, yInv); + + // wrapping uadd + + <1, 0> + long sum = q0 + z0; + long carry = ((q0 & z0) | ((q0 | z0) & ~sum)) >>> 63; + q0 = sum; + q1 += z1 + carry + 1; + + z0 -= q1 * y; + if (Long.compareUnsigned(z0, q0) > 0) { + q1 -= 1; + z0 += y; } - for (int i = Math.min(aLen, bLen) - 1; i >= 0; i--) { - cmp = Integer.compareUnsigned(a[i], b[i]); - if (cmp != 0) return cmp; + if (Long.compareUnsigned(z0, y) >= 0) { + q1 += 1; + z0 -= y; } - return 0; + return new DivEstimate(q1, z0); } - private static boolean isNeg(final int[] x, final int xLen) { - return x[xLen - 1] < 0; + private static long mod2by1(final long x1, final long x0, final long y, final long yInv) { + long z1 = x1; + long z0 = x0; + // wrapping umul z1 * yInv + long q0 = z1 * yInv; + long q1 = Math.unsignedMultiplyHigh(z1, yInv); + + // wrapping uadd + + <1, 0> + long sum = q0 + z0; + long carry = ((q0 & z0) | ((q0 | z0) & ~sum)) >>> 63; + q0 = sum; + q1 += z1 + carry + 1; + + z0 -= q1 * y; + if (Long.compareUnsigned(z0, q0) > 0) { + q1 -= 1; + z0 += y; + } + if (Long.compareUnsigned(z0, y) >= 0) { + q1 += 1; + z0 -= y; + } + return z0; } - private static void negate(final int[] x, final int xLen) { - int carry = 1; - for (int i = 0; i < xLen; i++) { - x[i] = ~x[i] + carry; - carry = (x[i] == 0 && carry == 1) ? 1 : 0; + // private static Div2Estimate div3by2( + // final long x2, final long x1, final long x0, final long y1, final long y0, final long yInv) + // { + // // divided by . + // // Requires < otherwise quotient overflows. + // long overflow; // carry or borrow + // long res; // sum or diff + // long z2 = x2; + // long z1 = x1; + // long z0 = x0; + + // // = z2 * yInv + + // long q0 = z2 * yInv; + // long q1 = Math.unsignedMultiplyHigh(z2, yInv); + // res = q0 + z1; + // overflow = ((q0 & z1) | ((q0 | z1) & ~res)) >>> 63; + // q0 = res; + // q1 = q1 + z2 + overflow; + + // // r1 <- z1 - q1 * y1 mod B + // z1 -= q1 * y1; + + // // wrapping sub − q1*y0 − + // long t0 = q1 * y0; + // long t1 = Math.unsignedMultiplyHigh(q1, y0); + + // res = z0 - t0; + // overflow = ((~z0 & t0) | ((~z0 | t0) & res)) >>> 63; + // z0 = res; + // z1 -= (t1 + overflow); + + // res = z0 - y0; + // overflow = ((~z0 & y0) | ((~z0 | y0) & res)) >>> 63; + // z0 = res; + // z1 -= (y1 + overflow); + + // // Adjustments + // q1 += 1; + // if (Long.compareUnsigned(z1, q0) >= 0) { + // q1 -= 1; + // res = z0 + y0; + // overflow = ((z0 & y0) | ((z0 | y0) & ~res)) >>> 63; + // z0 = res; + // z1 += y1 + overflow; + // } + + // int cmp = Long.compareUnsigned(z1, y1); + // if ((cmp > 0) || ((cmp == 0) && (Long.compareUnsigned(z0, y0) >= 0))) { + // q1 += 1; + // res = z0 - y0; + // overflow = ((~z0 & y0) | ((~z0 | y0) & res)) >>> 63; + // z0 = res; + // z1 -= (y1 + overflow); + // } + // return new Div2Estimate(q1, z1, z0); + // } + + // -------------------------------------------------------------------------- + // endregion + + // region Records + // -------------------------------------------------------------------------- + record UInt257(boolean carry, UInt256 u) { + boolean isUInt64() { + return !carry && u.isUInt64(); + } + + boolean isUInt256() { + return !carry; + } + + UInt256 UInt256Value() { + return u; + } + + UInt320 shiftLeftWide(final int shift) { + long u4 = (carry ? 1L : 0L); + if (shift == 0) return new UInt320(u4, u.u3, u.u2, u.u1, u.u0); + int invShift = (N_BITS_PER_LIMB - shift); + long z0 = (u.u0 << shift); + long z1 = (u.u1 << shift) | u.u0 >>> invShift; + long z2 = (u.u2 << shift) | u.u1 >>> invShift; + long z3 = (u.u3 << shift) | u.u2 >>> invShift; + long z4 = (u4 << shift) | u.u3 >>> invShift; + return new UInt320(z4, z3, z2, z1, z0); } } - private static void absInplace(final int[] x, final int xLen) { - if (isNeg(x, xLen)) negate(x, xLen); + record UInt128(long u1, long u0) {} + + record UInt192(long u2, long u1, long u0) {} + + record UInt320(long u4, long u3, long u2, long u1, long u0) { + static final UInt320 ZERO = new UInt320(0, 0, 0, 0, 0); + + // UInt256 UInt256ValueHigh() { + // return new UInt256(u4, u3, u2, u1); + // } + + UInt320 shiftDigitsRight() { + return new UInt320(0, u4, u3, u2, u1); + } + } + + record UInt512(long u7, long u6, long u5, long u4, long u3, long u2, long u1, long u0) { + boolean isUInt64() { + return (u7 & u6 & u5 & u4 & u3 & u2 & u1) == 0; + } + + UInt256 UInt256Value() { + return new UInt256(u3, u2, u1, u0); + } + + UInt576 shiftLeftWide(final int shift) { + if (shift == 0) return new UInt576(0, u7, u6, u5, u4, u3, u2, u1, u0); + int invShift = (N_BITS_PER_LIMB - shift); + long z0 = (u0 << shift); + long z1 = (u1 << shift) | u0 >>> invShift; + long z2 = (u2 << shift) | u1 >>> invShift; + long z3 = (u3 << shift) | u2 >>> invShift; + long z4 = (u4 << shift) | u3 >>> invShift; + long z5 = (u5 << shift) | u4 >>> invShift; + long z6 = (u6 << shift) | u5 >>> invShift; + long z7 = (u7 << shift) | u6 >>> invShift; + long z8 = u7 >>> invShift; + return new UInt576(z8, z7, z6, z5, z4, z3, z2, z1, z0); + } } - private static void absInto(final int[] dst, final int[] src, final int srcLen) { - System.arraycopy(src, 0, dst, 0, srcLen); - absInplace(dst, dst.length); + record UInt576(long u8, long u7, long u6, long u5, long u4, long u3, long u2, long u1, long u0) {} + + private record DivEstimate(long q, long r) {} + + record Div2Estimate(long q, long r1, long r0) {} + + // -------------------------------------------------------------------------- + // endregion + + // region 64bits Modulus + // -------------------------------------------------------------------------- + record Modulus64(long u0) { + Modulus64 shiftLeft(final int shift) { + return (shift == 0) ? this : new Modulus64(u0 << shift); + } + + UInt256 reduce(final UInt256 that) { + if (that.isUInt64()) { + return UInt256.fromLong(Long.remainderUnsigned(that.u0, u0)); + } + int shift = Long.numberOfLeadingZeros(u0); + Modulus64 m = shiftLeft(shift); + long inv = reciprocal(m.u0); + return m.reduceNormalised(that, shift, inv); + } + + UInt256 reduce(final UInt512 that) { + if (that.isUInt64()) return UInt256.fromLong(Long.remainderUnsigned(that.u0(), u0)); + int shift = Long.numberOfLeadingZeros(u0); + Modulus64 m = shiftLeft(shift); + long inv = reciprocal(m.u0); + return m.reduceNormalised(that, shift, inv); + } + + UInt256 sum(final UInt256 a, final UInt256 b) { + UInt257 sum = a.adc(b); + if (sum.isUInt64()) return UInt256.fromLong(Long.remainderUnsigned(sum.u().u0, u0)); + int shift = Long.numberOfLeadingZeros(u0); + Modulus64 m = shiftLeft(shift); + long inv = reciprocal(m.u0); + return m.reduceNormalised(sum, shift, inv); + } + + UInt256 mul(final UInt256 a, final UInt256 b) { + // multiply-reduce + if (a.isUInt64() && b.isUInt64()) { + UInt256 prod = a.mul64(b); + if (prod.isUInt64()) return UInt256.fromLong(Long.remainderUnsigned(prod.u0, u0)); + return reduce(prod); + } + // reduce-multiply-reduce + int shift = Long.numberOfLeadingZeros(u0); + Modulus64 m = shiftLeft(shift); + long inv = reciprocal(m.u0); + UInt256 x = (a.isUInt64()) ? a : m.reduceNormalised(a, shift, inv); + UInt256 y = (b.isUInt64()) ? b : m.reduceNormalised(b, shift, inv); + UInt256 prod = x.mul64(y); + return prod.isUInt64() + ? UInt256.fromLong(Long.remainderUnsigned(prod.u0, u0)) + : m.reduceNormalised(prod, shift, inv); + } + + private long reduceStep(final long v1, final long v0, final long inv) { + return (v1 == u0) ? v0 : mod2by1(v1, v0, u0, inv); + } + + private UInt256 reduceNormalised(final UInt256 that, final int shift, final long inv) { + long r; + UInt320 v = that.shiftLeftWide(shift); + if (v.u4 != 0 || Long.compareUnsigned(v.u3, u0) > 0) { + r = reduceStep(v.u4, v.u3, inv); + r = reduceStep(r, v.u2, inv); + r = reduceStep(r, v.u1, inv); + r = reduceStep(r, v.u0, inv); + } else if (v.u3 != 0 || Long.compareUnsigned(v.u2, u0) > 0) { + r = reduceStep(v.u3, v.u2, inv); + r = reduceStep(r, v.u1, inv); + r = reduceStep(r, v.u0, inv); + } else if (v.u2 != 0 || Long.compareUnsigned(v.u1, u0) > 0) { + r = reduceStep(v.u2, v.u1, inv); + r = reduceStep(r, v.u0, inv); + } else { + r = reduceStep(v.u1, v.u0, inv); + } + return UInt256.fromLong(r >>> shift); + } + + private UInt256 reduceNormalised(final UInt257 that, final int shift, final long inv) { + long r; + UInt320 v = that.shiftLeftWide(shift); + if (v.u4 != 0 || Long.compareUnsigned(v.u3, u0) > 0) { + r = reduceStep(v.u4, v.u3, inv); + r = reduceStep(r, v.u2, inv); + r = reduceStep(r, v.u1, inv); + r = reduceStep(r, v.u0, inv); + } else if (v.u3 != 0 || Long.compareUnsigned(v.u2, u0) > 0) { + r = reduceStep(v.u3, v.u2, inv); + r = reduceStep(r, v.u1, inv); + r = reduceStep(r, v.u0, inv); + } else if (v.u2 != 0 || Long.compareUnsigned(v.u1, u0) > 0) { + r = reduceStep(v.u2, v.u1, inv); + r = reduceStep(r, v.u0, inv); + } else { + r = reduceStep(v.u1, v.u0, inv); + } + return UInt256.fromLong(r >>> shift); + } + + private UInt256 reduceNormalised(final UInt512 that, final int shift, final long inv) { + long r; + UInt576 v = that.shiftLeftWide(shift); + if (v.u8 != 0 || Long.compareUnsigned(v.u7, u0) > 0) { + r = reduceStep(v.u8, v.u7, inv); + r = reduceStep(r, v.u6, inv); + r = reduceStep(r, v.u5, inv); + r = reduceStep(r, v.u4, inv); + r = reduceStep(r, v.u3, inv); + r = reduceStep(r, v.u2, inv); + r = reduceStep(r, v.u1, inv); + r = reduceStep(r, v.u0, inv); + } else if (v.u7 != 0 || Long.compareUnsigned(v.u6, u0) > 0) { + r = reduceStep(v.u7, v.u6, inv); + r = reduceStep(r, v.u5, inv); + r = reduceStep(r, v.u4, inv); + r = reduceStep(r, v.u3, inv); + r = reduceStep(r, v.u2, inv); + r = reduceStep(r, v.u1, inv); + r = reduceStep(r, v.u0, inv); + } else if (v.u6 != 0 || Long.compareUnsigned(v.u5, u0) > 0) { + r = reduceStep(v.u6, v.u5, inv); + r = reduceStep(r, v.u4, inv); + r = reduceStep(r, v.u3, inv); + r = reduceStep(r, v.u2, inv); + r = reduceStep(r, v.u1, inv); + r = reduceStep(r, v.u0, inv); + } else if (v.u5 != 0 || Long.compareUnsigned(v.u4, u0) > 0) { + r = reduceStep(v.u5, v.u4, inv); + r = reduceStep(r, v.u3, inv); + r = reduceStep(r, v.u2, inv); + r = reduceStep(r, v.u1, inv); + r = reduceStep(r, v.u0, inv); + } else if (v.u4 != 0 || Long.compareUnsigned(v.u3, u0) > 0) { + r = reduceStep(v.u4, v.u3, inv); + r = reduceStep(r, v.u2, inv); + r = reduceStep(r, v.u1, inv); + r = reduceStep(r, v.u0, inv); + } else if (v.u3 != 0 || Long.compareUnsigned(v.u2, u0) > 0) { + r = reduceStep(v.u3, v.u2, inv); + r = reduceStep(r, v.u1, inv); + r = reduceStep(r, v.u0, inv); + } else if (v.u2 != 0 || Long.compareUnsigned(v.u1, u0) > 0) { + r = reduceStep(v.u2, v.u1, inv); + r = reduceStep(r, v.u0, inv); + } else { + r = reduceStep(v.u1, v.u0, inv); + } + return UInt256.fromLong(r >>> shift); + } } - private static int numberOfLeadingZeros(final int[] x, final int xLen) { - int leadingIndex = xLen - 1; - while ((leadingIndex >= 0) && (x[leadingIndex] == 0)) leadingIndex--; - return 32 * (xLen - leadingIndex - 1) + Integer.numberOfLeadingZeros(x[leadingIndex]); + // -------------------------------------------------------------------------- + // endregion 64bits Modulus + + // region 128bits Modulus + // -------------------------------------------------------------------------- + record Modulus128(long u1, long u0) { + Modulus128 shiftLeft(final int shift) { + if (shift == 0) return this; + int invShift = N_BITS_PER_LIMB - shift; + return new Modulus128((u1 << shift) | (u0 >>> invShift), u0 << shift); + } + + int compareTo(final UInt256 v) { + if ((v.u3 | v.u2) != 0) return -1; + if (v.u1 != u1) return Long.compareUnsigned(u1, v.u1); + return Long.compareUnsigned(u0, v.u0); + } + + int compareTo(final UInt512 v) { + if ((v.u7 | v.u6 | v.u5 | v.u4 | v.u3 | v.u2) != 0) return -1; + if (v.u1 != u1) return Long.compareUnsigned(u1, v.u1); + return Long.compareUnsigned(u0, v.u0); + } + + UInt256 reduce(final UInt256 that) { + int cmp = compareTo(that); + if (cmp == 0) return ZERO; + if (cmp > 0) return that; + int shift = Long.numberOfLeadingZeros(u1); + Modulus128 m = shiftLeft(shift); + long inv = reciprocal(m.u1); + return m.reduceNormalised(that, shift, inv); + } + + UInt256 reduce(final UInt512 that) { + int cmp = compareTo(that); + if (cmp == 0) return ZERO; + if (cmp > 0) return that.UInt256Value(); + int shift = Long.numberOfLeadingZeros(u1); + Modulus128 m = shiftLeft(shift); + long inv = reciprocal(m.u1); + return m.reduceNormalised(that, shift, inv); + } + + UInt256 sum(final UInt256 a, final UInt256 b) { + UInt257 sum = a.adc(b); + int cmp = sum.isUInt256() ? compareTo(sum.UInt256Value()) : -1; + if (cmp == 0) return ZERO; + if (cmp > 0) return sum.UInt256Value(); + int shift = Long.numberOfLeadingZeros(u1); + Modulus128 m = shiftLeft(shift); + long inv = reciprocal(m.u1); + return m.reduceNormalised(sum, shift, inv); + } + + UInt256 mul(final UInt256 a, final UInt256 b) { + // multiply-reduce + if (a.isUInt128() && b.isUInt128()) { + UInt256 prod = a.mul128(b); + int cmp = compareTo(prod); + if (cmp == 0) return ZERO; + if (cmp > 0) return prod; + return reduce(prod); + } + // reduce-multiply-reduce + int shift = Long.numberOfLeadingZeros(u1); + Modulus128 m = shiftLeft(shift); + long inv = reciprocal(m.u1); + UInt256 x = (a.isUInt128()) ? a : m.reduceNormalised(a, shift, inv); + UInt256 y = (b.isUInt128()) ? b : m.reduceNormalised(b, shift, inv); + UInt256 prod = x.mul128(y); + int cmp = compareTo(prod); + if (cmp == 0) return ZERO; + if (cmp > 0) return prod; + return m.reduceNormalised(prod, shift, inv); + } + + private UInt128 reduceStep(final long v2, final long v1, final long v0, final long inv) { + long borrow, p0, p1, res; + long z2 = v2; + long z1 = v1; + long z0 = v0; + + if (z2 == u1) { + // Overflow case: div2by1 quotient would be <1, 0>, but adjusts to <0, -1> + // = -1 * u0 = + res = z0 + u0; + borrow = ((~z0 & ~u0) | ((~z0 | ~u0) & res)) >>> 63; + p1 = u0 - 1 + borrow; + z0 = res; + z1 = z1 - p1 + u1; + } else { + DivEstimate qr = div2by1(z2, z1, u1, inv); + z2 = 0; + z1 = qr.r; + + if (qr.q != 0) { + // Multiply-subtract: highest limb is already substracted + // = * q + p0 = u0 * qr.q; + p1 = Math.unsignedMultiplyHigh(u0, qr.q); + res = z0 - p0; + p1 += ((~z0 & p0) | ((~z0 | p0) & res)) >>> 63; + z0 = res; + + // Propagate overflows (borrows) + res = z1 - p1; + borrow = ((~z1 & p1) | ((~z1 | p1) & res)) >>> 63; + z1 = res; + + if (borrow != 0) { // unlikely + // Add back + res = z0 + u0; + long carry = (Long.compareUnsigned(res, z0) < 0) ? 1 : 0; + z0 = res; + res = z1 + u1 + carry; + carry = (Long.compareUnsigned(res, z1) < 0 || (u1 == -1 && carry == 1)) ? 1 : 0; + z1 = res; + if (carry == 0) { // unlikely: add back again + // Add back + res = z0 + u0; + carry = (Long.compareUnsigned(res, z0) < 0) ? 1 : 0; + z0 = res; + z1 = z1 + u1 + carry; + } + } + } + } + return new UInt128(z1, z0); + } + + private UInt256 reduceNormalised(final UInt256 that, final int shift, final long inv) { + UInt128 r; + UInt320 v = that.shiftLeftWide(shift); + if (v.u4 != 0 || Long.compareUnsigned(v.u3, u1) >= 0) { + r = reduceStep(v.u4, v.u3, v.u2, inv); + r = reduceStep(r.u1, r.u0, v.u1, inv); + r = reduceStep(r.u1, r.u0, v.u0, inv); + } else if (v.u3 != 0 || Long.compareUnsigned(v.u2, u1) >= 0) { + r = reduceStep(v.u3, v.u2, v.u1, inv); + r = reduceStep(r.u1, r.u0, v.u0, inv); + } else { + r = reduceStep(v.u2, v.u1, v.u0, inv); + } + return new UInt256(0, 0, r.u1, r.u0).shiftRight(shift); + } + + private UInt256 reduceNormalised(final UInt257 that, final int shift, final long inv) { + UInt128 r; + UInt320 v = that.shiftLeftWide(shift); + if (v.u4 != 0 || Long.compareUnsigned(v.u3, u1) >= 0) { + r = reduceStep(v.u4, v.u3, v.u2, inv); + r = reduceStep(r.u1, r.u0, v.u1, inv); + r = reduceStep(r.u1, r.u0, v.u0, inv); + } else if (v.u3 != 0 || Long.compareUnsigned(v.u2, u1) >= 0) { + r = reduceStep(v.u3, v.u2, v.u1, inv); + r = reduceStep(r.u1, r.u0, v.u0, inv); + } else { + r = reduceStep(v.u2, v.u1, v.u0, inv); + } + return new UInt256(0, 0, r.u1, r.u0).shiftRight(shift); + } + + private UInt256 reduceNormalised(final UInt512 that, final int shift, final long inv) { + UInt128 r; + UInt576 v = that.shiftLeftWide(shift); + if (v.u8 != 0 || Long.compareUnsigned(v.u7, u1) >= 0) { + r = reduceStep(v.u8, v.u7, v.u6, inv); + r = reduceStep(r.u1, r.u0, v.u5, inv); + r = reduceStep(r.u1, r.u0, v.u4, inv); + r = reduceStep(r.u1, r.u0, v.u3, inv); + r = reduceStep(r.u1, r.u0, v.u2, inv); + r = reduceStep(r.u1, r.u0, v.u1, inv); + r = reduceStep(r.u1, r.u0, v.u0, inv); + } else if (v.u7 != 0 || Long.compareUnsigned(v.u6, u1) >= 0) { + r = reduceStep(v.u7, v.u6, v.u5, inv); + r = reduceStep(r.u1, r.u0, v.u4, inv); + r = reduceStep(r.u1, r.u0, v.u3, inv); + r = reduceStep(r.u1, r.u0, v.u2, inv); + r = reduceStep(r.u1, r.u0, v.u1, inv); + r = reduceStep(r.u1, r.u0, v.u0, inv); + } else if (v.u6 != 0 || Long.compareUnsigned(v.u5, u1) >= 0) { + r = reduceStep(v.u6, v.u5, v.u4, inv); + r = reduceStep(r.u1, r.u0, v.u3, inv); + r = reduceStep(r.u1, r.u0, v.u2, inv); + r = reduceStep(r.u1, r.u0, v.u1, inv); + r = reduceStep(r.u1, r.u0, v.u0, inv); + } else if (v.u5 != 0 || Long.compareUnsigned(v.u4, u1) >= 0) { + r = reduceStep(v.u5, v.u4, v.u3, inv); + r = reduceStep(r.u1, r.u0, v.u2, inv); + r = reduceStep(r.u1, r.u0, v.u1, inv); + r = reduceStep(r.u1, r.u0, v.u0, inv); + } else if (v.u4 != 0 || Long.compareUnsigned(v.u3, u1) >= 0) { + r = reduceStep(v.u4, v.u3, v.u2, inv); + r = reduceStep(r.u1, r.u0, v.u1, inv); + r = reduceStep(r.u1, r.u0, v.u0, inv); + } else if (v.u3 != 0 || Long.compareUnsigned(v.u2, u1) >= 0) { + r = reduceStep(v.u3, v.u2, v.u1, inv); + r = reduceStep(r.u1, r.u0, v.u0, inv); + } else { + r = reduceStep(v.u2, v.u1, v.u0, inv); + } + return new UInt256(0, 0, r.u1, r.u0).shiftRight(shift); + } } - private static void shiftLeftInto( - final int[] result, final int[] x, final int xLen, final int shift) { - // Unchecked: result should be initialised with zeroes - // Unchecked: result length should be at least x.length + limbShift - int limbShift = shift / N_BITS_PER_LIMB; - int bitShift = shift % N_BITS_PER_LIMB; - if (bitShift == 0) { - System.arraycopy(x, 0, result, limbShift, xLen); - return; - } - - int j = limbShift; - int carry = 0; - for (int i = 0; i < xLen; ++i, ++j) { - result[j] = (x[i] << bitShift) | carry; - carry = x[i] >>> (32 - bitShift); - } - if (carry != 0) result[j] = carry; // last carry - } - - private static void shiftRightInto( - final int[] result, final int[] x, final int xLen, final int shift) { - // Unchecked: result length should be at least x.length - limbShift - int limbShift = shift / 32; - int bitShift = shift % 32; - int nLimbs = xLen - limbShift; - if (nLimbs <= 0) return; - - if (bitShift == 0) { - System.arraycopy(x, limbShift, result, 0, nLimbs); - return; - } - - int carry = 0; - for (int i = nLimbs - 1 + limbShift, j = nLimbs - 1; j >= 0; i--, j--) { - int r = (x[i] >>> bitShift) | carry; - result[j] = r; - carry = x[i] << (32 - bitShift); - } - } - - private static int[] addWithCarry(final int[] x, final int xLen, final int[] y, final int yLen) { - // Step 1: Add with carry - int[] a; - int[] b; - int aLen; - int bLen; - if (xLen < yLen) { - a = y; - aLen = yLen; - b = x; - bLen = xLen; - } else { - a = x; - aLen = xLen; - b = y; - bLen = yLen; - } - int[] sum = new int[aLen + 1]; - long carry = 0; - for (int i = 0; i < bLen; i++) { - long ai = a[i] & MASK_L; - long bi = b[i] & MASK_L; - long s = ai + bi + carry; - sum[i] = (int) s; - carry = s >>> 32; - } - int icarry = (int) carry; - for (int i = bLen; i < aLen; i++) { - sum[i] = a[i] + icarry; - icarry = (a[i] != 0 && sum[i] == 0) ? 1 : 0; - } - sum[aLen] = icarry; - return sum; - } - - private static int[] addMul(final int[] a, final int aLen, final int[] b, final int bLen) { - // Shortest in outer loop, swap if needed - int[] x; - int xLen; - int[] y; - int yLen; - if (a.length < b.length) { - x = b; - xLen = bLen; - y = a; - yLen = aLen; - } else { - x = a; - xLen = aLen; - y = b; - yLen = bLen; - } - int[] lhs = new int[xLen + yLen + 1]; - - // Main algo - for (int i = 0; i < yLen; i++) { - long carry = 0; - long yi = y[i] & MASK_L; - - int k = i; - for (int j = 0; j < xLen; j++, k++) { - long prod = yi * (x[j] & MASK_L); - long sum = (lhs[k] & MASK_L) + prod + carry; - lhs[k] = (int) sum; - carry = sum >>> 32; + // -------------------------------------------------------------------------- + // endregion 128bits Modulus + + // region 192bits Modulus + // -------------------------------------------------------------------------- + record Modulus192(long u2, long u1, long u0) { + Modulus192 shiftLeft(final int shift) { + if (shift == 0) return this; + int invShift = N_BITS_PER_LIMB - shift; + long z0 = u0 << shift; + long z1 = (u1 << shift) | (u0 >>> invShift); + long z2 = (u2 << shift) | (u1 >>> invShift); + return new Modulus192(z2, z1, z0); + } + + int compareTo(final UInt256 v) { + if (v.u3 != 0) return -1; + if (v.u2 != u2) return Long.compareUnsigned(u2, v.u2); + if (v.u1 != u1) return Long.compareUnsigned(u1, v.u1); + return Long.compareUnsigned(u0, v.u0); + } + + int compareTo(final UInt512 v) { + if ((v.u7 | v.u6 | v.u5 | v.u4 | v.u3) != 0) return -1; + if (v.u2 != u2) return Long.compareUnsigned(u2, v.u2); + if (v.u1 != u1) return Long.compareUnsigned(u1, v.u1); + return Long.compareUnsigned(u0, v.u0); + } + + UInt256 reduce(final UInt256 that) { + int cmp = compareTo(that); + if (cmp == 0) return ZERO; + if (cmp > 0) return that; + int shift = Long.numberOfLeadingZeros(u2); + Modulus192 m = shiftLeft(shift); + long inv = reciprocal(m.u2); + return m.reduceNormalised(that, shift, inv); + } + + UInt256 reduce(final UInt512 that) { + int cmp = compareTo(that); + if (cmp == 0) return ZERO; + if (cmp > 0) return that.UInt256Value(); + int shift = Long.numberOfLeadingZeros(u2); + Modulus192 m = shiftLeft(shift); + long inv = reciprocal(m.u2); + return m.reduceNormalised(that, shift, inv); + } + + UInt256 sum(final UInt256 a, final UInt256 b) { + UInt257 sum = a.adc(b); + if (!sum.carry()) { + int cmp = compareTo(sum.UInt256Value()); + if (cmp == 0) return ZERO; + if (cmp > 0) return sum.UInt256Value(); } + int shift = Long.numberOfLeadingZeros(u2); + Modulus192 m = shiftLeft(shift); + long inv = reciprocal(m.u2); + return m.reduceNormalised(sum, shift, inv); + } - // propagate leftover carry - while (carry != 0 && k < lhs.length) { - long sum = (lhs[k] & MASK_L) + carry; - lhs[k] = (int) sum; - carry = sum >>> 32; - k++; + UInt256 mul(final UInt256 a, final UInt256 b) { + // multiply-reduce + if (a.isUInt192() && b.isUInt192()) { + UInt512 prod = a.mul192(b); + int cmp = compareTo(prod); + if (cmp == 0) return ZERO; + if (cmp > 0) return prod.UInt256Value(); + return reduce(prod); + } + // reduce-multiply-reduce + int shift = Long.numberOfLeadingZeros(u2); + Modulus192 m = shiftLeft(shift); + long inv = reciprocal(m.u2); + UInt256 x = (a.isUInt192()) ? a : m.reduceNormalised(a, shift, inv); + UInt256 y = (b.isUInt192()) ? b : m.reduceNormalised(b, shift, inv); + UInt512 prod = x.mul192(y); + int cmp = compareTo(prod); + if (cmp == 0) return ZERO; + if (cmp > 0) return prod.UInt256Value(); + return m.reduceNormalised(prod, shift, inv); + } + + private UInt192 reduceStep( + final long v3, final long v2, final long v1, final long v0, final long inv) { + long borrow, p0, p1, p2, res; + // Divide step -> get highest 2 limbs. + long z3 = v3; + long z2 = v2; + long z1 = v1; + long z0 = v0; + + if (z3 == u2) { + // Overflow case: div2by1 quotient would be <1, 0>, but adjusts to <0, -1> + // = -1 * u0 = + res = z0 + u0; + borrow = ((~z0 & ~u0) | ((~z0 | ~u0) & res)) >>> 63; + p1 = u0 - 1 + borrow; + z0 = res; + + res = z1 - p1; + borrow = ((~z1 & p1) | ((~z1 | p1) & res)) >>> 63; + p1 = u1 - 1 + borrow; + z1 = res + u1; + borrow = ((~res & ~u1) | ((~res | ~u1) & z1)) >>> 63; + + z2 = z2 - p1 + u2 - borrow; + z3 = 0; + // borrow = ((~z2 & p1) | ((~z2 | p1) & res)) >>> 63; + // p1 = u2 - 1 + borrow; + // borrow = ((~res & ~u1) | ((~res | ~u1) & z1)) >>> 63; + // assert p1 + borrow == z3 : "Division did not cancel top digit" + } else { + DivEstimate qr = div2by1(z3, z2, u2, inv); + z3 = 0; + z2 = qr.r; + + if (qr.q != 0) { + // Multiply-subtract: already have highest 2 limbs + // = * q + p0 = u0 * qr.q; + p1 = Math.unsignedMultiplyHigh(u0, qr.q); + res = z0 - p0; + p1 += ((~z0 & p0) | ((~z0 | p0) & res)) >>> 63; + z0 = res; + + p0 = u1 * qr.q; + p2 = Math.unsignedMultiplyHigh(u1, qr.q); + res = z1 - p0; + p2 += ((~z1 & p0) | ((~z1 | p0) & res)) >>> 63; + z1 = res - p1; + borrow = ((~res & p1) | ((~res | p1) & z1)) >>> 63; + + // Propagate overflows (borrows) + res = z2 - p2 - borrow; + borrow = ((~z2 & p2) | ((~z2 | p2) & res)) >>> 63; + z2 = res; + + if (borrow != 0) { // unlikely + // Add back + res = z0 + u0; + long carry = (Long.compareUnsigned(res, z0) < 0) ? 1 : 0; + z0 = res; + res = z1 + u1 + carry; + carry = (Long.compareUnsigned(res, z1) < 0 || (u1 == -1 && carry == 1)) ? 1 : 0; + z1 = res; + res = z2 + u2 + carry; + carry = (Long.compareUnsigned(res, z2) < 0 || (u2 == -1 && carry == 1)) ? 1 : 0; + z2 = res; + + if (carry == 0) { // unlikely: add back again + // Add back + res = z0 + u0; + carry = (Long.compareUnsigned(res, z0) < 0) ? 1 : 0; + z0 = res; + res = z1 + u1 + carry; + carry = (Long.compareUnsigned(res, z1) < 0 || (u1 == -1 && carry == 1)) ? 1 : 0; + z1 = res; + z2 = z2 + u2 + carry; + } + } + } } + return new UInt192(z2, z1, z0); } - return lhs; - } - - private static int[] knuthRemainder(final int[] dividend, final int[] modulus) { - int[] result = new int[N_LIMBS]; - int divLen = nSetLimbs(dividend); - int modLen = nSetLimbs(modulus); - int cmp = compareLimbs(dividend, divLen, modulus, modLen); - if (cmp < 0) { - System.arraycopy(dividend, 0, result, 0, divLen); - return result; - } else if (cmp == 0) { - return result; - } - - int shift = numberOfLeadingZeros(modulus, modLen); - int limbShift = shift / 32; - int n = modLen - limbShift; - if (n == 0) return result; - if (n == 1) { - if (divLen == 1) { - result[0] = Integer.remainderUnsigned(dividend[0], modulus[0]); - return result; + + private UInt256 reduceNormalised(final UInt256 that, final int shift, final long inv) { + UInt192 r; + UInt320 v = that.shiftLeftWide(shift); + if (v.u4 != 0 || Long.compareUnsigned(v.u3, u2) >= 0) { + r = reduceStep(v.u4, v.u3, v.u2, v.u1, inv); + r = reduceStep(r.u2, r.u1, r.u0, v.u0, inv); + } else { + r = reduceStep(v.u3, v.u2, v.u1, v.u0, inv); } - long d = modulus[0] & MASK_L; - long rem = 0; - // Process from most significant limb downwards - for (int i = divLen - 1; i >= 0; i--) { - long cur = (rem << 32) | (dividend[i] & MASK_L); - rem = Long.remainderUnsigned(cur, d); + return new UInt256(0, r.u2, r.u1, r.u0).shiftRight(shift); + } + + private UInt256 reduceNormalised(final UInt257 that, final int shift, final long inv) { + UInt192 r; + UInt320 v = that.shiftLeftWide(shift); + if (v.u4 != 0 || Long.compareUnsigned(v.u3, u2) >= 0) { + r = reduceStep(v.u4, v.u3, v.u2, v.u1, inv); + r = reduceStep(r.u2, r.u1, r.u0, v.u0, inv); + } else { + r = reduceStep(v.u3, v.u2, v.u1, v.u0, inv); } - result[0] = (int) rem; - result[1] = (int) (rem >>> 32); - return result; - } - // Normalize - int m = divLen - n; - int bitShift = shift % 32; - int[] vLimbs = new int[n]; - shiftLeftInto(vLimbs, modulus, modLen, shift); - int[] uLimbs = new int[divLen + 1]; - shiftLeftInto(uLimbs, dividend, divLen, bitShift); - - long[] vLimbsAsLong = new long[n]; - for (int i = 0; i < n; i++) { - vLimbsAsLong[i] = vLimbs[i] & MASK_L; - } - - // Main division loop - long vn1 = vLimbsAsLong[n - 1]; - long vn2 = vLimbsAsLong[n - 2]; - for (int j = m; j >= 0; j--) { - long ujn = (uLimbs[j + n] & MASK_L); - long ujn1 = (uLimbs[j + n - 1] & MASK_L); - long ujn2 = (uLimbs[j + n - 2] & MASK_L); - - long dividendPart = (ujn << 32) | ujn1; - // Check that no need for Unsigned version of divrem. - long qhat = Long.divideUnsigned(dividendPart, vn1); - long rhat = Long.remainderUnsigned(dividendPart, vn1); - - while (qhat == 0x1_0000_0000L || Long.compareUnsigned(qhat * vn2, (rhat << 32) | ujn2) > 0) { - qhat--; - rhat += vn1; - if (rhat >= 0x1_0000_0000L) break; + return new UInt256(0, r.u2, r.u1, r.u0).shiftRight(shift); + } + + private UInt256 reduceNormalised(final UInt512 that, final int shift, final long inv) { + UInt192 r; + UInt576 v = that.shiftLeftWide(shift); + if (v.u8 != 0 || Long.compareUnsigned(v.u7, u2) >= 0) { + r = reduceStep(v.u8, v.u7, v.u6, v.u5, inv); + r = reduceStep(r.u2, r.u1, r.u0, v.u4, inv); + r = reduceStep(r.u2, r.u1, r.u0, v.u3, inv); + r = reduceStep(r.u2, r.u1, r.u0, v.u2, inv); + r = reduceStep(r.u2, r.u1, r.u0, v.u1, inv); + r = reduceStep(r.u2, r.u1, r.u0, v.u0, inv); + } else if (v.u7 != 0 || Long.compareUnsigned(v.u6, u2) >= 0) { + r = reduceStep(v.u7, v.u6, v.u5, v.u4, inv); + r = reduceStep(r.u2, r.u1, r.u0, v.u3, inv); + r = reduceStep(r.u2, r.u1, r.u0, v.u2, inv); + r = reduceStep(r.u2, r.u1, r.u0, v.u1, inv); + r = reduceStep(r.u2, r.u1, r.u0, v.u0, inv); + } else if (v.u6 != 0 || Long.compareUnsigned(v.u5, u2) >= 0) { + r = reduceStep(v.u6, v.u5, v.u4, v.u3, inv); + r = reduceStep(r.u2, r.u1, r.u0, v.u2, inv); + r = reduceStep(r.u2, r.u1, r.u0, v.u1, inv); + r = reduceStep(r.u2, r.u1, r.u0, v.u0, inv); + } else if (v.u5 != 0 || Long.compareUnsigned(v.u4, u2) >= 0) { + r = reduceStep(v.u5, v.u4, v.u3, v.u2, inv); + r = reduceStep(r.u2, r.u1, r.u0, v.u1, inv); + r = reduceStep(r.u2, r.u1, r.u0, v.u0, inv); + } else if (v.u4 != 0 || Long.compareUnsigned(v.u3, u2) >= 0) { + r = reduceStep(v.u4, v.u3, v.u2, v.u1, inv); + r = reduceStep(r.u2, r.u1, r.u0, v.u0, inv); + } else { + r = reduceStep(v.u3, v.u2, v.u1, v.u0, inv); } + return new UInt256(0, r.u2, r.u1, r.u0).shiftRight(shift); + } + } + + // -------------------------------------------------------------------------- + // endregion 192bits Modulus + + // region 256bits Modulus + // -------------------------------------------------------------------------- + record Modulus256(long u3, long u2, long u1, long u0) { + Modulus256 shiftLeft(final int shift) { + if (shift == 0) return this; + int invShift = N_BITS_PER_LIMB - shift; + long z0 = u0 << shift; + long z1 = (u1 << shift) | (u0 >>> invShift); + long z2 = (u2 << shift) | (u1 >>> invShift); + long z3 = (u3 << shift) | (u2 >>> invShift); + return new Modulus256(z3, z2, z1, z0); + } + + int compareTo(final UInt256 v) { + if (v.u3 != u3) return Long.compareUnsigned(u3, v.u3); + if (v.u2 != u2) return Long.compareUnsigned(u2, v.u2); + if (v.u1 != u1) return Long.compareUnsigned(u1, v.u1); + return Long.compareUnsigned(u0, v.u0); + } - // Multiply-subtract qhat*v from u slice - long borrow = 0; - for (int i = 0; i < n; i++) { - long prod = vLimbsAsLong[i] * qhat; - long sub = (uLimbs[i + j] & MASK_L) - (prod & MASK_L) - borrow; - uLimbs[i + j] = (int) sub; - borrow = (prod >>> 32) - (sub >> 32); + int compareTo(final UInt512 v) { + if ((v.u7 | v.u6 | v.u5 | v.u4) != 0) return -1; + if (v.u3 != u3) return Long.compareUnsigned(u3, v.u3); + if (v.u2 != u2) return Long.compareUnsigned(u2, v.u2); + if (v.u1 != u1) return Long.compareUnsigned(u1, v.u1); + return Long.compareUnsigned(u0, v.u0); + } + + UInt256 reduce(final UInt256 that) { + int cmp = compareTo(that); + if (cmp == 0) return ZERO; + if (cmp > 0) return that; + int shift = Long.numberOfLeadingZeros(u3); + Modulus256 m = shiftLeft(shift); + long inv = reciprocal(m.u3); + return m.reduceNormalised(that, shift, inv); + } + + UInt256 reduce(final UInt512 that) { + int cmp = compareTo(that); + if (cmp == 0) return ZERO; + if (cmp > 0) return that.UInt256Value(); + int shift = Long.numberOfLeadingZeros(u3); + Modulus256 m = shiftLeft(shift); + long inv = reciprocal(m.u3); + return m.reduceNormalised(that, shift, inv); + } + + UInt256 sum(final UInt256 a, final UInt256 b) { + UInt257 sum = a.adc(b); + if (!sum.carry()) { + int cmp = compareTo(sum.UInt256Value()); + if (cmp == 0) return ZERO; + if (cmp > 0) return sum.UInt256Value(); } - long sub = (uLimbs[j + n] & MASK_L) - borrow; - uLimbs[j + n] = (int) sub; - - if (sub < 0) { - // Add back - long carry = 0; - for (int i = 0; i < n; i++) { - long sum = (uLimbs[i + j] & MASK_L) + vLimbsAsLong[i] + carry; - uLimbs[i + j] = (int) sum; - carry = sum >>> 32; + int shift = Long.numberOfLeadingZeros(u3); + Modulus256 m = shiftLeft(shift); + long inv = reciprocal(m.u3); + return m.reduceNormalised(sum, shift, inv); + } + + UInt256 mul(final UInt256 a, final UInt256 b) { + // multiply-reduce + UInt512 prod = a.mul256(b); + int cmp = compareTo(prod); + if (cmp == 0) return ZERO; + if (cmp > 0) return prod.UInt256Value(); + return reduce(prod); + } + + private UInt256 reduceStep( + final long v4, final long v3, final long v2, final long v1, final long v0, final long inv) { + long borrow, p0, p1, p2, res; + long z4 = v4; + long z3 = v3; + long z2 = v2; + long z1 = v1; + long z0 = v0; + + if (z4 == u3) { + // Overflow case: div2by1 quotient would be <1, 0>, but adjusts to <0, -1> + // = -1 * u0 = + res = z0 + u0; + borrow = ((~z0 & ~u0) | ((~z0 | ~u0) & res)) >>> 63; + p1 = u0 - 1 + borrow; + z0 = res; + + res = z1 - p1; + borrow = ((~z1 & p1) | ((~z1 | p1) & res)) >>> 63; + p1 = u1 - 1 + borrow; + z1 = res + u1; + borrow = ((~res & ~u1) | ((~res | ~u1) & z1)) >>> 63; + + res = z2 - p1 - borrow; + borrow = ((~z2 & p1) | ((~z2 | p1) & res)) >>> 63; + p1 = u2 - 1 + borrow; + z2 = res + u2; + borrow = ((~res & ~u2) | ((~res | ~u2) & z2)) >>> 63; + + z3 = z3 - p1 + u3 - borrow; + } else { + DivEstimate qr = div2by1(z4, z3, u3, inv); + z3 = qr.r; + + // Multiply-subtract: already have highest 1 limbs + // = * q + p0 = u0 * qr.q; + p1 = Math.unsignedMultiplyHigh(u0, qr.q); + res = z0 - p0; + p1 += ((~z0 & p0) | ((~z0 | p0) & res)) >>> 63; + z0 = res; + + p0 = u1 * qr.q; + p2 = Math.unsignedMultiplyHigh(u1, qr.q); + res = z1 - p0; + p2 += ((~z1 & p0) | ((~z1 | p0) & res)) >>> 63; + z1 = res - p1; + borrow = ((~res & p1) | ((~res | p1) & z1)) >>> 63; + + p0 = u2 * qr.q; + p1 = Math.unsignedMultiplyHigh(u2, qr.q); + res = z2 - p0 - borrow; + p1 += ((~z2 & p0) | ((~z2 | p0) & res)) >>> 63; + z2 = res - p2; + borrow = ((~res & p2) | ((~res | p2) & z2)) >>> 63; + + // Propagate overflows (borrows) + res = z3 - p1 - borrow; + borrow = ((~z3 & p1) | ((~z3 | p1) & res)) >>> 63; + z3 = res; + + if (borrow != 0) { // unlikely + // Add back + res = z0 + u0; + long carry = (Long.compareUnsigned(res, z0) < 0) ? 1 : 0; + z0 = res; + res = z1 + u1 + carry; + carry = (Long.compareUnsigned(res, z1) < 0 || (u1 == -1 && carry == 1)) ? 1 : 0; + z1 = res; + res = z2 + u2 + carry; + carry = (Long.compareUnsigned(res, z2) < 0 || (u2 == -1 && carry == 1)) ? 1 : 0; + z2 = res; + res = z3 + u3 + carry; + carry = (Long.compareUnsigned(res, z3) < 0 || (u3 == -1 && carry == 1)) ? 1 : 0; + z3 = res; + + if (carry == 0) { // unlikely: add back again + // Add back + res = z0 + u0; + carry = (Long.compareUnsigned(res, z0) < 0) ? 1 : 0; + z0 = res; + res = z1 + u1 + carry; + carry = (Long.compareUnsigned(res, z1) < 0 || (u1 == -1 && carry == 1)) ? 1 : 0; + z1 = res; + res = z2 + u2 + carry; + carry = (Long.compareUnsigned(res, z2) < 0 || (u2 == -1 && carry == 1)) ? 1 : 0; + z2 = res; + z3 = z3 + u3 + carry; + } } - uLimbs[j + n] = (int) (uLimbs[j + n] + carry); } + return new UInt256(z3, z2, z1, z0); + } + + private UInt256 reduceNormalised(final UInt256 that, final int shift, final long inv) { + UInt320 v = that.shiftLeftWide(shift); + return reduceStep(v.u4, v.u3, v.u2, v.u1, v.u0, inv).shiftRight(shift); + } + + private UInt256 reduceNormalised(final UInt257 that, final int shift, final long inv) { + UInt320 v = that.shiftLeftWide(shift); + return reduceStep(v.u4, v.u3, v.u2, v.u1, v.u0, inv).shiftRight(shift); + } + + private UInt256 reduceNormalised(final UInt512 that, final int shift, final long inv) { + UInt256 r; + UInt576 v = that.shiftLeftWide(shift); + if (v.u8 != 0 || Long.compareUnsigned(v.u7, u3) >= 0) { + r = reduceStep(v.u8, v.u7, v.u6, v.u5, v.u4, inv); + r = reduceStep(r.u3, r.u2, r.u1, r.u0, v.u3, inv); + r = reduceStep(r.u3, r.u2, r.u1, r.u0, v.u2, inv); + r = reduceStep(r.u3, r.u2, r.u1, r.u0, v.u1, inv); + r = reduceStep(r.u3, r.u2, r.u1, r.u0, v.u0, inv); + } else if (v.u7 != 0 || Long.compareUnsigned(v.u6, u3) >= 0) { + r = reduceStep(v.u7, v.u6, v.u5, v.u4, v.u3, inv); + r = reduceStep(r.u3, r.u2, r.u1, r.u0, v.u2, inv); + r = reduceStep(r.u3, r.u2, r.u1, r.u0, v.u1, inv); + r = reduceStep(r.u3, r.u2, r.u1, r.u0, v.u0, inv); + } else if (v.u6 != 0 || Long.compareUnsigned(v.u5, u3) >= 0) { + r = reduceStep(v.u6, v.u5, v.u4, v.u3, v.u2, inv); + r = reduceStep(r.u3, r.u2, r.u1, r.u0, v.u1, inv); + r = reduceStep(r.u3, r.u2, r.u1, r.u0, v.u0, inv); + } else if (v.u5 != 0 || Long.compareUnsigned(v.u4, u3) >= 0) { + r = reduceStep(v.u5, v.u4, v.u3, v.u2, v.u1, inv); + r = reduceStep(r.u3, r.u2, r.u1, r.u0, v.u0, inv); + } else { + r = reduceStep(v.u4, v.u3, v.u2, v.u1, v.u0, inv); + } + return r.shiftRight(shift); } - // Unnormalize remainder - shiftRightInto(result, uLimbs, n, bitShift); - return result; } // -------------------------------------------------------------------------- - // endregion + // endregion 256bits Modulus } diff --git a/evm/src/test/java/org/hyperledger/besu/evm/UInt256Prop.java b/evm/src/test/java/org/hyperledger/besu/evm/UInt256Prop.java new file mode 100644 index 00000000000..3cfc4a7ad13 --- /dev/null +++ b/evm/src/test/java/org/hyperledger/besu/evm/UInt256Prop.java @@ -0,0 +1,273 @@ +/* + * Copyright contributors to Besu. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on + * an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the + * specific language governing permissions and limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ +package org.hyperledger.besu.evm; + +import static org.assertj.core.api.Assertions.assertThat; + +import java.math.BigInteger; +import java.util.Arrays; + +import net.jqwik.api.Arbitraries; +import net.jqwik.api.Arbitrary; +import net.jqwik.api.ForAll; +import net.jqwik.api.Property; +import net.jqwik.api.Provide; +import org.apache.tuweni.bytes.Bytes; +import org.apache.tuweni.bytes.Bytes32; + +public class UInt256Prop { + @Provide + Arbitrary unsigned1to32() { + return Arbitraries.bytes() + .array(byte[].class) + .ofMinSize(1) + .ofMaxSize(32) + .map(UInt256Prop::clampUnsigned32); + } + + @Provide + Arbitrary unsigned0to64() { + return Arbitraries.bytes() + .array(byte[].class) + .ofMinSize(0) + .ofMaxSize(64) + .map(UInt256Prop::clampUnsigned32); + } + + @Provide + Arbitrary singleLimbUnsigned1to4() { + return Arbitraries.bytes() + .array(byte[].class) + .ofMinSize(1) + .ofMaxSize(4) + .map(UInt256Prop::clampUnsigned32); + } + + @Provide + Arbitrary shifts() { + return Arbitraries.integers().between(-512, 512); + } + + @Property + void property_roundTripUnsigned_toFromBytesBE(@ForAll("unsigned0to64") final byte[] any) { + // Arrange + final byte[] be = clampUnsigned32(any); + + // Act + final UInt256 u = UInt256.fromBytesBE(be); + final byte[] back = u.toBytesBE(); + + // Assert + assertThat(back).hasSize(32); + byte[] expected = bigUnsignedToBytes32(toBigUnsigned(be)); + assertThat(back).containsExactly(expected); + } + + @Property + void property_equals_compare_consistent( + @ForAll("unsigned1to32") final byte[] a, @ForAll("unsigned1to32") final byte[] b) { + // Arrange + final UInt256 ua = UInt256.fromBytesBE(a); + final UInt256 ub = UInt256.fromBytesBE(b); + + // Act + final int cmp = UInt256.compare(ua, ub); + final boolean eq = ua.equals(ub); + + // Assert + assertThat(cmp == 0).isEqualTo(eq); + + BigInteger ba = toBigUnsigned(a); + BigInteger bb = toBigUnsigned(b); + int bc = ba.compareTo(bb); + assertThat(Integer.signum(cmp)).isEqualTo(Integer.signum(bc)); + } + + @Property + void property_mod_matchesBigInteger( + @ForAll("unsigned1to32") final byte[] a, @ForAll("unsigned1to32") final byte[] m) { + // Arrange + final UInt256 ua = UInt256.fromBytesBE(a); + final UInt256 um = UInt256.fromBytesBE(m); + + // Act + final byte[] got = ua.mod(um).toBytesBE(); + + // Assert + BigInteger A = toBigUnsigned(a); + BigInteger M = toBigUnsigned(m); + byte[] exp = (M.signum() == 0) ? Bytes32.ZERO.toArrayUnsafe() : bigUnsignedToBytes32(A.mod(M)); + assertThat(got).containsExactly(exp); + } + + @Property + void property_mod_singleLimb_matchesBigInteger( + @ForAll("singleLimbUnsigned1to4") final byte[] a, + @ForAll("singleLimbUnsigned1to4") final byte[] m) { + + // Arrange + final UInt256 ua = UInt256.fromBytesBE(a); + final UInt256 um = UInt256.fromBytesBE(m); + + // Act + final byte[] got = ua.mod(um).toBytesBE(); + + // Assert + BigInteger A = toBigUnsigned(a); + BigInteger M = toBigUnsigned(m); + byte[] exp = (M.signum() == 0) ? Bytes32.ZERO.toArrayUnsafe() : bigUnsignedToBytes32(A.mod(M)); + assertThat(got).containsExactly(exp); + } + + @Property + void property_signedMod_matchesEvmSemantics( + @ForAll("unsigned1to32") final byte[] a, @ForAll("unsigned1to32") final byte[] m) { + + // Arrange + final byte[] a32 = Bytes32.leftPad(Bytes.wrap(a)).toArrayUnsafe(); + final byte[] m32 = Bytes32.leftPad(Bytes.wrap(m)).toArrayUnsafe(); + final BigInteger A = new BigInteger(a32); + final BigInteger M = new BigInteger(m32); + final UInt256 ua = UInt256.fromBytesBE(a32); + final UInt256 um = UInt256.fromBytesBE(m32); + + // Act + byte[] got = ua.signedMod(um).toBytesBE(); + + // Assert + byte[] expected = + (M.signum() == 0) ? Bytes32.ZERO.toArrayUnsafe() : computeSignedModExpected(A, M); + + assertThat(got).containsExactly(expected); + } + + @Property + void property_addMod_matchesBigInteger( + @ForAll("unsigned1to32") final byte[] a, + @ForAll("unsigned1to32") final byte[] b, + @ForAll("unsigned1to32") final byte[] m) { + // Arrange + final UInt256 ua = UInt256.fromBytesBE(a); + final UInt256 ub = UInt256.fromBytesBE(b); + final UInt256 um = UInt256.fromBytesBE(m); + + // Act + byte[] got = ua.addMod(ub, um).toBytesBE(); + + // Assert + BigInteger A = toBigUnsigned(a); + BigInteger B = toBigUnsigned(b); + BigInteger M = toBigUnsigned(m); + byte[] exp = + (M.signum() == 0) ? Bytes32.ZERO.toArrayUnsafe() : bigUnsignedToBytes32(A.add(B).mod(M)); + assertThat(got).containsExactly(exp); + } + + @Property + void property_mulMod_matchesBigInteger( + @ForAll("unsigned1to32") final byte[] a, + @ForAll("unsigned1to32") final byte[] b, + @ForAll("unsigned1to32") final byte[] m) { + // Arrange + final UInt256 ua = UInt256.fromBytesBE(a); + final UInt256 ub = UInt256.fromBytesBE(b); + final UInt256 um = UInt256.fromBytesBE(m); + + // Act + byte[] got = ua.mulMod(ub, um).toBytesBE(); + + // Assert + BigInteger A = toBigUnsigned(a); + BigInteger B = toBigUnsigned(b); + BigInteger M = toBigUnsigned(m); + byte[] exp = + (M.signum() == 0) + ? Bytes32.ZERO.toArrayUnsafe() + : bigUnsignedToBytes32(A.multiply(B).mod(M)); + assertThat(got).containsExactly(exp); + } + + @Property + void property_divByZero_invariants() { + // Arrange + UInt256 x = UInt256.fromBytesBE(new byte[] {1, 2, 3, 4}); + UInt256 zero = UInt256.ZERO; + + // Act & Assert + assertThat(x.mod(zero).toBytesBE()).containsExactly(Bytes32.ZERO.toArrayUnsafe()); + assertThat(x.signedMod(zero).toBytesBE()).containsExactly(Bytes32.ZERO.toArrayUnsafe()); + assertThat(x.addMod(x, zero).toBytesBE()).containsExactly(Bytes32.ZERO.toArrayUnsafe()); + assertThat(x.mulMod(x, zero).toBytesBE()).containsExactly(Bytes32.ZERO.toArrayUnsafe()); + } + + private static byte[] clampUnsigned32(final byte[] any) { + if (any.length == 0) { + return new byte[] {0}; + } + int len = Math.max(1, Math.min(32, any.length)); + byte[] out = new byte[len]; + System.arraycopy(any, 0, out, 0, len); + return out; + } + + private static byte[] bigUnsignedToBytes32(final BigInteger x) { + BigInteger y = x.mod(BigInteger.ONE.shiftLeft(256)); + + byte[] ba = y.toByteArray(); + if (ba.length == 0) { + return new byte[32]; + } + + if (ba.length == 32) { + return ba; + } + + if (ba.length < 32) { + byte[] out = new byte[32]; + System.arraycopy(ba, 0, out, 32 - ba.length, ba.length); + return out; + } + + // If bigger than 32, take lower 32 bytes. + byte[] out = new byte[32]; + System.arraycopy(ba, ba.length - 32, out, 0, 32); + + return out; + } + + private static BigInteger toBigUnsigned(final byte[] be) { + return new BigInteger(1, be); + } + + private static byte[] computeSignedModExpected(final BigInteger A, final BigInteger M) { + + BigInteger r = A.abs().mod(M.abs()); + + if (A.signum() < 0 && r.signum() != 0) { + return padNegative(r); + } + + return bigUnsignedToBytes32(r); + } + + private static byte[] padNegative(final BigInteger r) { + BigInteger neg = r.negate(); + byte[] rb = neg.toByteArray(); + byte[] padded = new byte[32]; + Arrays.fill(padded, (byte) 0xFF); + System.arraycopy(rb, 0, padded, 32 - rb.length, rb.length); + return padded; + } +} diff --git a/evm/src/test/java/org/hyperledger/besu/evm/UInt256PropertyBasedTest.java b/evm/src/test/java/org/hyperledger/besu/evm/UInt256PropertyBasedTest.java index 97744492bbe..3cf8d04cc20 100644 --- a/evm/src/test/java/org/hyperledger/besu/evm/UInt256PropertyBasedTest.java +++ b/evm/src/test/java/org/hyperledger/besu/evm/UInt256PropertyBasedTest.java @@ -123,6 +123,10 @@ void property_mod_matchesBigInteger( BigInteger A = toBigUnsigned(a); BigInteger M = toBigUnsigned(m); byte[] exp = (M.signum() == 0) ? Bytes32.ZERO.toArrayUnsafe() : bigUnsignedToBytes32(A.mod(M)); + if (!Arrays.equals(got, exp)) + System.out.println( + String.format( + "%s %% %s == %s", ua.toHexString(), um.toHexString(), ua.mod(um).toHexString())); assertThat(got).containsExactly(exp); } @@ -511,7 +515,7 @@ void property_xor_involutive( void property_xor_with_allOnes_is_complement(@ForAll("unsigned1to32") final byte[] a) { // Arrange final UInt256 ua = UInt256.fromBytesBE(a); - final UInt256 allOnes = UInt256.ALL_ONES; + final UInt256 allOnes = UInt256.MAX; // Act final UInt256 result = ua.xor(allOnes); @@ -574,7 +578,7 @@ void property_xor_specific_patterns() { (byte) 0x55, (byte) 0x55, (byte) 0x55, (byte) 0x55 }); // 01010101... - final UInt256 allOnes = UInt256.ALL_ONES; + final UInt256 allOnes = UInt256.MAX; // Act & Assert - 0xAA XOR 0x55 = 0xFF assertThat(pattern1.xor(pattern2)).isEqualTo(allOnes); @@ -687,7 +691,7 @@ void property_or_idempotent(@ForAll("unsigned1to32") final byte[] a) { void property_or_with_allOnes(@ForAll("unsigned1to32") final byte[] a) { // Arrange final UInt256 ua = UInt256.fromBytesBE(a); - final UInt256 allOnes = UInt256.ALL_ONES; + final UInt256 allOnes = UInt256.MAX; // Act & Assert - A | 0xFF...FF = 0xFF...FF (domination) assertThat(ua.or(allOnes)).isEqualTo(allOnes); assertThat(allOnes.or(ua)).isEqualTo(allOnes); @@ -742,7 +746,7 @@ void property_or_with_complement_is_allOnes(@ForAll("unsigned1to32") final byte[ complementBytes[i] = (byte) ~aBytes32[i]; } final UInt256 complement = UInt256.fromBytesBE(complementBytes); - final UInt256 allOnes = UInt256.ALL_ONES; + final UInt256 allOnes = UInt256.MAX; // Act & Assert - A | ~A = 0xFF...FF assertThat(ua.or(complement)).isEqualTo(allOnes); } @@ -774,7 +778,7 @@ void property_or_specific_patterns() { (byte) 0x55, (byte) 0x55, (byte) 0x55, (byte) 0x55, (byte) 0x55, (byte) 0x55, (byte) 0x55, (byte) 0x55 }); // 01010101... - final UInt256 allOnes = UInt256.ALL_ONES; + final UInt256 allOnes = UInt256.MAX; // Act & Assert - 0xAA OR 0x55 = 0xFF assertThat(pattern1.or(pattern2)).isEqualTo(allOnes); // Verify with Bytes implementation @@ -896,7 +900,7 @@ void property_not_involutive(@ForAll("unsigned1to32") final byte[] a) { void property_not_different_from_original(@ForAll("unsigned1to32") final byte[] a) { // Arrange final UInt256 ua = UInt256.fromBytesBE(a); - final UInt256 allOnes = UInt256.ALL_ONES; + final UInt256 allOnes = UInt256.MAX; // Act final UInt256 notA = ua.not(); @@ -922,7 +926,7 @@ void property_not_with_or_is_allOnes(@ForAll("unsigned1to32") final byte[] a) { // Arrange final UInt256 ua = UInt256.fromBytesBE(a); // Act & Assert - A | ~A = 0xFF...FF - assertThat(ua.or(ua.not())).isEqualTo(UInt256.ALL_ONES); + assertThat(ua.or(ua.not())).isEqualTo(UInt256.MAX); } @Property @@ -930,7 +934,7 @@ void property_not_with_xor_is_allOnes(@ForAll("unsigned1to32") final byte[] a) { // Arrange final UInt256 ua = UInt256.fromBytesBE(a); // Act & Assert - A ^ ~A = 0xFF...FF - assertThat(ua.xor(ua.not())).isEqualTo(UInt256.ALL_ONES); + assertThat(ua.xor(ua.not())).isEqualTo(UInt256.MAX); } @Property @@ -963,7 +967,7 @@ void property_not_de_morgans_or( void property_not_zero_is_allOnes() { // Arrange final UInt256 zero = UInt256.ZERO; - final UInt256 allOnes = UInt256.ALL_ONES; + final UInt256 allOnes = UInt256.MAX; // Act & Assert - ~0 = 0xFF...FF assertThat(zero.not()).isEqualTo(allOnes); @@ -972,7 +976,7 @@ void property_not_zero_is_allOnes() { @Property void property_not_allOnes_is_zero() { // Arrange - final UInt256 allOnes = UInt256.ALL_ONES; + final UInt256 allOnes = UInt256.MAX; final UInt256 zero = UInt256.ZERO; // Act & Assert - ~0xFF...FF = 0 @@ -1036,7 +1040,7 @@ void property_not_each_bit_flipped(@ForAll("unsigned1to32") final byte[] a) { void property_not_xor_equivalence(@ForAll("unsigned1to32") final byte[] a) { // Arrange final UInt256 ua = UInt256.fromBytesBE(a); - final UInt256 allOnes = UInt256.ALL_ONES; + final UInt256 allOnes = UInt256.MAX; // Act & Assert - ~A = A ^ 0xFF...FF assertThat(ua.not()).isEqualTo(ua.xor(allOnes)); @@ -1046,7 +1050,7 @@ void property_not_xor_equivalence(@ForAll("unsigned1to32") final byte[] a) { void property_not_sum_with_original(@ForAll("unsigned1to32") final byte[] a) { // Arrange final UInt256 ua = UInt256.fromBytesBE(a); - final UInt256 allOnes = UInt256.ALL_ONES; + final UInt256 allOnes = UInt256.MAX; // Act // When we add A + ~A (bitwise), we should get all 1s in each bit position diff --git a/evm/src/test/java/org/hyperledger/besu/evm/UInt256Test.java b/evm/src/test/java/org/hyperledger/besu/evm/UInt256Test.java index 080b684ac34..16320fb54c4 100644 --- a/evm/src/test/java/org/hyperledger/besu/evm/UInt256Test.java +++ b/evm/src/test/java/org/hyperledger/besu/evm/UInt256Test.java @@ -15,7 +15,6 @@ package org.hyperledger.besu.evm; import static org.assertj.core.api.Assertions.assertThat; -import static org.junit.jupiter.api.Assertions.assertArrayEquals; import java.math.BigInteger; import java.util.Arrays; @@ -24,24 +23,23 @@ import org.apache.tuweni.bytes.Bytes; import org.apache.tuweni.bytes.Bytes32; import org.junit.jupiter.api.Test; -import org.junit.jupiter.params.ParameterizedTest; -import org.junit.jupiter.params.provider.ValueSource; public class UInt256Test { - static final int SAMPLE_SIZE = 300; + static final int SAMPLE_SIZE = 3; - private Bytes32 bigIntTo32B(final BigInteger x) { - byte[] a = x.toByteArray(); + private Bytes32 bigIntTo32B(final BigInteger y) { + byte[] a = y.toByteArray(); if (a.length > 32) return Bytes32.wrap(a, a.length - 32); return Bytes32.leftPad(Bytes.wrap(a)); } - private Bytes32 bigIntToSigned32B(final BigInteger x) { - if (x.signum() >= 0) return bigIntTo32B(x); + private Bytes32 bigIntTo32B(final BigInteger x, final int sign) { + if (sign >= 0) return bigIntTo32B(x); byte[] a = new byte[32]; Arrays.fill(a, (byte) 0xFF); byte[] b = x.toByteArray(); System.arraycopy(b, 0, a, 32 - b.length, b.length); + if (a.length > 32) return Bytes32.wrap(a, a.length - 32); return Bytes32.leftPad(Bytes.wrap(a)); } @@ -63,22 +61,22 @@ public void fromInts() { public void fromBytesBE() { byte[] input; UInt256 result; - int[] expectedLimbs; + UInt256 expected; input = new byte[] {-128, 0, 0, 0}; result = UInt256.fromBytesBE(input); - expectedLimbs = new int[] {-2147483648, 0, 0, 0, 0, 0, 0, 0}; - assertThat(result.limbs()).as("4b-neg-limbs").isEqualTo(expectedLimbs); + expected = new UInt256(0, 0, 0, 2147483648L); + assertThat(result).as("4b-neg-limbs").isEqualTo(expected); input = new byte[] {0, 0, 1, 1, 1}; result = UInt256.fromBytesBE(input); - expectedLimbs = new int[] {1 + 256 + 65536, 0, 0, 0, 0, 0, 0, 0}; - assertThat(result.limbs()).as("3b-limbs").isEqualTo(expectedLimbs); + expected = new UInt256(0, 0, 0, 1 + 256 + 65536); + assertThat(result).as("3b-limbs").isEqualTo(expected); - input = new byte[] {1, 0, 0, 0, 0, 1, 1, 1}; + input = new byte[] {1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1}; result = UInt256.fromBytesBE(input); - expectedLimbs = new int[] {1 + 256 + 65536, 16777216, 0, 0, 0, 0, 0, 0}; - assertThat(result.limbs()).as("8b-limbs").isEqualTo(expectedLimbs); + expected = new UInt256(0, 0, 16777216, 1 + 256 + 65536); + assertThat(result).as("8b-limbs").isEqualTo(expected); input = new byte[] { @@ -86,8 +84,8 @@ public void fromBytesBE() { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 }; result = UInt256.fromBytesBE(input); - expectedLimbs = new int[] {0, 0, 0, 0, 0, 0, 0, 16777216}; - assertThat(result.limbs()).as("32b-limbs").isEqualTo(expectedLimbs); + expected = new UInt256(72057594037927936L, 0, 0, 0); + assertThat(result).as("32b-limbs").isEqualTo(expected); input = new byte[] { @@ -95,8 +93,15 @@ public void fromBytesBE() { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 }; result = UInt256.fromBytesBE(input); - expectedLimbs = new int[] {0, 0, 0, 0, 0, 0, 257, 0}; - assertThat(result.limbs()).as("32b-padded-limbs").isEqualTo(expectedLimbs); + expected = new UInt256(257, 0, 0, 0); + assertThat(result).as("32b-padded-limbs").isEqualTo(expected); + + Bytes inputBytes = + Bytes.fromHexString("0x000000000000000000000000ffffffffffffffffffffffffffffffffffffffff"); + input = inputBytes.toArrayUnsafe(); + result = UInt256.fromBytesBE(input); + expected = new UInt256(0, 4294967295L, -1L, -1L); + assertThat(result).as("32b-case2-limbs").isEqualTo(expected); } @Test @@ -208,6 +213,112 @@ public void modB() { assertThat(remainder).isEqualTo(expected); } + @Test + public void modC() { + BigInteger big_number = new BigInteger("1000000000000000000000000000000000000000000000000", 16); + BigInteger big_modulus = new BigInteger("ff00000000000000", 16); + UInt256 number = UInt256.fromBytesBE(big_number.toByteArray()); + UInt256 modulus = UInt256.fromBytesBE(big_modulus.toByteArray()); + Bytes32 remainder = Bytes32.leftPad(Bytes.wrap(number.mod(modulus).toBytesBE())); + Bytes32 expected = Bytes32.leftPad(Bytes.wrap(big_number.mod(big_modulus).toByteArray())); + assertThat(remainder).isEqualTo(expected); + } + + @Test + public void modD() { + BigInteger big_number = new BigInteger("ff00000000000000000000000000000000", 16); + BigInteger big_modulus = new BigInteger("100000000000000000000000000000000", 16); + UInt256 number = UInt256.fromBytesBE(big_number.toByteArray()); + UInt256 modulus = UInt256.fromBytesBE(big_modulus.toByteArray()); + Bytes32 remainder = Bytes32.leftPad(Bytes.wrap(number.mod(modulus).toBytesBE())); + Bytes32 expected = Bytes32.leftPad(Bytes.wrap(big_number.mod(big_modulus).toByteArray())); + assertThat(remainder).isEqualTo(expected); + } + + @Test + public void modE() { + BigInteger big_number = new BigInteger("ff00000000000000000000000000000000", 16); + BigInteger big_modulus = new BigInteger("100000000000000000000000000000001", 16); + UInt256 number = UInt256.fromBytesBE(big_number.toByteArray()); + UInt256 modulus = UInt256.fromBytesBE(big_modulus.toByteArray()); + Bytes32 remainder = Bytes32.leftPad(Bytes.wrap(number.mod(modulus).toBytesBE())); + Bytes32 expected = Bytes32.leftPad(Bytes.wrap(big_number.mod(big_modulus).toByteArray())); + assertThat(remainder).isEqualTo(expected); + } + + @Test + public void modF() { + BigInteger big_number = new BigInteger("1000000000000000000000000000000000000000000000000", 16); + BigInteger big_modulus = new BigInteger("ff000000000000000000000000000000", 16); + UInt256 number = UInt256.fromBytesBE(big_number.toByteArray()); + UInt256 modulus = UInt256.fromBytesBE(big_modulus.toByteArray()); + Bytes32 remainder = Bytes32.leftPad(Bytes.wrap(number.mod(modulus).toBytesBE())); + Bytes32 expected = Bytes32.leftPad(Bytes.wrap(big_number.mod(big_modulus).toByteArray())); + assertThat(remainder).isEqualTo(expected); + } + + @Test + public void modG() { + BigInteger big_number = new BigInteger("1000000000000000000000000000000000000000000000000", 16); + BigInteger big_modulus = new BigInteger("100000000000000000000000000000001", 16); + UInt256 number = UInt256.fromBytesBE(big_number.toByteArray()); + UInt256 modulus = UInt256.fromBytesBE(big_modulus.toByteArray()); + Bytes32 remainder = Bytes32.leftPad(Bytes.wrap(number.mod(modulus).toBytesBE())); + Bytes32 expected = Bytes32.leftPad(Bytes.wrap(big_number.mod(big_modulus).toByteArray())); + assertThat(remainder).isEqualTo(expected); + } + + @Test + public void modH() { + BigInteger big_number = + new BigInteger("000000000000000000ff00000000000000000000000000000000000000000000", 16); + BigInteger big_modulus = + new BigInteger("0000000000000000000000000000000000fe0000000000000000000000000001", 16); + UInt256 number = UInt256.fromBytesBE(big_number.toByteArray()); + UInt256 modulus = UInt256.fromBytesBE(big_modulus.toByteArray()); + Bytes32 remainder = Bytes32.leftPad(Bytes.wrap(number.mod(modulus).toBytesBE())); + Bytes32 expected = Bytes32.leftPad(Bytes.wrap(big_number.mod(big_modulus).toByteArray())); + assertThat(remainder).isEqualTo(expected); + } + + @Test + public void modI() { + // modulus 128 with overflow case + BigInteger big_number = new BigInteger("020000000000000000000000000000000000", 16); + BigInteger big_modulus = new BigInteger("02000000000000000000", 16); + UInt256 number = UInt256.fromBytesBE(big_number.toByteArray()); + UInt256 modulus = UInt256.fromBytesBE(big_modulus.toByteArray()); + Bytes32 remainder = Bytes32.leftPad(Bytes.wrap(number.mod(modulus).toBytesBE())); + Bytes32 expected = Bytes32.leftPad(Bytes.wrap(big_number.mod(big_modulus).toByteArray())); + assertThat(remainder).isEqualTo(expected); + } + + @Test + public void modJ() { + // modulus 128 with overflow case -> 2 add back in quotient estimate div2by1. + BigInteger big_number = new BigInteger("10000000000000000010000000000000000", 16); + BigInteger big_modulus = new BigInteger("200000000000000ff", 16); + UInt256 number = UInt256.fromBytesBE(big_number.toByteArray()); + UInt256 modulus = UInt256.fromBytesBE(big_modulus.toByteArray()); + Bytes32 remainder = Bytes32.leftPad(Bytes.wrap(number.mod(modulus).toBytesBE())); + Bytes32 expected = Bytes32.leftPad(Bytes.wrap(big_number.mod(big_modulus).toByteArray())); + assertThat(remainder).isEqualTo(expected); + } + + @Test + public void modK() { + // modulus 128 with overflow case -> 2 add back in quotient estimate div2by1. + BigInteger big_number = + new BigInteger("ff000000000000000000000000000000000000000000000000000000", 16); + BigInteger big_modulus = + new BigInteger("1000000000000000000000002000000000000000000000000", 16); + UInt256 number = UInt256.fromBytesBE(big_number.toByteArray()); + UInt256 modulus = UInt256.fromBytesBE(big_modulus.toByteArray()); + Bytes32 remainder = Bytes32.leftPad(Bytes.wrap(number.mod(modulus).toBytesBE())); + Bytes32 expected = Bytes32.leftPad(Bytes.wrap(big_number.mod(big_modulus).toByteArray())); + assertThat(remainder).isEqualTo(expected); + } + @Test public void modGeneralState() { BigInteger big_number = new BigInteger("cea0c5cc171fa61277e5604a3bc8aef4de3d3882", 16); @@ -268,6 +379,30 @@ public void referenceTest459() { assertThat(remainder).isEqualTo(expected); } + @Test + public void ExecutionSpecStateTest_453() { + byte[] xArr = + new byte[] { + -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, + -1, -1, -1, -1, -1, -1, -1, -1, -1, -2 + }; + byte[] mArr = + new byte[] { + -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, + -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 + }; + BigInteger xbig = new BigInteger(1, xArr); + BigInteger ybig = new BigInteger(1, xArr); + BigInteger mbig = new BigInteger(1, mArr); + UInt256 x = UInt256.fromBytesBE(xArr); + UInt256 y = UInt256.fromBytesBE(xArr); + UInt256 m = UInt256.fromBytesBE(mArr); + Bytes32 remainder = Bytes32.leftPad(Bytes.wrap(x.addMod(y, m).toBytesBE())); + Bytes32 expected = + BigInteger.ZERO.compareTo(mbig) == 0 ? Bytes32.ZERO : bigIntTo32B(xbig.add(ybig).mod(mbig)); + assertThat(remainder).isEqualTo(expected); + } + @Test public void addMod() { final Random random = new Random(42); @@ -292,10 +427,47 @@ public void addMod() { BigInteger.ZERO.compareTo(cInt) == 0 ? Bytes32.ZERO : bigIntTo32B(aInt.add(bInt).mod(cInt)); + if (!remainder.equals(expected)) + System.out.println(String.format("%s + %s == %s (mod %s)", a, b, a.add(b), c)); assertThat(remainder).isEqualTo(expected); } } + @Test + public void mulMod_ExecutionSpecStateTest_457() { + Bytes value0 = + Bytes.fromHexString("0x000000000000000000000000ffffffffffffffffffffffffffffffffffffffff"); + Bytes value1 = + Bytes.fromHexString("0xfffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffe"); + Bytes value2 = + Bytes.fromHexString("0x000000000000000000000000ffffffffffffffffffffffffffffffffffffffff"); + BigInteger aInt = new BigInteger(1, value0.toArrayUnsafe()); + BigInteger bInt = new BigInteger(1, value1.toArrayUnsafe()); + BigInteger cInt = new BigInteger(1, value2.toArrayUnsafe()); + UInt256 a = UInt256.fromBytesBE(value0.toArrayUnsafe()); + UInt256 b = UInt256.fromBytesBE(value1.toArrayUnsafe()); + UInt256 c = UInt256.fromBytesBE(value2.toArrayUnsafe()); + Bytes32 remainder = Bytes32.leftPad(Bytes.wrap(a.mulMod(b, c).toBytesBE())); + Bytes32 expected = bigIntTo32B(aInt.multiply(bInt).mod(cInt)); + assertThat(remainder).isEqualTo(expected); + + value0 = + Bytes.fromHexString("0xfffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffe"); + value1 = + Bytes.fromHexString("0xffffffffffffffffffffffffb195148ca348dc57a7331852b390ccefa7b0c18b"); + value2 = + Bytes.fromHexString("0xfffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffe"); + aInt = new BigInteger(1, value0.toArrayUnsafe()); + bInt = new BigInteger(1, value1.toArrayUnsafe()); + cInt = new BigInteger(1, value2.toArrayUnsafe()); + a = UInt256.fromBytesBE(value0.toArrayUnsafe()); + b = UInt256.fromBytesBE(value1.toArrayUnsafe()); + c = UInt256.fromBytesBE(value2.toArrayUnsafe()); + remainder = Bytes32.leftPad(Bytes.wrap(a.mulMod(b, c).toBytesBE())); + expected = bigIntTo32B(aInt.multiply(bInt).mod(cInt)); + assertThat(remainder).isEqualTo(expected); + } + @Test public void mulMod() { final Random random = new Random(123); @@ -320,40 +492,34 @@ public void mulMod() { BigInteger.ZERO.compareTo(cInt) == 0 ? Bytes32.ZERO : bigIntTo32B(aInt.multiply(bInt).mod(cInt)); + if (!remainder.equals(expected)) + System.out.println( + String.format("%s * %s (mod %s)", a.toHexString(), b.toHexString(), c.toHexString())); assertThat(remainder).isEqualTo(expected); } } - @Test - public void signedMod_no_padding() { - Bytes aBytes = - Bytes.fromHexString("0xe8e8e8e2000100000009ea02000000000000ff3ffffff80000001000220000"); - Bytes bBytes = - Bytes.fromHexString("0x8000000000000000000000000000000000000000000000000000000000000000"); - Bytes32 expected = - Bytes32.leftPad( - Bytes.fromHexString( - "0x00e8e8e8e2000100000009ea02000000000000ff3ffffff80000001000220000")); - UInt256 a = UInt256.fromBytesBE(aBytes.toArrayUnsafe()); - UInt256 b = UInt256.fromBytesBE(bBytes.toArrayUnsafe()); - Bytes32 remainder = Bytes32.leftPad(Bytes.wrap(a.signedMod(b).toBytesBE())); - assertThat(remainder).isEqualTo(expected); - } - @Test public void signedMod() { final Random random = new Random(432); for (int i = 0; i < SAMPLE_SIZE; i++) { int aSize = random.nextInt(1, 33); int bSize = random.nextInt(1, 33); + boolean neg = random.nextBoolean(); byte[] aArray = new byte[aSize]; byte[] bArray = new byte[bSize]; random.nextBytes(aArray); random.nextBytes(bArray); + if ((aSize < 32) && (neg)) { + byte[] tmp = new byte[32]; + Arrays.fill(tmp, (byte) 0xFF); + System.arraycopy(aArray, 0, tmp, 32 - aArray.length, aArray.length); + aArray = tmp; + } UInt256 a = UInt256.fromBytesBE(aArray); UInt256 b = UInt256.fromBytesBE(bArray); - BigInteger aInt = aArray.length < 32 ? new BigInteger(1, aArray) : new BigInteger(aArray); - BigInteger bInt = bArray.length < 32 ? new BigInteger(1, bArray) : new BigInteger(bArray); + BigInteger aInt = a.isNegative() ? new BigInteger(aArray) : new BigInteger(1, aArray); + BigInteger bInt = b.isNegative() ? new BigInteger(bArray) : new BigInteger(1, bArray); Bytes32 remainder = Bytes32.leftPad(Bytes.wrap(a.signedMod(b).toBytesBE())); Bytes32 expected; BigInteger rem = BigInteger.ZERO; @@ -362,285 +528,12 @@ public void signedMod() { rem = aInt.abs().mod(bInt.abs()); if ((aInt.compareTo(BigInteger.ZERO) < 0) && (rem.compareTo(BigInteger.ZERO) != 0)) { rem = rem.negate(); - expected = bigIntToSigned32B(rem); + expected = bigIntTo32B(rem, -1); } else { - expected = bigIntTo32B(rem); + expected = bigIntTo32B(rem, 1); } } assertThat(remainder).isEqualTo(expected); } } - - @Test - void testFromBytesBE_emptyArray() { - UInt256 result = UInt256.fromBytesBE(new byte[0]); - assertThat(result).isEqualTo(UInt256.ZERO); - assertThat(result.isZero()).isTrue(); - } - - @Test - void testFromBytesBE_singleZeroByte() { - UInt256 result = UInt256.fromBytesBE(new byte[] {0}); - assertThat(result).isEqualTo(UInt256.ZERO); - assertThat(result.intValue()).isEqualTo(0); - } - - @Test - void testFromBytesBE_singleByte() { - UInt256 result = UInt256.fromBytesBE(new byte[] {0x42}); - assertThat(result.intValue()).isEqualTo(0x42); - assertThat(result.longValue()).isEqualTo(0x42L); - } - - @Test - void testFromBytesBE_twoBytesFF() { - UInt256 result = UInt256.fromBytesBE(new byte[] {(byte) 0xFF, (byte) 0xFF}); - assertThat(result.intValue()).isEqualTo(0xFFFF); - assertThat(result.longValue()).isEqualTo(0xFFFFL); - } - - @Test - void testFromBytesBE_fourBytes() { - UInt256 result = UInt256.fromBytesBE(new byte[] {0x01, 0x02, 0x03, 0x04}); - assertThat(result.intValue()).isEqualTo(0x01020304); - } - - @Test - void testFromBytesBE_eightBytes() { - UInt256 result = - UInt256.fromBytesBE(new byte[] {0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08}); - assertThat(result.longValue()).isEqualTo(0x0102030405060708L); - } - - @Test - void testFromBytesBE_exactly32Bytes_allZeros() { - byte[] bytes = new byte[32]; // all zeros - UInt256 result = UInt256.fromBytesBE(bytes); - assertThat(result).isEqualTo(UInt256.ZERO); - assertThat(result.isZero()).isTrue(); - } - - @Test - void testFromBytesBE_exactly32Bytes_allOnes() { - byte[] bytes = new byte[32]; - for (int i = 0; i < 32; i++) { - bytes[i] = (byte) 0xFF; - } - UInt256 result = UInt256.fromBytesBE(bytes); - - // Should be MAX_UINT256 (2^256 - 1) - byte[] resultBytes = result.toBytesBE(); - assertArrayEquals(bytes, resultBytes); - } - - @Test - void testFromBytesBE_exactly32Bytes_one() { - byte[] bytes = new byte[32]; - bytes[31] = 0x01; // least significant byte - UInt256 result = UInt256.fromBytesBE(bytes); - - assertThat(result.intValue()).isEqualTo(1); - assertThat(result.longValue()).isEqualTo(1L); - } - - @Test - void testFromBytesBE_exactly32Bytes_pattern() { - byte[] bytes = new byte[32]; - // Create pattern: 0x0102030405060708...1F20 - for (int i = 0; i < 32; i++) { - bytes[i] = (byte) (i + 1); - } - UInt256 result = UInt256.fromBytesBE(bytes); - - // Verify round-trip - byte[] resultBytes = result.toBytesBE(); - assertArrayEquals(bytes, resultBytes); - } - - @Test - void testFromBytesBE_exactly32Bytes_highBitSet() { - byte[] bytes = new byte[32]; - bytes[0] = (byte) 0x80; // high bit set (but still unsigned) - UInt256 result = UInt256.fromBytesBE(bytes); - - // Verify it's treated as unsigned (not negative) - byte[] resultBytes = result.toBytesBE(); - assertArrayEquals(bytes, resultBytes); - } - - @Test - void testFromBytesBE_roundTrip_variousLengths() { - for (int len = 1; len <= 32; len++) { - byte[] original = new byte[len]; - for (int i = 0; i < len; i++) { - original[i] = (byte) (i + 1); - } - - UInt256 value = UInt256.fromBytesBE(original); - byte[] result = value.toBytesBE(); - - // Result is always 32 bytes, so compare with left-padded original - byte[] expected = new byte[32]; - System.arraycopy(original, 0, expected, 32 - len, len); - - assertArrayEquals(expected, result, "Failed for length " + len); - } - } - - @Test - void testFromBytesBE_leadingZeros() { - // Leading zeros should be handled correctly - byte[] bytes = new byte[] {0x00, 0x00, 0x00, 0x01, 0x02, 0x03}; - UInt256 result = UInt256.fromBytesBE(bytes); - - assertThat(result.intValue()).isEqualTo(0x010203); - } - - @Test - void testFromBytesBE_maxInt() { - byte[] bytes = new byte[] {0x7F, (byte) 0xFF, (byte) 0xFF, (byte) 0xFF}; - UInt256 result = UInt256.fromBytesBE(bytes); - - assertThat(result.intValue()).isEqualTo(Integer.MAX_VALUE); - } - - @Test - void testFromBytesBE_maxLong() { - byte[] bytes = - new byte[] { - 0x7F, - (byte) 0xFF, - (byte) 0xFF, - (byte) 0xFF, - (byte) 0xFF, - (byte) 0xFF, - (byte) 0xFF, - (byte) 0xFF - }; - UInt256 result = UInt256.fromBytesBE(bytes); - - assertThat(result.longValue()).isEqualTo(Long.MAX_VALUE); - } - - @Test - void testFromBytesBE_unsignedIntMax() { - // 0xFFFFFFFF as unsigned = 4294967295 - byte[] bytes = new byte[] {(byte) 0xFF, (byte) 0xFF, (byte) 0xFF, (byte) 0xFF}; - UInt256 result = UInt256.fromBytesBE(bytes); - - assertThat(result.longValue()).isEqualTo(0xFFFFFFFFL); - } - - @Test - void testFromBytesBE_unsignedLongMax() { - // 0xFFFFFFFFFFFFFFFF as unsigned - byte[] bytes = - new byte[] { - (byte) 0xFF, (byte) 0xFF, (byte) 0xFF, (byte) 0xFF, - (byte) 0xFF, (byte) 0xFF, (byte) 0xFF, (byte) 0xFF - }; - UInt256 result = UInt256.fromBytesBE(bytes); - - // When converted back to long, should get the bit pattern - assertThat(result.longValue()).isEqualTo(-1L); // all bits set - } - - @Test - void testFromBytesBE_boundaryValues() { - // Test 1, 2, 3, 4, 8, 16, 32 bytes - int[] lengths = {1, 2, 3, 4, 8, 16, 32}; - - for (int len : lengths) { - byte[] bytes = new byte[len]; - bytes[len - 1] = (byte) 0xFF; // set last byte - - UInt256 result = UInt256.fromBytesBE(bytes); - assertThat(result.intValue() & 0xFF).isEqualTo(0xFF); - } - } - - @Test - void testFromBytesBE_comparisonWithBigInteger() { - byte[] bytes = - new byte[] {0x12, 0x34, 0x56, 0x78, (byte) 0x9A, (byte) 0xBC, (byte) 0xDE, (byte) 0xF0}; - - UInt256 result = UInt256.fromBytesBE(bytes); - java.math.BigInteger expected = new java.math.BigInteger(1, bytes); - - assertThat(result.toBigInteger()).isEqualTo(expected); - } - - @ParameterizedTest - @ValueSource(ints = {0, 1, 127, 128, 255, 256, 65535, 65536, Integer.MAX_VALUE}) - void testFromBytesBE_knownIntegers(final int value) { - // Convert int to bytes (big-endian) - byte[] bytes = new byte[4]; - bytes[0] = (byte) (value >>> 24); - bytes[1] = (byte) (value >>> 16); - bytes[2] = (byte) (value >>> 8); - bytes[3] = (byte) value; - - UInt256 result = UInt256.fromBytesBE(bytes); - assertThat(result.intValue()).isEqualTo(value); - } - - @Test - void testFromBytesBE_powerOfTwo() { - // Test 2^8, 2^16, 2^32, 2^64, 2^128, 2^255 - - // 2^8 = 256 - byte[] bytes8 = new byte[] {0x01, 0x00}; - assertThat(UInt256.fromBytesBE(bytes8).intValue()).isEqualTo(256); - - // 2^16 = 65536 - byte[] bytes16 = new byte[] {0x01, 0x00, 0x00}; - assertThat(UInt256.fromBytesBE(bytes16).intValue()).isEqualTo(65536); - - // 2^32 - byte[] bytes32 = new byte[] {0x01, 0x00, 0x00, 0x00, 0x00}; - assertThat(UInt256.fromBytesBE(bytes32).longValue()).isEqualTo(0x100000000L); - } - - @Test - void testFromBytesBE_alternatingPattern() { - // 0xAA pattern - byte[] bytesAA = new byte[32]; - for (int i = 0; i < 32; i++) { - bytesAA[i] = (byte) 0xAA; - } - UInt256 resultAA = UInt256.fromBytesBE(bytesAA); - assertArrayEquals(bytesAA, resultAA.toBytesBE()); - - // 0x55 pattern - byte[] bytes55 = new byte[32]; - for (int i = 0; i < 32; i++) { - bytes55[i] = (byte) 0x55; - } - UInt256 result55 = UInt256.fromBytesBE(bytes55); - assertArrayEquals(bytes55, result55.toBytesBE()); - } - - @Test - void testFromBytesBE_consistency() { - // Verify same bytes always produce same result - byte[] bytes = new byte[] {0x01, 0x02, 0x03, 0x04, 0x05}; - - UInt256 result1 = UInt256.fromBytesBE(bytes); - UInt256 result2 = UInt256.fromBytesBE(bytes); - - assertThat(result1).isEqualTo(result2); - assertThat(result1.hashCode()).isEqualTo(result2.hashCode()); - } - - @Test - void testFromBytesBE_differentLengthsSameValue() { - // Leading zeros should not affect value - byte[] bytes1 = new byte[] {0x01, 0x02, 0x03}; - byte[] bytes2 = new byte[] {0x00, 0x00, 0x01, 0x02, 0x03}; - - UInt256 result1 = UInt256.fromBytesBE(bytes1); - UInt256 result2 = UInt256.fromBytesBE(bytes2); - - assertThat(result1).isEqualTo(result2); - } }