在之前的BLOG里,我们介绍了通过梯度下降法来解决回归问题。在这篇BLOG中,就让我们学习一种新的解析方法——正规方程法吧。

正规方程法

到之前的BLOG中,我们一直在使用梯度下降法来解决线性回归问题,也就是使用的迭代算法,经过很多步迭代来计算最值所对应的 θ 。但正规方程法(Normal Equation) 提供了一种求 θ 的解析解法,也就是可以一次求解最优时 θ 的值。

在讲正规方程法之前,我们再讲讲另一种解析解法——方程法。先举一个例子来解释这个问题,我们假设,有一个非常简单的代价函数 J(θ) 其是关于参数 θ 的二次函数:

ZG1.png

那么如何求解一个函数的最低点呢?学过数学分析/高等数学的同学就知道,要求解一个函数的最值,其实就是要去找那些导数值等于 0 或者导数值不存在的点。当 θ 只是一个实数的时候,这时很容易实现的,只需要解一个方程就可以了;但如果 θ 不是一个实数,而是一个 n+1 维的参数向量的时候我们又该怎么办呢?此时代价函数 J 是向量 θ 的函数,也就是 θ0 到 θn 的函数,大概是这样:

ZG2.png

我们其实要对每个参数 θ 求 J 的偏导数,然后求出满足所有偏导数都不存在或等于 0 的解,最值就出现在里面,但是这时非常麻烦的。

但我们我们接下来要讲的正规方程法不需要这些数学分析/高等代数的知识,我们只需要一些线性代数的知识就可以了。具体怎么操作,让我们来再次通过卖房子的例子来说明吧。假如说我有 m = 4 的训练样本:

ZG3.png

我们首先要做的就是在训练集中加上一列对应额外特征变量的x,其值恒为 1:

ZG4.png

下来我们要构建一个矩阵 X 这个矩阵基本包含了训练样本的所有特征变量,这个矩阵基本包含了训练样本的所有特征变量,也即是说我们要将表格中的所有特征值放到一个矩阵当中,所以 X 会是一个 m*(n+1) 维矩阵:

ZG5.png

接下来我们还要构造一个向量y,我们要将要预测的值建一个向量,并且称之为向量 y ,所以 y 会是一个 m 维向量:

ZG6.png

接着我们来分析一下,我们最后求出的回归方程应该是 h0(x) = 00 + 01 * x1 + 02 * x2 + 03 * x3 + 04 * x4,我们一样地,构造一个向量 p 为 p = [00, 01, 02, 03, 04],那如果我们的数据是完全在我们的回归直线上,那就应该有 X * p = y,即p = X^T * y

但实际上,我们会遇到两个问题,第一数据大多情况下是均匀地分布在我们的直线两边的,所以我们可以用线性代数中的最小二乘法进行逼近;第二,只有方阵才有逆矩阵,而大多数情况下我们的 X 都是普通的矩阵,所以我们要用到SVQ里面用到过的伪逆的概念。

思路就是上面说的那样,具体的过程我就不进行推到了,其实最优情况下的向量 0 就是满足下面一个式子:

ZG7.png

