diff --git a/src/math/round.c b/src/math/round.c index aa08d41..683a287 100644 --- a/src/math/round.c +++ b/src/math/round.c @@ -1,35 +1,156 @@ -#define asuint64(x) ((union {double f; uint64_t i;}){x}).i -#define asdouble(x) ((union {double f; uint64_t i;}){x}).f +#define asuint64(x) ((union {f64 f; uint64_t i;}){x}).i +#define asdouble(x) ((union {f64 f; uint64_t i;}){x}).f #if defined(__GNUC__) || defined(__clang__) - #define just_do_it(v) do{__attribute__((unused)) volatile double t = v;}while(0) + #define just_do_it(v) do{__attribute__((unused)) volatile f64 t = v;}while(0) #else - #define just_do_it(v) do{volatile double t = v;}while(0) + #define just_do_it(v) do{volatile f64 t = v;}while(0) #endif -double nearbyint(double x) { +f64 nearbyint(f64 x) { #pragma STDC FENV_ACCESS ON - int e = fetestexcept(FE_INEXACT); - x = rint(x); - if (!e) feclearexcept(FE_INEXACT); - return x; + u64 bits = F64_BITS(x); + i64 bexp = F64_BEXP(bits); + u64 bmant = F64_MANT(bits); + // 1. Get rid of special cases, exp = 0x7ff, and exp < 0x3ff + // Return x unmodified if inf, nan + if(bexp == 0x7ff) { + return x; + } + int mode = fegetround(); + // Get exponent for (integer_mantissa * 2^exp) representation + i64 exp = bexp - 0x3ff - 52; + int s = F64_SIGN(bits); + // This value is 0 if no increment is required, and 1 if the absolute value + // increases by 1 + int c; + { + // Check if we need to round towards 0 or towards 1 + // (assumes specific values in rounding modes in fenv.h) + int a = (mode&2)>>1; + int b = mode&1; + int mask = ((a^b)<<1)|(a^b); + int d = 2 - mode&mask; + c = s ^ d; + } + // If the whole mantissa is after a point, such that the first digit is 0, + // then the value is closer to 0 these values are all zeroes, subnormal + // numbers and very small normal numbers + if(exp < -53) { + // Return 0 if exponent and mantissa are zero + if(bexp == 0 && bmant == 0) { + return x; + } + // For subnormal and normal numbers we round them either towards 0 or 1 + // and then call it a day + u64 new_bexp = (u64)((1-c)&0x3ff) << F64_MANT_BITS; + u64 new_sign = (u64)s << 63; + u64 new_bits = new_sign | new_bexp; + return F64_CONS(new_bits); + } + // 2. Get fractional and whole bits of the mantissa + u64 mant = bmant | ((u64)1 << 52); + if(exp >= 0) { + // Already an integer + return x; + } + // if e.g. mantissa is 0b101.., and exponent is -2, the value is 0b101*2^-2 + // or 0b1.01, meaning there are 2 fractional digits + int nfrac_digs = -exp; + // The rest of the digits are whole + int nwhole_digs = F64_MANT_BITS - nfrac_digs; + u64 frac_mask = (((u64)1<<(nfrac_digs))-1); + u64 frac_mant = mant & frac_mask; + // The mantissas for 1.0 and 0.5 + u64 one = (((u64)1<<(nfrac_digs))); + u64 half = one >> 1; + // 3. Round the float based on the value of c + // we'll first fix up c to include other rounding modes + c |= (mode == FE_UPWARD) & ((~s)&1); + c |= (mode == FE_DOWNWARD) & s; + c |= (mode == FE_TONEAREST) & (frac_mant >= half); + // Drop fractional bits + u64 new_mant = mant & ~frac_mant; + // Add 1 to float if required + if(c) { + new_mant += one; + if(new_mant > ((u64)1 << 53)) { + new_mant >>= 1; + exp += 1; + } + } + new_mant &= F64_MANT_MASK; + u64 new_bits = new_mant; + new_bits |= (exp+0x3ff+52) << F64_MANT_BITS; + new_bits |= (u64)s << (F64_MANT_BITS + F64_BEXP_BITS); + f64 result = F64_CONS(new_bits); + return result; } -float nearbyintf(float x) { +f32 nearbyintf(f32 x) { #pragma STDC FENV_ACCESS ON - int e = fetestexcept(FE_INEXACT); - x = rintf(x); - if (!e) feclearexcept(FE_INEXACT); - return x; + u64 bits = F32_BITS(x); + i64 bexp = F32_BEXP(bits); + u64 bmant = F32_MANT(bits); + if(bexp == 0x7f) { + return x; + } + int mode = fegetround(); + i64 exp = bexp - 0x3f - 52; + int s = F32_SIGN(bits); + int c; + { + int a = (mode&2)>>1; + int b = mode&1; + int mask = ((a^b)<<1)|(a^b); + int d = 2 - mode&mask; + c = s ^ d; + } + if(exp < -24) { + if(bexp == 0 && bmant == 0) { + return x; + } + u64 new_bexp = (u64)((1-c)&0x3f) << F32_MANT_BITS; + u64 new_sign = (u64)s << 63; + u64 new_bits = new_sign | new_bexp; + return F32_CONS(new_bits); + } + u64 mant = bmant | ((u64)1 << 23); + if(exp >= 0) { + return x; + } + int nfrac_digs = -exp; + int nwhole_digs = F32_MANT_BITS - nfrac_digs; + u64 frac_mask = (((u64)1<<(nfrac_digs))-1); + u64 frac_mant = mant & frac_mask; + u64 one = (((u64)1<<(nfrac_digs))); + u64 half = one >> 1; + c |= (mode == FE_UPWARD) & ((~s)&1); + c |= (mode == FE_DOWNWARD) & s; + c |= (mode == FE_TONEAREST) & (frac_mant >= half); + u64 new_mant = mant & ~frac_mant; + if(c) { + new_mant += one; + if(new_mant > ((u64)1 << 24)) { + new_mant >>= 1; + exp += 1; + } + } + new_mant &= F32_MANT_MASK; + u64 new_bits = new_mant; + new_bits |= (exp+0x3f+23) << F32_MANT_BITS; + new_bits |= (u64)s << (F32_MANT_BITS + F32_BEXP_BITS); + f64 result = F32_CONS(new_bits); + return result; } -long double nearbyintl(long double x) { - return nearbyint(x); +fl64 nearbyintl(fl64 x) { + return nearbyint((f64)x); } -double nextafter(double x, double y) { - union {double f; uint64_t i;} ux={x}, uy={y}; +f64 nextafter(f64 x, f64 y) { + union {f64 f; uint64_t i;} ux={x}, uy={y}; uint64_t ax, ay; int e; if (isnan(x) || isnan(y)) return x + y; @@ -45,16 +166,16 @@ double nextafter(double x, double y) { else { ux.i++; } - e = ux.i >> 52 & 0x7ff; + e = ux.i >> 52 & 0x7f; /* raise overflow if ux.f is infinite and x is finite */ - if (e == 0x7ff) just_do_it(x+x); + if (e == 0x7f) just_do_it(x+x); /* raise underflow if ux.f is subnormal or zero */ if (e == 0) just_do_it(x*x + ux.f*ux.f); return ux.f; } -float nextafterf(float x, float y) { - union {float f; uint32_t i;} ux={x}, uy={y}; +f32 nextafterf(f32 x, f32 y) { + union {f32 f; uint32_t i;} ux={x}, uy={y}; uint32_t ax, ay, e; if (isnan(x) || isnan(y)) return x + y; @@ -78,16 +199,16 @@ float nextafterf(float x, float y) { return ux.f; } -long double nextafterl(long double x, long double y) { +fl64 nextafterl(fl64 x, fl64 y) { return nextafter(x, y); } -double nexttoward(double x, long double y) { +f64 nexttoward(f64 x, fl64 y) { return nextafter(x, y); } -float nexttowardf(float x, long double y) { - union {float f; uint32_t i;} ux = {x}; +f32 nexttowardf(f32 x, fl64 y) { + union {f32 f; uint32_t i;} ux = {x}; uint32_t e; if (isnan(x) || isnan(y)) return x + y; if (x == y) return y; @@ -109,16 +230,16 @@ float nexttowardf(float x, long double y) { return ux.f; } -long double nexttowardl(long double x, long double y) { +fl64 nexttowardl(fl64 x, fl64 y) { return nextafterl(x, y); } -double rint(double x) { +f64 rint(f64 x) { static const double_t toint = 1/DBL_EPSILON; - union {double f; uint64_t i;} u = {x}; + union {f64 f; uint64_t i;} u = {x}; int e = u.i>>52 & 0x7ff; int s = u.i>>63; - double y; + f64 y; if (e >= 0x3ff+52) return x; if (s) y = x - toint + toint; else y = x + toint - toint; @@ -126,12 +247,12 @@ double rint(double x) { return y; } -float rintf(float x) { - static const float toint = 1/FLT_EPSILON; - union {float f; uint32_t i;} u = {x}; +f32 rintf(f32 x) { + static const f32 toint = 1/FLT_EPSILON; + union {f32 f; uint32_t i;} u = {x}; int e = u.i>>23 & 0xff; int s = u.i>>31; - float y; + f32 y; if (e >= 0x7f+23) return x; if (s) y = x - toint + toint; else y = x + toint - toint; @@ -139,12 +260,12 @@ float rintf(float x) { return y; } -long double rintl(long double x) { +fl64 rintl(fl64 x) { return rint(x); } #if LONG_MAX < 1U<<53 && defined(FE_INEXACT) - static long lrint_slow(double x) + static long lrint_slow(f64 x) { #pragma STDC FENV_ACCESS ON int e; @@ -156,7 +277,7 @@ long double rintl(long double x) { return x; } - long lrint(double x) + long lrint(f64 x) { uint32_t abstop = asuint64(x)>>32 & 0x7fffffff; uint64_t sign = asuint64(x) & (1ULL << 63); @@ -170,34 +291,34 @@ long double rintl(long double x) { return lrint_slow(x); } #else - long lrint(double x) { + long lrint(f64 x) { return rint(x); } #endif -long lrintf(float x) { +long lrintf(f32 x) { return rintf(x); } -long lrintl(long double x) { +long lrintl(fl64 x) { return lrint(x); } -long long llrint(double x) { +long long llrint(f64 x) { return rint(x); } -long long llrintf(float x) { +long long llrintf(f32 x) { return rintf(x); } -long long llrintl(long double x) { +long long llrintl(fl64 x) { return llrint(x); } -double round(double x) { +f64 round(f64 x) { static const double_t toint = 1/DBL_EPSILON; - union {double f; uint64_t i;} u = {x}; + union {f64 f; uint64_t i;} u = {x}; int e = u.i >> 52 & 0x7ff; double_t y; if (e >= 0x3ff+52) return x; @@ -215,9 +336,9 @@ double round(double x) { return y; } -float roundf(float x) { +f32 roundf(f32 x) { static const double_t toint = 1/FLT_EPSILON; - union {float f; uint32_t i;} u = {x}; + union {f32 f; uint32_t i;} u = {x}; int e = u.i >> 23 & 0xff; float_t y; if (e >= 0x7f+23) return x; @@ -234,37 +355,37 @@ float roundf(float x) { return y; } -long double roundl(long double x) { +fl64 roundl(fl64 x) { return round(x); } -long lround(double x) { +long lround(f64 x) { return round(x); } -long lroundf(float x) { +long lroundf(f32 x) { return roundf(x); } -long lroundl(long double x) { +long lroundl(fl64 x) { return roundl(x); } -long long llround(double x) { +long long llround(f64 x) { return round(x); } -long long llroundf(float x) { +long long llroundf(f32 x) { return roundf(x); } -long long llroundl(long double x) { +long long llroundl(fl64 x) { return roundl(x); } -double ceil(double x) { +f64 ceil(f64 x) { static const double_t toint = 1/DBL_EPSILON; - union {double f; uint64_t i;} u = {x}; + union {f64 f; uint64_t i;} u = {x}; int e = u.i >> 52 & 0x7ff; double_t y; @@ -285,8 +406,8 @@ double ceil(double x) { return x + y; } -float ceilf(float x) { - union {float f; uint32_t i;} u = {x}; +f32 ceilf(f32 x) { + union {f32 f; uint32_t i;} u = {x}; int e = (int)(u.i >> 23 & 0xff) - 0x7f; uint32_t m; @@ -310,13 +431,13 @@ float ceilf(float x) { return u.f; } -long double ceill(long double x) { +fl64 ceill(fl64 x) { return ceil(x); } -double floor(double x) { +f64 floor(f64 x) { static const double_t toint = 1/DBL_EPSILON; - union {double f; uint64_t i;} u = {x}; + union {f64 f; uint64_t i;} u = {x}; int e = u.i >> 52 & 0x7ff; double_t y; if (e >= 0x3ff+52 || x == 0) @@ -336,8 +457,8 @@ double floor(double x) { return x + y; } -float floorf(float x) { - union {float f; uint32_t i;} u = {x}; +f32 floorf(f32 x) { + union {f32 f; uint32_t i;} u = {x}; int e = (int)(u.i >> 23 & 0xff) - 0x7f; uint32_t m; @@ -361,12 +482,12 @@ float floorf(float x) { return u.f; } -long double floorl(long double x) { +fl64 floorl(fl64 x) { return floor(x); } -double trunc(double x) { - union {double f; uint64_t i;} u = {x}; +f64 trunc(f64 x) { + union {f64 f; uint64_t i;} u = {x}; int e = (int)(u.i >> 52 & 0x7ff) - 0x3ff + 12; uint64_t m; @@ -382,8 +503,8 @@ double trunc(double x) { return u.f; } -float truncf(float x) { - union {float f; uint32_t i;} u = {x}; +f32 truncf(f32 x) { + union {f32 f; uint32_t i;} u = {x}; int e = (int)(u.i >> 23 & 0xff) - 0x7f + 9; uint32_t m; @@ -399,6 +520,6 @@ float truncf(float x) { return u.f; } -long double truncl(long double x) { +fl64 truncl(fl64 x) { return trunc(x); } diff --git a/src/util.c b/src/util.c index b517e6e..be6076d 100644 --- a/src/util.c +++ b/src/util.c @@ -44,10 +44,10 @@ typedef wchar_t wchar; #define STR_(a) #a #define STR(a) STR_(a) -#define F64_BITS(x) ((union {f64 f; u64 i;}){x}).i -#define F64_CONS(x) ((union {f64 f; u64 i;}){x}).f -#define F32_BITS(x) ((union {f32 f; u32 i;}){x}).i -#define F32_CONS(x) ((union {f32 f; u32 i;}){x}).f +#define F64_BITS(x) ((union {f64 f; u64 i;}){.f=x}).i +#define F64_CONS(x) ((union {f64 f; u64 i;}){.i=x}).f +#define F32_BITS(x) ((union {f32 f; u32 i;}){.f=x}).i +#define F32_CONS(x) ((union {f32 f; u32 i;}){.i=x}).f #define F64_MANT_MASK UINT64_C(0xfffffffffffff) #define F64_MANT_MAX UINT64_C(0xfffffffffffff) diff --git a/test/math.c b/test/math.c index 0e58bcc..4137bdd 100644 --- a/test/math.c +++ b/test/math.c @@ -28,6 +28,14 @@ int main() { printf("-0.0 is %s\n", show_classification(-0.0)); printf("1.0 is %s\n", show_classification(1.0)); + printf("\n\n=== nearbyint === \n"); + double d; + printf("nearbyint(+0.0) = %f\n", d=nearbyint(+0.0)); + printf("nearbyint(-0.0) = %f\n", d=nearbyint(-0.0)); + printf("nearbyint(+16.4) = %f\n", d=nearbyint(+16.4)); + printf("nearbyint(+16.5) = %f\n", d=nearbyint(+16.5)); + printf("nearbyint(+16.8) = %f\n", d=nearbyint(+16.8)); + printf("\n\n=== signbit === \n"); printf("signbit(+0.0) = %d\n", signbit(+0.0)); printf("signbit(-0.0) = %d\n", signbit(-0.0));