3
\$\begingroup\$

I have implemented an integer square root function that is branch-free and runs in constant time, using the first variant found in this answer as a base. All possible values for the types byte, ushort, and uint have been exhaustively verified against the Math.Sqrt function. Validating ulong and UInt128 completely is not feasible but I have yet to find any edge cases that fail.

It would be nice to add support for types that are larger than 128 bits but I was unable to come up with a way to calculate the constant required. Am curious if anyone has any ideas on how one could solve that problem or otherwise improve the function.

C#

public static class BinaryIntegerConstants<T> where T : IBinaryInteger<T>
{
    public static T Size { get; } = T.PopCount(value: T.AllBitsSet);
}

private static T As<T>(this bool value) where T : IBinaryInteger<T> =>
    T.CreateTruncating(value: Unsafe.As<bool, byte>(source: ref value));

public static T MostSignificantBit<T>(this T value) where T : IBinaryInteger<T> =>
    (BinaryIntegerConstants<T>.Size - T.LeadingZeroCount(value: value));
public static T SquareRoot<T>(this T value) where T : IBinaryInteger<T>, IUnsignedNumber<T> {
    var msb = int.CreateTruncating(value: value.MostSignificantBit());
    var msbIsOdd = (msb & 1);
    var m = ((msb + 1) >> 1);
    var mMinusOne = (m - 1);
    var mPlusOne = (m + 1);
    var x = (T.One << mMinusOne);
    var y = (x - (value >> (mPlusOne - msbIsOdd)));
    var z = y;

    x += x;

    if (BinaryIntegerConstants<T>.Size > T.CreateChecked(value: 8UL)) {
        y = (((y * y) >> mPlusOne) + z);
        y = (((y * y) >> mPlusOne) + z);
    }

    if (BinaryIntegerConstants<T>.Size > T.CreateChecked(value: 16UL)) {
        y = (((y * y) >> mPlusOne) + z);
        y = (((y * y) >> mPlusOne) + z);
        y = (((y * y) >> mPlusOne) + z);
        y = (((y * y) >> mPlusOne) + z);
    }

    if (BinaryIntegerConstants<T>.Size > T.CreateChecked(value: 32UL)) {
        y = (((y * y) >> mPlusOne) + z);
        y = (((y * y) >> mPlusOne) + z);
        y = (((y * y) >> mPlusOne) + z);
        y = (((y * y) >> mPlusOne) + z);
        y = (((y * y) >> mPlusOne) + z);
        y = (((y * y) >> mPlusOne) + z);
        y = (((y * y) >> mPlusOne) + z);
        y = (((y * y) >> mPlusOne) + z);
    }

    if (BinaryIntegerConstants<T>.Size > T.CreateChecked(value: 64UL)) {
        var i = (BinaryIntegerConstants<T>.Size >> 3);

        do {
            i -= (T.One << 3);
            y = (((y * y) >> mPlusOne) + z);
            y = (((y * y) >> mPlusOne) + z);
            y = (((y * y) >> mPlusOne) + z);
            y = (((y * y) >> mPlusOne) + z);
            y = (((y * y) >> mPlusOne) + z);
            y = (((y * y) >> mPlusOne) + z);
            y = (((y * y) >> mPlusOne) + z);
            y = (((y * y) >> mPlusOne) + z);
        } while (i != T.Zero);
    }

    y = (x - y);
    x = T.CreateTruncating(value: msbIsOdd);
    y -= uint.CreateChecked(value: BinaryIntegerConstants<T>.Size) switch {
        8U => (x * ((y * T.CreateChecked(value: 5UL)) >> 4)),
        16U => (x * ((y * T.CreateChecked(value: 75UL)) >> 8)),
        32U => (x * ((y * T.CreateChecked(value: 19195UL)) >> 16)),
        64U => (x * ((y * T.CreateChecked(value: 1257966796UL)) >> 32)),
        128U => (x * ((y * T.CreateChecked(value: 5402926248376769403UL)) >> 64)),
        _ => throw new NotSupportedException(), // TODO: Research a way to calculate the proper constant at runtime.
    };
    x = (T.One << (int.CreateTruncating(value: (BinaryIntegerConstants<T>.Size - T.One))));
    y -= ((value - (y * y)) > x).As<T>();

    if (BinaryIntegerConstants<T>.Size > T.CreateChecked(value: 8UL)) {
        y -= ((value - (y * y)) > x).As<T>();
        y -= ((value - (y * y)) > x).As<T>();
    }

    if (BinaryIntegerConstants<T>.Size > T.CreateChecked(value: 32UL)) {
        y -= ((value - (y * y)) > x).As<T>();
        y -= ((value - (y * y)) > x).As<T>();
        y -= ((value - (y * y)) > x).As<T>();
    }

    return (y & (T.AllBitsSet >> 1));
}

