nearbyint rewrite

This commit is contained in:
bumbread 2022-08-06 16:31:44 +11:00
parent 4dd8816167
commit 38da83bfe0
3 changed files with 202 additions and 73 deletions

View File

@ -1,35 +1,156 @@
#define asuint64(x) ((union {double f; uint64_t i;}){x}).i #define asuint64(x) ((union {f64 f; uint64_t i;}){x}).i
#define asdouble(x) ((union {double f; uint64_t i;}){x}).f #define asdouble(x) ((union {f64 f; uint64_t i;}){x}).f
#if defined(__GNUC__) || defined(__clang__) #if defined(__GNUC__) || defined(__clang__)
#define just_do_it(v) do{__attribute__((unused)) volatile double t = v;}while(0) #define just_do_it(v) do{__attribute__((unused)) volatile f64 t = v;}while(0)
#else #else
#define just_do_it(v) do{volatile double t = v;}while(0) #define just_do_it(v) do{volatile f64 t = v;}while(0)
#endif #endif
double nearbyint(double x) { f64 nearbyint(f64 x) {
#pragma STDC FENV_ACCESS ON #pragma STDC FENV_ACCESS ON
int e = fetestexcept(FE_INEXACT); u64 bits = F64_BITS(x);
x = rint(x); i64 bexp = F64_BEXP(bits);
if (!e) feclearexcept(FE_INEXACT); u64 bmant = F64_MANT(bits);
// 1. Get rid of special cases, exp = 0x7ff, and exp < 0x3ff
// Return x unmodified if inf, nan
if(bexp == 0x7ff) {
return x; return x;
}
int mode = fegetround();
// Get exponent for (integer_mantissa * 2^exp) representation
i64 exp = bexp - 0x3ff - 52;
int s = F64_SIGN(bits);
// This value is 0 if no increment is required, and 1 if the absolute value
// increases by 1
int c;
{
// Check if we need to round towards 0 or towards 1
// (assumes specific values in rounding modes in fenv.h)
int a = (mode&2)>>1;
int b = mode&1;
int mask = ((a^b)<<1)|(a^b);
int d = 2 - mode&mask;
c = s ^ d;
}
// If the whole mantissa is after a point, such that the first digit is 0,
// then the value is closer to 0 these values are all zeroes, subnormal
// numbers and very small normal numbers
if(exp < -53) {
// Return 0 if exponent and mantissa are zero
if(bexp == 0 && bmant == 0) {
return x;
}
// For subnormal and normal numbers we round them either towards 0 or 1
// and then call it a day
u64 new_bexp = (u64)((1-c)&0x3ff) << F64_MANT_BITS;
u64 new_sign = (u64)s << 63;
u64 new_bits = new_sign | new_bexp;
return F64_CONS(new_bits);
}
// 2. Get fractional and whole bits of the mantissa
u64 mant = bmant | ((u64)1 << 52);
if(exp >= 0) {
// Already an integer
return x;
}
// if e.g. mantissa is 0b101.., and exponent is -2, the value is 0b101*2^-2
// or 0b1.01, meaning there are 2 fractional digits
int nfrac_digs = -exp;
// The rest of the digits are whole
int nwhole_digs = F64_MANT_BITS - nfrac_digs;
u64 frac_mask = (((u64)1<<(nfrac_digs))-1);
u64 frac_mant = mant & frac_mask;
// The mantissas for 1.0 and 0.5
u64 one = (((u64)1<<(nfrac_digs)));
u64 half = one >> 1;
// 3. Round the float based on the value of c
// we'll first fix up c to include other rounding modes
c |= (mode == FE_UPWARD) & ((~s)&1);
c |= (mode == FE_DOWNWARD) & s;
c |= (mode == FE_TONEAREST) & (frac_mant >= half);
// Drop fractional bits
u64 new_mant = mant & ~frac_mant;
// Add 1 to float if required
if(c) {
new_mant += one;
if(new_mant > ((u64)1 << 53)) {
new_mant >>= 1;
exp += 1;
}
}
new_mant &= F64_MANT_MASK;
u64 new_bits = new_mant;
new_bits |= (exp+0x3ff+52) << F64_MANT_BITS;
new_bits |= (u64)s << (F64_MANT_BITS + F64_BEXP_BITS);
f64 result = F64_CONS(new_bits);
return result;
} }
float nearbyintf(float x) { f32 nearbyintf(f32 x) {
#pragma STDC FENV_ACCESS ON #pragma STDC FENV_ACCESS ON
int e = fetestexcept(FE_INEXACT); u64 bits = F32_BITS(x);
x = rintf(x); i64 bexp = F32_BEXP(bits);
if (!e) feclearexcept(FE_INEXACT); u64 bmant = F32_MANT(bits);
if(bexp == 0x7f) {
return x; return x;
}
int mode = fegetround();
i64 exp = bexp - 0x3f - 52;
int s = F32_SIGN(bits);
int c;
{
int a = (mode&2)>>1;
int b = mode&1;
int mask = ((a^b)<<1)|(a^b);
int d = 2 - mode&mask;
c = s ^ d;
}
if(exp < -24) {
if(bexp == 0 && bmant == 0) {
return x;
}
u64 new_bexp = (u64)((1-c)&0x3f) << F32_MANT_BITS;
u64 new_sign = (u64)s << 63;
u64 new_bits = new_sign | new_bexp;
return F32_CONS(new_bits);
}
u64 mant = bmant | ((u64)1 << 23);
if(exp >= 0) {
return x;
}
int nfrac_digs = -exp;
int nwhole_digs = F32_MANT_BITS - nfrac_digs;
u64 frac_mask = (((u64)1<<(nfrac_digs))-1);
u64 frac_mant = mant & frac_mask;
u64 one = (((u64)1<<(nfrac_digs)));
u64 half = one >> 1;
c |= (mode == FE_UPWARD) & ((~s)&1);
c |= (mode == FE_DOWNWARD) & s;
c |= (mode == FE_TONEAREST) & (frac_mant >= half);
u64 new_mant = mant & ~frac_mant;
if(c) {
new_mant += one;
if(new_mant > ((u64)1 << 24)) {
new_mant >>= 1;
exp += 1;
}
}
new_mant &= F32_MANT_MASK;
u64 new_bits = new_mant;
new_bits |= (exp+0x3f+23) << F32_MANT_BITS;
new_bits |= (u64)s << (F32_MANT_BITS + F32_BEXP_BITS);
f64 result = F32_CONS(new_bits);
return result;
} }
long double nearbyintl(long double x) { fl64 nearbyintl(fl64 x) {
return nearbyint(x); return nearbyint((f64)x);
} }
double nextafter(double x, double y) { f64 nextafter(f64 x, f64 y) {
union {double f; uint64_t i;} ux={x}, uy={y}; union {f64 f; uint64_t i;} ux={x}, uy={y};
uint64_t ax, ay; uint64_t ax, ay;
int e; int e;
if (isnan(x) || isnan(y)) return x + y; if (isnan(x) || isnan(y)) return x + y;
@ -45,16 +166,16 @@ double nextafter(double x, double y) {
else { else {
ux.i++; ux.i++;
} }
e = ux.i >> 52 & 0x7ff; e = ux.i >> 52 & 0x7f;
/* raise overflow if ux.f is infinite and x is finite */ /* raise overflow if ux.f is infinite and x is finite */
if (e == 0x7ff) just_do_it(x+x); if (e == 0x7f) just_do_it(x+x);
/* raise underflow if ux.f is subnormal or zero */ /* raise underflow if ux.f is subnormal or zero */
if (e == 0) just_do_it(x*x + ux.f*ux.f); if (e == 0) just_do_it(x*x + ux.f*ux.f);
return ux.f; return ux.f;
} }
float nextafterf(float x, float y) { f32 nextafterf(f32 x, f32 y) {
union {float f; uint32_t i;} ux={x}, uy={y}; union {f32 f; uint32_t i;} ux={x}, uy={y};
uint32_t ax, ay, e; uint32_t ax, ay, e;
if (isnan(x) || isnan(y)) return x + y; if (isnan(x) || isnan(y)) return x + y;
@ -78,16 +199,16 @@ float nextafterf(float x, float y) {
return ux.f; return ux.f;
} }
long double nextafterl(long double x, long double y) { fl64 nextafterl(fl64 x, fl64 y) {
return nextafter(x, y); return nextafter(x, y);
} }
double nexttoward(double x, long double y) { f64 nexttoward(f64 x, fl64 y) {
return nextafter(x, y); return nextafter(x, y);
} }
float nexttowardf(float x, long double y) { f32 nexttowardf(f32 x, fl64 y) {
union {float f; uint32_t i;} ux = {x}; union {f32 f; uint32_t i;} ux = {x};
uint32_t e; uint32_t e;
if (isnan(x) || isnan(y)) return x + y; if (isnan(x) || isnan(y)) return x + y;
if (x == y) return y; if (x == y) return y;
@ -109,16 +230,16 @@ float nexttowardf(float x, long double y) {
return ux.f; return ux.f;
} }
long double nexttowardl(long double x, long double y) { fl64 nexttowardl(fl64 x, fl64 y) {
return nextafterl(x, y); return nextafterl(x, y);
} }
double rint(double x) { f64 rint(f64 x) {
static const double_t toint = 1/DBL_EPSILON; static const double_t toint = 1/DBL_EPSILON;
union {double f; uint64_t i;} u = {x}; union {f64 f; uint64_t i;} u = {x};
int e = u.i>>52 & 0x7ff; int e = u.i>>52 & 0x7ff;
int s = u.i>>63; int s = u.i>>63;
double y; f64 y;
if (e >= 0x3ff+52) return x; if (e >= 0x3ff+52) return x;
if (s) y = x - toint + toint; if (s) y = x - toint + toint;
else y = x + toint - toint; else y = x + toint - toint;
@ -126,12 +247,12 @@ double rint(double x) {
return y; return y;
} }
float rintf(float x) { f32 rintf(f32 x) {
static const float toint = 1/FLT_EPSILON; static const f32 toint = 1/FLT_EPSILON;
union {float f; uint32_t i;} u = {x}; union {f32 f; uint32_t i;} u = {x};
int e = u.i>>23 & 0xff; int e = u.i>>23 & 0xff;
int s = u.i>>31; int s = u.i>>31;
float y; f32 y;
if (e >= 0x7f+23) return x; if (e >= 0x7f+23) return x;
if (s) y = x - toint + toint; if (s) y = x - toint + toint;
else y = x + toint - toint; else y = x + toint - toint;
@ -139,12 +260,12 @@ float rintf(float x) {
return y; return y;
} }
long double rintl(long double x) { fl64 rintl(fl64 x) {
return rint(x); return rint(x);
} }
#if LONG_MAX < 1U<<53 && defined(FE_INEXACT) #if LONG_MAX < 1U<<53 && defined(FE_INEXACT)
static long lrint_slow(double x) static long lrint_slow(f64 x)
{ {
#pragma STDC FENV_ACCESS ON #pragma STDC FENV_ACCESS ON
int e; int e;
@ -156,7 +277,7 @@ long double rintl(long double x) {
return x; return x;
} }
long lrint(double x) long lrint(f64 x)
{ {
uint32_t abstop = asuint64(x)>>32 & 0x7fffffff; uint32_t abstop = asuint64(x)>>32 & 0x7fffffff;
uint64_t sign = asuint64(x) & (1ULL << 63); uint64_t sign = asuint64(x) & (1ULL << 63);
@ -170,34 +291,34 @@ long double rintl(long double x) {
return lrint_slow(x); return lrint_slow(x);
} }
#else #else
long lrint(double x) { long lrint(f64 x) {
return rint(x); return rint(x);
} }
#endif #endif
long lrintf(float x) { long lrintf(f32 x) {
return rintf(x); return rintf(x);
} }
long lrintl(long double x) { long lrintl(fl64 x) {
return lrint(x); return lrint(x);
} }
long long llrint(double x) { long long llrint(f64 x) {
return rint(x); return rint(x);
} }
long long llrintf(float x) { long long llrintf(f32 x) {
return rintf(x); return rintf(x);
} }
long long llrintl(long double x) { long long llrintl(fl64 x) {
return llrint(x); return llrint(x);
} }
double round(double x) { f64 round(f64 x) {
static const double_t toint = 1/DBL_EPSILON; static const double_t toint = 1/DBL_EPSILON;
union {double f; uint64_t i;} u = {x}; union {f64 f; uint64_t i;} u = {x};
int e = u.i >> 52 & 0x7ff; int e = u.i >> 52 & 0x7ff;
double_t y; double_t y;
if (e >= 0x3ff+52) return x; if (e >= 0x3ff+52) return x;
@ -215,9 +336,9 @@ double round(double x) {
return y; return y;
} }
float roundf(float x) { f32 roundf(f32 x) {
static const double_t toint = 1/FLT_EPSILON; static const double_t toint = 1/FLT_EPSILON;
union {float f; uint32_t i;} u = {x}; union {f32 f; uint32_t i;} u = {x};
int e = u.i >> 23 & 0xff; int e = u.i >> 23 & 0xff;
float_t y; float_t y;
if (e >= 0x7f+23) return x; if (e >= 0x7f+23) return x;
@ -234,37 +355,37 @@ float roundf(float x) {
return y; return y;
} }
long double roundl(long double x) { fl64 roundl(fl64 x) {
return round(x); return round(x);
} }
long lround(double x) { long lround(f64 x) {
return round(x); return round(x);
} }
long lroundf(float x) { long lroundf(f32 x) {
return roundf(x); return roundf(x);
} }
long lroundl(long double x) { long lroundl(fl64 x) {
return roundl(x); return roundl(x);
} }
long long llround(double x) { long long llround(f64 x) {
return round(x); return round(x);
} }
long long llroundf(float x) { long long llroundf(f32 x) {
return roundf(x); return roundf(x);
} }
long long llroundl(long double x) { long long llroundl(fl64 x) {
return roundl(x); return roundl(x);
} }
double ceil(double x) { f64 ceil(f64 x) {
static const double_t toint = 1/DBL_EPSILON; static const double_t toint = 1/DBL_EPSILON;
union {double f; uint64_t i;} u = {x}; union {f64 f; uint64_t i;} u = {x};
int e = u.i >> 52 & 0x7ff; int e = u.i >> 52 & 0x7ff;
double_t y; double_t y;
@ -285,8 +406,8 @@ double ceil(double x) {
return x + y; return x + y;
} }
float ceilf(float x) { f32 ceilf(f32 x) {
union {float f; uint32_t i;} u = {x}; union {f32 f; uint32_t i;} u = {x};
int e = (int)(u.i >> 23 & 0xff) - 0x7f; int e = (int)(u.i >> 23 & 0xff) - 0x7f;
uint32_t m; uint32_t m;
@ -310,13 +431,13 @@ float ceilf(float x) {
return u.f; return u.f;
} }
long double ceill(long double x) { fl64 ceill(fl64 x) {
return ceil(x); return ceil(x);
} }
double floor(double x) { f64 floor(f64 x) {
static const double_t toint = 1/DBL_EPSILON; static const double_t toint = 1/DBL_EPSILON;
union {double f; uint64_t i;} u = {x}; union {f64 f; uint64_t i;} u = {x};
int e = u.i >> 52 & 0x7ff; int e = u.i >> 52 & 0x7ff;
double_t y; double_t y;
if (e >= 0x3ff+52 || x == 0) if (e >= 0x3ff+52 || x == 0)
@ -336,8 +457,8 @@ double floor(double x) {
return x + y; return x + y;
} }
float floorf(float x) { f32 floorf(f32 x) {
union {float f; uint32_t i;} u = {x}; union {f32 f; uint32_t i;} u = {x};
int e = (int)(u.i >> 23 & 0xff) - 0x7f; int e = (int)(u.i >> 23 & 0xff) - 0x7f;
uint32_t m; uint32_t m;
@ -361,12 +482,12 @@ float floorf(float x) {
return u.f; return u.f;
} }
long double floorl(long double x) { fl64 floorl(fl64 x) {
return floor(x); return floor(x);
} }
double trunc(double x) { f64 trunc(f64 x) {
union {double f; uint64_t i;} u = {x}; union {f64 f; uint64_t i;} u = {x};
int e = (int)(u.i >> 52 & 0x7ff) - 0x3ff + 12; int e = (int)(u.i >> 52 & 0x7ff) - 0x3ff + 12;
uint64_t m; uint64_t m;
@ -382,8 +503,8 @@ double trunc(double x) {
return u.f; return u.f;
} }
float truncf(float x) { f32 truncf(f32 x) {
union {float f; uint32_t i;} u = {x}; union {f32 f; uint32_t i;} u = {x};
int e = (int)(u.i >> 23 & 0xff) - 0x7f + 9; int e = (int)(u.i >> 23 & 0xff) - 0x7f + 9;
uint32_t m; uint32_t m;
@ -399,6 +520,6 @@ float truncf(float x) {
return u.f; return u.f;
} }
long double truncl(long double x) { fl64 truncl(fl64 x) {
return trunc(x); return trunc(x);
} }

View File

@ -44,10 +44,10 @@ typedef wchar_t wchar;
#define STR_(a) #a #define STR_(a) #a
#define STR(a) STR_(a) #define STR(a) STR_(a)
#define F64_BITS(x) ((union {f64 f; u64 i;}){x}).i #define F64_BITS(x) ((union {f64 f; u64 i;}){.f=x}).i
#define F64_CONS(x) ((union {f64 f; u64 i;}){x}).f #define F64_CONS(x) ((union {f64 f; u64 i;}){.i=x}).f
#define F32_BITS(x) ((union {f32 f; u32 i;}){x}).i #define F32_BITS(x) ((union {f32 f; u32 i;}){.f=x}).i
#define F32_CONS(x) ((union {f32 f; u32 i;}){x}).f #define F32_CONS(x) ((union {f32 f; u32 i;}){.i=x}).f
#define F64_MANT_MASK UINT64_C(0xfffffffffffff) #define F64_MANT_MASK UINT64_C(0xfffffffffffff)
#define F64_MANT_MAX UINT64_C(0xfffffffffffff) #define F64_MANT_MAX UINT64_C(0xfffffffffffff)

View File

@ -28,6 +28,14 @@ int main() {
printf("-0.0 is %s\n", show_classification(-0.0)); printf("-0.0 is %s\n", show_classification(-0.0));
printf("1.0 is %s\n", show_classification(1.0)); printf("1.0 is %s\n", show_classification(1.0));
printf("\n\n=== nearbyint === \n");
double d;
printf("nearbyint(+0.0) = %f\n", d=nearbyint(+0.0));
printf("nearbyint(-0.0) = %f\n", d=nearbyint(-0.0));
printf("nearbyint(+16.4) = %f\n", d=nearbyint(+16.4));
printf("nearbyint(+16.5) = %f\n", d=nearbyint(+16.5));
printf("nearbyint(+16.8) = %f\n", d=nearbyint(+16.8));
printf("\n\n=== signbit === \n"); printf("\n\n=== signbit === \n");
printf("signbit(+0.0) = %d\n", signbit(+0.0)); printf("signbit(+0.0) = %d\n", signbit(+0.0));
printf("signbit(-0.0) = %d\n", signbit(-0.0)); printf("signbit(-0.0) = %d\n", signbit(-0.0));