其中(X^TX)^-1代表X^TX的逆矩阵。再Octave中,这可以写成很简洁的一行代码pinv(x' * x) * x' * y。我没有去证明这个式子,因为证明过程可以在线性代数的树上找到,但你可以相信这个式子会给出最优的 θ 值。

最后我们来谈一谈什么时候可以使用这个正规方程法,其实最重要的就是这个特征值的数目n的大小。因为我们要计算(X^T*X)^-1,这大概是个O(n^3)的工作,所以当n≥10000时,实际上花费的时间可能就不如梯度下降法了。但相反地,在 n 不太大的时候,正规方程法还是一种很不错的算法。

X^T*X的可逆性

可能有的同学会好奇,θ = inv(X'X )*X'y ,那对于矩阵X'X的结果是不可逆的情况咋办呢? 确实如果你懂一点线性代数的知识 ,有的有些矩阵可逆,而有些矩阵不可逆。其实X'X的不可逆的问题很少发生,在Octave里如果你用pinv来实现逆矩阵的计算,就算该矩阵不可逆,你也会得到一个正常的解,因为这时计算机帮你找到了一个类似于逆矩阵的矩阵,此时你的不可逆矩阵乘以这个矩阵是最接近单位矩阵的。

其实在Octave里,有两个函数可以求解矩阵的逆,一个被称为pinv()另一个是inv(),这两者之间的差异是些许计算过程上的,前者是所谓的伪逆,后者被称为严格的逆,使用pinv() 函数计算出比较正常的解,便矩阵X'X是不可逆的。

尽管如此,我们还是要尽量避免X'X结果是不可逆的情况的出现。如果矩阵X'X结果是不可逆的,通常有两种最常见的原因:

第一个原因是可能出现了多余的特征值,它和其他特征值始终保持线性关系。例如在预测住房价格时,如果x1是以英尺为尺寸规格计算的房子,x2是以平方米为尺寸规格计算的房子,同时我们找到1米恒等于3.28英尺,那 x1 就恒等于 x2 的 3.28平方倍,即它们是线性相关的,所以此时 X'X 就会不可逆了。

第二个原因是可能在你的特征值太大了,这也可能会导致矩阵X'X的结果是不可逆的。具体地说,在训练样本数 m 小于或等于特征数 n 的时候,你的矩阵 X'X 就很可能是不可逆的了,这时你就要填充更多的数据或者重新审视你特征值的选取,是否有必要选取这么多之类的。

其实不可逆的情况在实际的操作中还是少之又少的,再加上Octave能够自动处理,所以只要知道大概的解决方案就可以了,不必倾注太多时间。

代码实现

思路在上面已经写的比较清楚了,下面我用Cpp进行了实现,供大家学习,其中求逆矩阵用了高斯消元法,不了解的同学可以点击这里进行学习

#include <cmath>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <algorithm>
using namespace std;
int n, m;
const double EPS = 1e-6;
double cpu[1100][2200];
double temp[1100][2200];
double x[1100][2200], y[1100];
double a[1100][2200],ans[2200];
const double Alpha = 3 * 1e-3;
void read() {
    printf("请输入样本大小m:\n");
    scanf("%d", &m);

    printf("请输入特征值的个数n:\n");
    scanf("%d", &n);

    for(int i = 1; i <= m; i++) x[i][1] = 1;
    n++;

    printf("请依次输入m个特征值xi和结果y:\n");
    for(int i = 1; i <= m ; i++) { 
        for(int j = 2; j <= n; j++) {
            scanf("%lf", &x[i][j]);
        }
        scanf("%lf", &y[i]);
    }
    return;
}

void gauss() {
    double multi;
    for(int i = 1; i <= n; i++) {//从左到右枚举第i列(即枚举主元列)

        if(fabs(a[i][i]) < EPS) {
            printf("Error!\n");
            exit(0);
        }

        multi = a[i][i];

        for(int j = i; j <= 2 * n; j++) a[i][j] /= multi;

        for(int j = 1; j <= n; j++) {
            if(j != i) {
                if(a[j][i] == 0)continue;
                else multi= a[j][i] / a[i][i];
                for(int k = i; k <= 2 * n; k++) {
                    a[j][k] -= a[i][k] * multi;
                }
            }
        }
        //通过初等行变化(即减去第i行的某个倍数),我们将除第i行的第i列都变成0;
    }
}

void work() {
    memset(ans, 0, sizeof(ans));
    memset(cpu, 0, sizeof(cpu));
    memset(a, 0, sizeof(a));
    for(int i = 1; i <= n; i++) {
        for(int j = 1; j <= n; j++) {
            for(int k = 1; k <= m; k++) {
                a[i][j] += x[k][i] * x[k][j];
            }
        }
    }
    for(int i = 1; i <= n; i++) {
        a[i][i + n] = 1;
    }

    gauss();

    for(int i = 1; i <= n; i++) {
        for(int j = 1; j <= n; j++) {
            temp[i][j] = a[i][j + n];
        }
    }

    for(int i = 1; i <= n; i++) {
        for(int j = 1; j <= m; j++) {
            for(int k = 1; k <= n; k++) {
                cpu[i][j] += temp[i][k] * x[j][k];
            }
        }
    }

    for(int i = 1; i <= n; i++) {
        for(int j = 1; j <= m; j++) {
            ans[i] += cpu[i][j] * y[j];
        }
    }

    return ;
}

void write() {
    printf("拟合出的直线为:");
    printf("y = %0.1lf", ans[1]);
    for(int i = 2; i <= n; i++) {
        if(ans[i] >= 0)printf(" + %0.1f * x%d", ans[i], i);
        else printf(" - %0.1f * x%d", -ans[i], i);
    }
    printf("\n");
}
int main() {

    read();
    work();
    write();
    return 0;
}  

效果也是极为不错的:

ZG8.png

结语

通过这篇BLOG,相信你已经掌握了机器学习的第二个算法正规方程法,之后我们还会继续学习更多更高效的算法。最后希望你喜欢这篇BLOG!

Last modification:March 4th, 2021 at 03:43 am
If you think my article is useful to you, please feel free to appreciate