diff --git a/decimal.go b/decimal.go index 880d569..8464195 100644 --- a/decimal.go +++ b/decimal.go @@ -1361,6 +1361,21 @@ func (d Decimal) LessThanOrEqual(d2 Decimal) bool { return cmp == -1 || cmp == 0 } +// Clamp returns min if d is less than min, max if d is greater than max, +// and d otherwise. If min > max, it will panic with an error message. +func (d Decimal) Clamp(min, max Decimal) Decimal { + if min.GreaterThan(max) { + panic(fmt.Sprintf("decimal: min (%s) is greater than max (%s)", min.String(), max.String())) + } + if d.LessThan(min) { + return min + } + if d.GreaterThan(max) { + return max + } + return d +} + // Sign returns: // // -1 if d < 0 diff --git a/decimal_test.go b/decimal_test.go index 25d95bc..16dd97c 100644 --- a/decimal_test.go +++ b/decimal_test.go @@ -3962,3 +3962,231 @@ func ExampleNewFromFloat() { //0.123123123123123 //-10000000000000 } + +func TestDecimal_RoundBank(t *testing.T) { + type testData struct { + value string + places int32 + expected string + } + + tests := []testData{ + {"2.5", 0, "2"}, + {"3.5", 0, "4"}, + {"4.5", 0, "4"}, + {"5.5", 0, "6"}, + {"2.45", 1, "2.4"}, + {"2.55", 1, "2.6"}, + {"1.45", 1, "1.4"}, + {"1.55", 1, "1.6"}, + {"1.65", 1, "1.6"}, + {"-2.5", 0, "-2"}, + {"-3.5", 0, "-4"}, + {"-4.5", 0, "-4"}, + {"-5.5", 0, "-6"}, + {"-2.45", 1, "-2.4"}, + {"-2.55", 1, "-2.6"}, + {"-1.45", 1, "-1.4"}, + {"-1.55", 1, "-1.6"}, + {"-1.65", 1, "-1.6"}, + {"2.4", 0, "2"}, + {"2.6", 0, "3"}, + {"2.499", 0, "2"}, + {"2.501", 0, "3"}, + {"-2.4", 0, "-2"}, + {"-2.6", 0, "-3"}, + {"-2.499", 0, "-2"}, + {"-2.501", 0, "-3"}, + {"0", 0, "0"}, + {"0.0", 0, "0"}, + {"0.5", 0, "0"}, + {"1.5", 0, "2"}, + {"-0.5", 0, "0"}, + {"-1.5", 0, "-2"}, + {"1.5", 1, "1.5"}, + {"1.50", 1, "1.5"}, + {"1.500", 2, "1.50"}, + {"2.345", 2, "2.34"}, + {"2.355", 2, "2.36"}, + {"2.365", 2, "2.36"}, + {"-2.345", 2, "-2.34"}, + {"-2.355", 2, "-2.36"}, + {"-2.365", 2, "-2.36"}, + {"123.456", 0, "123"}, + {"123.5", 0, "124"}, + {"124.5", 0, "124"}, + {"125.5", 0, "126"}, + {"-123.456", 0, "-123"}, + {"-123.5", 0, "-124"}, + {"-124.5", 0, "-124"}, + {"-125.5", 0, "-126"}, + {"0.12345", 4, "0.1234"}, + {"0.12355", 4, "0.1236"}, + {"0.12365", 4, "0.1236"}, + {"-0.12345", 4, "-0.1234"}, + {"-0.12355", 4, "-0.1236"}, + {"-0.12365", 4, "-0.1236"}, + {"545", -1, "540"}, + {"555", -1, "560"}, + {"565", -1, "560"}, + {"-545", -1, "-540"}, + {"-555", -1, "-560"}, + {"-565", -1, "-560"}, + {"1234.567", -2, "1200"}, + {"1250.0", -2, "1200"}, + {"1350.0", -2, "1400"}, + {"-1234.567", -2, "-1200"}, + {"-1250.0", -2, "-1200"}, + {"-1350.0", -2, "-1400"}, + {"3.1415926535", 5, "3.14159"}, + {"3.1415926535", 4, "3.1416"}, + {"3.14155", 4, "3.1416"}, + {"3.14145", 4, "3.1414"}, + {"-3.1415926535", 5, "-3.14159"}, + {"-3.1415926535", 4, "-3.1416"}, + {"-3.14155", 4, "-3.1416"}, + {"-3.14145", 4, "-3.1414"}, + } + + for _, test := range tests { + d, err := NewFromString(test.value) + if err != nil { + t.Fatal(err) + } + + result := d.RoundBank(test.places) + expected, err := NewFromString(test.expected) + if err != nil { + t.Fatal(err) + } + + if !result.Equal(expected) { + t.Errorf("RoundBank(%s, %d): expected %s, got %s", + test.value, test.places, test.expected, result.String()) + } + } +} + +func TestDecimal_Clamp(t *testing.T) { + type testData struct { + value string + min string + max string + expected string + } + + tests := []testData{ + {"5", "0", "10", "5"}, + {"0", "0", "10", "0"}, + {"10", "0", "10", "10"}, + {"-5", "0", "10", "0"}, + {"15", "0", "10", "10"}, + {"5.5", "0", "10", "5.5"}, + {"-1", "-10", "10", "-1"}, + {"-15", "-10", "10", "-10"}, + {"15", "-10", "10", "10"}, + {"0", "-5", "5", "0"}, + {"-5", "-5", "5", "-5"}, + {"5", "-5", "5", "5"}, + {"-10", "-5", "5", "-5"}, + {"10", "-5", "5", "5"}, + {"3.14159", "0", "10", "3.14159"}, + {"-3.14159", "-10", "10", "-3.14159"}, + {"0", "0", "0", "0"}, + {"100", "-1000", "1000", "100"}, + {"-1000", "-1000", "1000", "-1000"}, + {"1000", "-1000", "1000", "1000"}, + {"-1001", "-1000", "1000", "-1000"}, + {"1001", "-1000", "1000", "1000"}, + } + + for _, test := range tests { + d, err := NewFromString(test.value) + if err != nil { + t.Fatal(err) + } + min, err := NewFromString(test.min) + if err != nil { + t.Fatal(err) + } + max, err := NewFromString(test.max) + if err != nil { + t.Fatal(err) + } + + result := d.Clamp(min, max) + expected, err := NewFromString(test.expected) + if err != nil { + t.Fatal(err) + } + + if !result.Equal(expected) { + t.Errorf("Clamp(%s, %s, %s): expected %s, got %s", + test.value, test.min, test.max, test.expected, result.String()) + } + } +} + +func TestDecimal_Clamp_Panic(t *testing.T) { + type testData struct { + min string + max string + expectedPanic bool + } + + tests := []testData{ + {"5", "10", false}, + {"10", "10", false}, + {"0", "0", false}, + {"-10", "10", false}, + {"-5", "-5", false}, + {"10", "5", true}, + {"0", "-10", true}, + {"5", "-5", true}, + } + + for _, test := range tests { + min, err := NewFromString(test.min) + if err != nil { + t.Fatal(err) + } + max, err := NewFromString(test.max) + if err != nil { + t.Fatal(err) + } + + d := NewFromInt(0) + + if test.expectedPanic { + if !didPanic(func() { d.Clamp(min, max) }) { + t.Errorf("expected panic when Clamp with min=%s > max=%s, but no panic occurred", + test.min, test.max) + } + } else { + if didPanic(func() { d.Clamp(min, max) }) { + t.Errorf("unexpected panic when Clamp with min=%s <= max=%s", + test.min, test.max) + } + } + } +} + +func TestDecimal_Clamp_PanicMessage(t *testing.T) { + min := NewFromInt(10) + max := NewFromInt(5) + d := NewFromInt(0) + + defer func() { + if r := recover(); r != nil { + errMsg := fmt.Sprintf("%v", r) + expectedMsg := "decimal: min (10) is greater than max (5)" + if errMsg != expectedMsg { + t.Errorf("expected panic message '%s', got '%s'", expectedMsg, errMsg) + } + } else { + t.Error("expected panic, but no panic occurred") + } + }() + + d.Clamp(min, max) +}