diff options
-rw-r--r-- | .pick_status.json | 2 | ||||
-rw-r--r-- | src/compiler/spirv/vtn_glsl450.c | 27 |
2 files changed, 14 insertions, 15 deletions
diff --git a/.pick_status.json b/.pick_status.json index ee733be8156..69a14295cb8 100644 --- a/.pick_status.json +++ b/.pick_status.json @@ -688,7 +688,7 @@ "description": "Revert \"spirv: Use a simpler and more correct implementaiton of tanh()\"", "nominated": true, "nomination_type": 2, - "resolution": 0, + "resolution": 1, "master_sha": null, "because_sha": "da1c49171d0df185545cfbbd600e287f7c6160fa" }, diff --git a/src/compiler/spirv/vtn_glsl450.c b/src/compiler/spirv/vtn_glsl450.c index 2d66512bc42..5ae75ff4a15 100644 --- a/src/compiler/spirv/vtn_glsl450.c +++ b/src/compiler/spirv/vtn_glsl450.c @@ -458,25 +458,24 @@ handle_glsl450_alu(struct vtn_builder *b, enum GLSLstd450 entrypoint, return; case GLSLstd450Tanh: { - /* tanh(x) := (0.5 * (e^x - e^(-x))) / (0.5 * (e^x + e^(-x))) + /* tanh(x) := (e^x - e^(-x)) / (e^x + e^(-x)) * - * With a little algebra this reduces to (e^2x - 1) / (e^2x + 1) + * We clamp x to [-10, +10] to avoid precision problems. When x > 10, + * e^x dominates the sum, e^(-x) is lost and tanh(x) is 1.0 for 32 bit + * floating point. * - * We clamp x to (-inf, +10] to avoid precision problems. When x > 10, - * e^2x is so much larger than 1.0 that 1.0 gets flushed to zero in the - * computation e^2x +/- 1 so it can be ignored. - * - * For 16-bit precision we clamp x to (-inf, +4.2] since the maximum - * representable number is only 65,504 and e^(2*6) exceeds that. Also, - * if x > 4.2, tanh(x) will return 1.0 in fp16. + * For 16-bit precision this we clamp x to [-4.2, +4.2]. */ const uint32_t bit_size = src[0]->bit_size; const double clamped_x = bit_size > 16 ? 10.0 : 4.2; - nir_ssa_def *x = nir_fmin(nb, src[0], - nir_imm_floatN_t(nb, clamped_x, bit_size)); - nir_ssa_def *exp2x = build_exp(nb, nir_fmul_imm(nb, x, 2.0)); - val->ssa->def = nir_fdiv(nb, nir_fadd_imm(nb, exp2x, -1.0), - nir_fadd_imm(nb, exp2x, 1.0)); + nir_ssa_def *x = nir_fclamp(nb, src[0], + nir_imm_floatN_t(nb, -clamped_x, bit_size), + nir_imm_floatN_t(nb, clamped_x, bit_size)); + val->ssa->def = + nir_fdiv(nb, nir_fsub(nb, build_exp(nb, x), + build_exp(nb, nir_fneg(nb, x))), + nir_fadd(nb, build_exp(nb, x), + build_exp(nb, nir_fneg(nb, x)))); return; } |