首页 > 笔记 > 高斯消元总结

高斯消元总结

最近在做概率dp的题,用到了这种东西。

高斯消元,听起来非常高大上。其实我们小学都用过。

比如我们现在有一堆方程。

$$
\begin{align*}
6x+8y+6z &= 10 \\
4x+7y+9z &= 10 \\
4x+6y+7z &= 4 \\
\end{align*}
$$

该如何解呢?我们的基本思路就是先通过某个式子同时乘某个数,然后减去另一个式子,就减少了一个未知数。

然而我们发现不需要把所有的未知数全部消掉。比如我们可以第一行不变,第二行只消x,第三行消x和y,这样子从第三行解出z,带回到第二行,解出y,再带回第一个解出x。这样就会消出一个斜的三角形矩阵。效率是$O(n^3)$的。

高斯消元就是这个思路,它先把系数写成一个矩阵:

\begin{bmatrix}
6 & 8 & 6 & 10\\
4 & 7 & 9 & 10\\
4 & 6 & 7 & 4\\
\end{bmatrix}

现在我们要消除x。

第二行乘$3 \over 2$再减第一行

\begin{bmatrix}
6 & 8 & 6 & 10\\
0 & 2.5 & 7.5 & 5\\
4 & 6 & 7 & 4\\
\end{bmatrix}

第三行乘$3 \over 2$再减第一行

\begin{bmatrix}
6 & 8 & 6 & 10\\
0 & 2.5 & 7.5 & 5\\
0 & 1 & 4.5 & -4\\
\end{bmatrix}

现在消y。

第三行乘$5 \over 2$再减第二行

\begin{bmatrix}
6 & 8 & 6 & 10\\
0 & 2.5 & 7.5 & 5\\
0 & 0 & 3.75 & -15\\
\end{bmatrix}

这样我们就得出了$z=-4$,带回第二行得到

$$2.5y=35$$

得到$y=14$,带回第一行:

$$6x=-78$$

所以:

$$
\begin{align*}
x &= -13 \\
y &= 14 \\
z &= -4 \\
\end{align*}
$$

还有一点要注意,假如有一行当前的值为0,比如

\begin{bmatrix}
0 & 8 & 6 & 10\\
4 & 7 & 9 & 10\\
4 & 6 & 7 & 4\\
\end{bmatrix}

我们就需要把这行和下面的某行换下,要不就会无解。

在实际操作中,其实是找到下面当前这列值最大的换上来。这样能避免无解,也可以减少精度误差。

代码:

#include <cstdio>
#include <algorithm>
using namespace std;
double f[92][92],ans[92];
int n;
const double eps=1e-12;
int dcmp(double x)
{
    if (x<=eps&&x>=-eps) return 0;
    return (x>0)?1:-1;
}
bool gauss()
{
    for (int i=1;i<=n;i++)
    {
        int num=i;
        for (int j=i+1;j<=n;j++)
            if (dcmp(f[j][i]-f[num][i])>0) num=j;
        if (num!=i)
            for (int j=1;j<=n+1;j++)
                swap(f[i][j],f[num][j]);
        for (int j=i+1;j<=n;j++)
            if (dcmp(f[j][i]))
            {
                double t=f[j][i]/f[i][i];
                for (int k=1;k<=n+1;k++)
                    f[j][k]-=t*f[i][k];
            }
    }
    for (int i=n;i>=1;i--)
    {
        if (dcmp(f[i][i])==0)
            return 0;
        for (int j=i+1;j<=n;j++)
            f[i][n+1]-=f[i][j]*ans[j];
        ans[i]=f[i][n+1]/f[i][i];
    }
    return 1;
}
main()
{
    scanf("%d",&n);
    for (int i=1;i<=n;i++)
        for (int j=1;j<=n+1;j++)
            scanf("%lf",&f[i][j]);
    if (gauss())
        for (int i=1;i<=n;i++)
            printf("%.2lf\n",ans[i]);
    else
        puts("No Solution");
}