32-Bit Asm | .NET 7.0.0 (7.0.22.51805), X64 RyuJIT AVX2

; SquareRoot[[System.UInt32, System.Private.CoreLib]](UInt32)
       push      rsi
       sub       rsp,20
       mov       esi,ecx
       mov       ecx,esi
       call      qword ptr [MostSignificantBit[[System.UInt32, System.Private.CoreLib]](UInt32)]
       mov       edx,eax
       and       edx,1
       inc       eax
       shr       eax,1
       lea       ecx,[rax-1]
       inc       eax
       mov       r8d,1
       shlx      ecx,r8d,ecx
       mov       r8d,eax
       sub       r8d,edx
       shrx      r8d,esi,r8d
       mov       r9d,ecx
       sub       r9d,r8d
       add       ecx,ecx
       mov       r8d,r9d
       imul      r8d,r9d
       and       eax,1F
       shrx      r8d,r8d,eax
       add       r8d,r9d
       imul      r8d,r8d
       shrx      r8d,r8d,eax
       add       r8d,r9d
       imul      r8d,r8d
       shrx      r8d,r8d,eax
       add       r8d,r9d
       imul      r8d,r8d
       shrx      r8d,r8d,eax
       add       r8d,r9d
       imul      r8d,r8d
       shrx      r8d,r8d,eax
       add       r8d,r9d
       imul      r8d,r8d
       shrx      r8d,r8d,eax
       add       r8d,r9d
       mov       eax,ecx
       sub       eax,r8d
       mov       r8d,eax
       imul      eax,r8d,4AFB
       shr       eax,10
       imul      eax,edx
       sub       r8d,eax
       mov       eax,r8d
       imul      eax,r8d
       mov       edx,esi
       sub       edx,eax
       xor       eax,eax
       cmp       edx,80000000
       seta      al
       sub       r8d,eax
       mov       eax,r8d
       imul      eax,r8d
       mov       edx,esi
       sub       edx,eax
       xor       eax,eax
       cmp       edx,80000000
       seta      al
       sub       r8d,eax
       mov       eax,r8d
       imul      eax,r8d
       sub       esi,eax
       xor       eax,eax
       cmp       esi,80000000
       seta      al
       sub       r8d,eax
       mov       eax,r8d
       and       eax,7FFFFFFF
       add       rsp,20
       pop       rsi
       ret
; Total bytes of code 248

64-Bit Asm | .NET 7.0.0 (7.0.22.51805), X64 RyuJIT AVX2

