文章目录
- 1. 目的
- 2 二分法求开根号
- 2.1 数学原理:单调函数
- 2.2 代码实现:注意事项
- 2.3 代码实现: 完整代码
- 2.4 验证结果
- 3. 牛顿法
- 3.1 数学原理:迭代求解
- 3.2 代码实现
- 3.3 结果
- 4. 卡马克快速法
- 4.1 原理
- 4.2 代码实现
- 4.3 结果
- 5. 完整代码
- 6. References
1. 目的
训练 lenet 需要初始化 kernel 的 weight 和 bias,而使用 Xavier Glorot 初始化则需要计算 sqrt ( 6.0 f a n i n + f a n o u t ) \text{sqrt}(\frac{6.0}{fan_{in} + fan_{out}}) sqrt(fanin+fanout6.0)(均匀分布) 或 sqrt ( 2.0 f a n i n + f a n o u t ) \text{sqrt}(\frac{2.0}{fan_{in} + fan_{out}}) sqrt(fanin+fanout2.0)(高斯分布).(参考[1]). 为了完全用 C 语言实现 lenet 的训练, 避免依赖 C 标准库的数学库函数 sqrt()
, 考虑弄清楚 sqrt()
的原理, 手动实现一个"低配版": 精度有轻微误差,实现简单。
实现开根号的方法,粗略说有三种:
- 二分法
- 牛顿法
- 卡马克公式快速法
本文只考虑 n , n ∈ R + \sqrt{n}, n \in \R^+ n,n∈R+.
2 二分法求开根号
2.1 数学原理:单调函数
对于正实数 n ∈ R + n \in \R^+ n∈R+, 它的二次方根为 x = n x=\sqrt{n} x=n, 也就是使得 x 2 = n x^2=n x2=n 成立的数字。考察方程 f ( x ) = x 2 − n = 0 f(x)=x^2-n=0 f(x)=x2−n=0 的解:
- 如果 n > 1 n > 1 n>1, 则 n ∈ ( 1 , n ) \sqrt{n} \in (1, n) n∈(1,n), 是一个单调递增区间, s.t. f ( x ) \text{s.t.} f(x) s.t.f(x)有解
- 如果 0 < n < 1 0 < n < 1 0<n<1, 则 n ∈ ( n , 1 ) \sqrt{n} \in (n, 1) n∈(n,1), 也是一个单调递增区间, s.t. f ( x ) \text{s.t.} f(x) s.t.f(x)有解
单调性使得我们可以用二分法求解 f ( x ) = x 2 − n = 0 f(x)=x^2-n=0 f(x)=x2−n=0, 从而得到答案 n = x \sqrt{n}=x n=x
2.2 代码实现:注意事项
比较相等
C语言使用 IEEE-754 标准来表示浮点数, 表示的数字可能和理论数字有误差, 因此判断浮点数相等时往往做差值的绝对值然后和 eps 比较, 小于eps就认为相等。
迭代求解
二分法是一个迭代求解算法, 可以手动设置迭代次数, 也可以设置比较精度 eps,迭代过程中精度误差小于 eps 就停止。本文的实现选择设置 eps 的方式。
防止溢出
本文给出的实现,是用 double 类型计算的。 计算两个数字中点时,有可能超过 double 类型最大值, 因此用先求差值的一半,再加到左端点的方式来计算中点。
特殊数字处理
n < 0 n < 0 n<0, 直接返回。
n = 0 n = 0 n=0 和 n = 1 n = 1 n=1, 直接返回。
2.3 代码实现: 完整代码
#include <stdio.h>
#include <stdbool.h>double m_fabs(double n)
{return n >= 0.0 ? n : -n;
}double m_sqrt(double n)
{if (n == 0.0 || n == 1.0){return n;}if (n < 0.f){printf("Error: not supported n: %f\n", n);return -1;}double left, right;if (n > 1.0){left = 1.0;right = n;}else{left = n;right = 1.0;}double left_v = left * left - n;double right_v = right * right - n;if (left_v * right_v > 0){printf("Error: not exist sqrt for n=%f\n", n);return -2;}const double eps = 1e-5;while (left <= right){printf("left=%f, right=%f\n", left, right);double mid = left + (right - left) / 2.0;double value = mid * mid;if (value - n > eps){right = mid;}else if (value - n < -eps){left = mid;}else{return mid;}}return 233;
}int main()
{double n;while (true){printf(">>> Please input an double number: ");scanf("%lf", &n);double ans = m_sqrt(n);printf("sqrt(%lf) = %lf\n", n, ans);}return 0;
}
2.4 验证结果
base) zz@Legion-R7000P% gcc sqrt.c
(base) zz@Legion-R7000P% ./a.out
>>> Please input an double number: 9.0
left=1.000000, right=9.000000
left=1.000000, right=5.000000
sqrt(9.000000) = 3.000000
>>> Please input an double number: 0.04
left=0.040000, right=1.000000
left=0.040000, right=0.520000
left=0.040000, right=0.280000
left=0.160000, right=0.280000
left=0.160000, right=0.220000
left=0.190000, right=0.220000
left=0.190000, right=0.205000
left=0.197500, right=0.205000
left=0.197500, right=0.201250
left=0.199375, right=0.201250
left=0.199375, right=0.200313
left=0.199844, right=0.200313
left=0.199844, right=0.200078
left=0.199961, right=0.200078
sqrt(0.040000) = 0.200020
>>> Please input an double number: 0.01
left=0.010000, right=1.000000
left=0.010000, right=0.505000
left=0.010000, right=0.257500
left=0.010000, right=0.133750
left=0.071875, right=0.133750
left=0.071875, right=0.102813
left=0.087344, right=0.102813
left=0.095078, right=0.102813
left=0.098945, right=0.102813
left=0.098945, right=0.100879
left=0.099912, right=0.100879
left=0.099912, right=0.100396
left=0.099912, right=0.100154
sqrt(0.010000) = 0.100033
>>> Please input an double number: ^C
3. 牛顿法
3.1 数学原理:迭代求解
给定数字 a a a, 求 a \sqrt{a} a. 等价于求方程 f ( x ) = x 2 − a = 0 f(x)=x^2-a = 0 f(x)=x2−a=0 的解。
这个方程在 x 0 x_0 x0 点处的切线 L ( x 0 ) L(x_0) L(x0)方程为 f ( x ) − f ( x 0 ) = f ′ ( x 0 ) ( x − x 0 ) f(x)-f(x_0)=f'(x_0)(x-x_0) f(x)−f(x0)=f′(x0)(x−x0).
切线与 x x x 轴有交点, 也就是当 f ( x ) = 0 f(x)=0 f(x)=0, f ′ ( x 0 ) ( x − x 0 ) + f ( x 0 ) = 0 f'(x_0)(x-x_0) + f(x_0) = 0 f′(x0)(x−x0)+f(x0)=0
⇒ x − x 0 = − f ( x 0 ) / f ′ ( x 0 ) \Rightarrow x-x_0 = -f(x_0)/f'(x_0) ⇒x−x0=−f(x0)/f′(x0)
$\Rightarrow x = x_0 - f(x_0)/f’(x_0) = x_0 - (x_0^2-n)/2x_0 = (x_0 + a/x_0)/2 $
得到 x \sqrt{x} x 的第一个近似解 x 1 = ( x 0 + a x 0 ) / 2 x_1=(x_0+\frac{a}{x_0})/2 x1=(x0+x0a)/2.
通常 x 1 x_1 x1 的精度不足,也就是 x 1 2 x_1 ^ 2 x12 和 a a a 相差比较多,因此还需要继续迭代。迭代到第 n n n 次时:
$\Rightarrow x_{n+1} = x_{n} - \frac{f_n}{f’(x_n)} = \frac{1}{2} (x_n + \frac{a}{x_n}) $
只要此时 x n 2 {x_n}^2 xn2 和 a a a 足够接近, 或者迭代次数 n n n 足够大, 都可以停止迭代, 用 x n x_n xn 作为 a \sqrt{a} a.
3.2 代码实现
#include <stdio.h>
#include <stdbool.h>double m_fabs(double n)
{return n >= 0.0 ? n : -n;
}double m_sqrt_newton(double a)
{// x_{n+1} = \frac{1}{2} (x_n + \frac{a}{x_n})double x = 1.0; // why?double eps = 1e-5;while (m_fabs(x * x - a) > eps){printf("x = %lf\n", x);x = (x + a / x) / 2.0;}return x;
}int main()
{double n;while (true){printf(">>> Please input an double number: ");scanf("%lf", &n);//double ans = m_sqrt(n);//printf("sqrt(%lf) = %lf\n", n, ans);double ans_newton = m_sqrt_newton(n);printf("sqrt_newton(%lf) = %lf\n", n, ans_newton);}return 0;
}
3.3 结果
zz@Legion-R7000P% gcc sqrt.c
zz@Legion-R7000P% ./a.out
>>> Please input an double number: 9.0
x = 1.000000
x = 5.000000
x = 3.400000
x = 3.023529
x = 3.000092
sqrt_newton(9.000000) = 3.000000
>>> Please input an double number: 0.04
x = 1.000000
x = 0.520000
x = 0.298462
x = 0.216241
x = 0.200610
sqrt_newton(0.040000) = 0.200001
>>> Please input an double number: 0.01
x = 1.000000
x = 0.505000
x = 0.262401
x = 0.150255
x = 0.108404
x = 0.100326
sqrt_newton(0.010000) = 0.100001
>>> Please input an double number: ^C
4. 卡马克快速法
4.1 原理
卡马克在雷神之锤游戏中给出了求平方根倒数的一种非常trick的代码实现。把它再求倒数, 就得到开根号结果。
它其实是一种混合方法: 一部分是牛顿法, 另一部分是对数函数的近似。其中牛顿迭代部分用于提升精度, 对数函数的逼近则和 IEEE-754 浮点数表示法紧密结合。
使用的近似公式是 l o g 2 ( 1 + x ) ≈ x + k log_2(1+x) \approx x + k log2(1+x)≈x+k. 见参考[4].
4.2 代码实现
由于 Carmack 快速求平方根的倒数法, 本身目的就是要尽可能快, 因此使用 float 类型而不是 double 类型。
#include <stdio.h>double m_sqrt_carmack(double n)
{int i;float x2, y;const float threehalfs = 1.5f;x2 = n * 0.5f;y = (float)n;i = *(int*)&y;i = 0x5f3759df - (i >> 1);y = *(float *)&i;y = y * (threehalfs - (x2 * y * y)); // 1st iterationy = y * (threehalfs - (x2 * y * y)); // 2nd iterationreturn 1.0 / y;
}int main()
{double n;while (true){printf(">>> Please input an double number: ");scanf("%lf", &n);double ans_carmack = m_sqrt_carmack(n);printf("sqrt_carmack(%lf) = %lf\n", n, ans_carmack);}return 0;
}
4.3 结果
zz@Legion-R7000P% gcc sqrt.c
zz@Legion-R7000P% ./a.out
>>> Please input an double number: 9.0
sqrt_carmack(9.000000) = 3.000006
>>> Please input an double number: 0.04
sqrt_carmack(0.040000) = 0.200001
>>> Please input an double number: 0.01
sqrt_carmack(0.010000) = 0.100000
>>> Please input an double number: ^C
5. 完整代码
// Author: Zhuo Zhang <imzhuo@foxmail.com>
// Homepage: https://github.com/zchrissirhcz
#include <stdio.h>
#include <stdbool.h>double m_fabs(double n)
{return n >= 0.0 ? n : -n;
}double m_sqrt(double n)
{if (n == 0.0 || n == 1.0){return n;}if (n < 0.f){printf("Error: not supported n: %f\n", n);return -1;}double left, right;if (n > 1.0){left = 1.0;right = n;}else{left = n;right = 1.0;}double left_v = left * left - n;double right_v = right * right - n;if (left_v * right_v > 0){printf("Error: not exist sqrt for n=%f\n", n);return -2;}const double eps = 1e-5;while (left <= right){printf("left=%f, right=%f\n", left, right);double mid = left + (right - left) / 2.0;double value = mid * mid;if (value - n > eps){right = mid;}else if (value - n < -eps){left = mid;}else{return mid;}}return 233;
}double m_sqrt_newton(double a)
{// x_{n+1} = \frac{1}{2} (x_n + \frac{a}{x_n})double x = 1.0; // why?double eps = 1e-5;while (m_fabs(x * x - a) > eps){printf("x = %lf\n", x);x = (x + a / x) / 2.0;}return x;
}double m_sqrt_carmack(double n)
{int i;float x2, y;const float threehalfs = 1.5f;x2 = n * 0.5f;y = (float)n;i = *(int*)&y;i = 0x5f3759df - (i >> 1);y = *(float *)&i;y = y * (threehalfs - (x2 * y * y)); // 1st iterationy = y * (threehalfs - (x2 * y * y)); // 2nd iterationreturn 1.0 / y;
}int main()
{double n;while (true){printf(">>> Please input an double number: ");scanf("%lf", &n);double ans = m_sqrt(n);printf("sqrt(%lf) = %lf\n", n, ans);double ans_newton = m_sqrt_newton(n);printf("sqrt_newton(%lf) = %lf\n", n, ans_newton);double ans_carmack = m_sqrt_carmack(n);printf("sqrt_carmack(%lf) = %lf\n", n, ans_carmack);}return 0;
}
6. References
- [1] https://www.bookstack.cn/read/paddlepaddle-1.6/3f4d0d9266a7a5c8.md
- [2] https://www.cnblogs.com/wangkundentisy/p/8118007.html
- [3] https://blog.csdn.net/plm199513100/article/details/124072422
- [4] 【回归本源】番外1-雷神之锤3的那段代码