; SquareRoot[[System.UInt64, System.Private.CoreLib]](UInt64)
       push      rsi
       sub       rsp,20
       mov       rsi,rcx
       mov       rcx,rsi
       call      qword ptr [MostSignificantBit[[System.UInt64, System.Private.CoreLib]](UInt64)]
       mov       rdx,rax
       and       rdx,1
       inc       rax
       shr       rax,1
       lea       rcx,[rax-1]
       inc       rax
       mov       r8d,1
       shlx      rcx,r8,rcx
       mov       r8d,eax
       sub       r8d,edx
       shrx      r8,rsi,r8
       mov       r9,rcx
       sub       r9,r8
       add       rcx,rcx
       mov       r8,r9
       imul      r8,r9
       and       eax,3F
       shrx      r8,r8,rax
       add       r8,r9
       imul      r8,r8
       shrx      r8,r8,rax
       add       r8,r9
       imul      r8,r8
       shrx      r8,r8,rax
       add       r8,r9
       imul      r8,r8
       shrx      r8,r8,rax
       add       r8,r9
       imul      r8,r8
       shrx      r8,r8,rax
       add       r8,r9
       imul      r8,r8
       shrx      r8,r8,rax
       add       r8,r9
       imul      r8,r8
       shrx      r8,r8,rax
       add       r8,r9
       imul      r8,r8
       shrx      r8,r8,rax
       add       r8,r9
       imul      r8,r8
       shrx      r8,r8,rax
       add       r8,r9
       imul      r8,r8
       shrx      r8,r8,rax
       add       r8,r9
       imul      r8,r8
       shrx      r8,r8,rax
       add       r8,r9
       imul      r8,r8
       shrx      r8,r8,rax
       add       r8,r9
       imul      r8,r8
       shrx      r8,r8,rax
       add       r8,r9
       imul      r8,r8
       shrx      r8,r8,rax
       add       r8,r9
       mov       rax,rcx
       sub       rax,r8
       mov       r8,rax
       movsxd    rax,edx
       imul      rdx,r8,4AFB0CCC
       shr       rdx,20
       imul      rax,rdx
       sub       r8,rax
       mov       rax,r8
       imul      rax,r8
       mov       rdx,rsi
       sub       rdx,rax
       mov       rax,8000000000000000
       cmp       rdx,rax
       seta      al
       movzx     eax,al
       sub       r8,rax
       mov       rax,r8
       imul      rax,r8
       mov       rdx,rsi
       sub       rdx,rax
       mov       rax,8000000000000000
       cmp       rdx,rax
       seta      al
       movzx     eax,al
       sub       r8,rax
       mov       rax,r8
       imul      rax,r8
       mov       rdx,rsi
       sub       rdx,rax
       mov       rax,8000000000000000
       cmp       rdx,rax
       seta      al
       movzx     eax,al
       sub       r8,rax
       mov       rax,r8
       imul      rax,r8
       mov       rdx,rsi
       sub       rdx,rax
       mov       rax,8000000000000000
       cmp       rdx,rax
       seta      al
       movzx     eax,al
       sub       r8,rax
       mov       rax,r8
       imul      rax,r8
       mov       rdx,rsi
       sub       rdx,rax
       mov       rax,8000000000000000
       cmp       rdx,rax
       seta      al
       movzx     eax,al
       sub       r8,rax
       mov       rax,r8
       imul      rax,r8
       sub       rsi,rax
       mov       rax,8000000000000000
       cmp       rsi,rax
       seta      al
       movzx     eax,al
       sub       r8,rax
       mov       rax,7FFFFFFFFFFFFFFF
       and       rax,r8
       add       rsp,20
       pop       rsi
       ret
; Total bytes of code 498
\$\endgroup\$

1 Answer 1

1
\$\begingroup\$

how one could solve that problem .... (add support for types that are larger than 128 bits)

Found magic number 1257966796 in Implementation of binary floating-point arithmetic on embedded integer processors. Might help with this goal or just coincidental.


or otherwise improve the function.

Just some minor stuff:

Documentation

Comments in code explaining the algorithm are warranted.

Use hex

Constants 75, 19195, 1257966796, 5402926248376769403 certainly look magical.

At least 0x4B, 0x4AFB, 0x4AFB0CCC, 0x4AFB0CCC06219B7B looks like a pattern.

Let x = 5402926248376769403/264 --> 0.29289321881345247556389585485981.

Notice x is very close to (2 + √2)/2, so the next value may be

(2 + √2)/2 * 2128
99666397752933951918340834954143154528.885... or rounded
99666397752933951918340834954143154529
0x4AFB0CCC06219B7BA682764C8AB54161

"Validating ulong and UInt128 completely is not feasible but I have yet to find any edge cases that fail." --> This also implies OP's 5402926248376769403 may be off-by-1.

Runs in constant time?

Does below run in constant time?

    do {
        i -= (T.One << 3);
        y = (((y * y) >> mPlusOne) + z);
        ...
        y = (((y * y) >> mPlusOne) + z);
    } while (i != T.Zero);

Format uniformity

SquareRoot<T>() ... lacks a preceding blank line.

Simplification?

var mMinusOne = (m - 1);

// var x = (T.One << mMinusOne);
// x += x;

var x = (T.One << m);
\$\endgroup\$
2
  • \$\begingroup\$ Notes: The constants look magical, but aren't (they're clearer in base 10 once one understands that they're a fixed-point representation of sqrt(0.5)). The entire function is essentially a loop that has been unrolled in order to allow the compiler to optimized for commonly supported word sizes, meaning that the final loop truly is constant time. \$\endgroup\$ Commented Jan 15, 2023 at 6:40
  • 1
    \$\begingroup\$ @Kittoes0124 comment would have been useful as comments in code. Further, it does not look like a fixed-point representation of sqrt(0.5), but a fixed-point representation of (2-sqrt(0.5))/2. \$\endgroup\$ Commented Jan 15, 2023 at 6:50

Not the answer you're looking for? Browse other questions tagged or ask your own question.