Python机器学习之随机梯度下降法如何实现

发布时间:2023-02-27 10:52:38 作者:iii
来源:亿速云 阅读:107

本文小编为大家详细介绍“Python机器学习之随机梯度下降法如何实现”,内容详细,步骤清晰,细节处理妥当,希望这篇“Python机器学习之随机梯度下降法如何实现”文章能帮助大家解决疑惑,下面跟着小编的思路慢慢深入,一起来学习新知识吧。

随机梯度下降法

为什么使用随机梯度下降法?

如果当我们数据量和样本量非常大时,每一项都要参与到梯度下降,那么它的计算量时非常大的,所以我们可以采用随机梯度下降法。

Python机器学习之随机梯度下降法如何实现

随机梯度下降法中的学习率必须是随着循环的次数增加而递减的。如果eta取一样的话有可能在非常接近我们的最优值时会跳过,所以随着迭代次数的增加,学习率eta要随之减小,我们可以用模拟退火的思想实现(如下图所示),t0和t1是一个常数,定值,其通常是根据经验取得一些值。

Python机器学习之随机梯度下降法如何实现

随机梯度下降法的实现

随机梯度下降法的公式如下图所示,其中挑出一个样本出来计算。

Python机器学习之随机梯度下降法如何实现

先创建x,y,以下取10000个样本

import numpy as np

m = 10000

x = np.random.random(size=m)
y = x*3 + 4 + np.random.normal(size=m)

写入函数

def dj_sgd(theta, x_i, y_i): # 传入一个样本,获取对应的梯度
    return x_i.T.dot(x_i.dot(theta)-y_i)*2 # MSE

def sgd(X_b, y, initial_theta, n_iters): # 求出整个theta的函数
    def learning_rate(i_iter):
        t0 = 5
        t1 = 50
        return t0/(i_iter+t1)
    theta = initial_theta
    i_iter = 1
    
    while i_iter <= n_iters:
        index = np.random.randint(0, len(X_b))
        x_i = X_b[index]
        y_i = y[index]
        gradient = dj_sgd(theta, x_i, y_i) # 求导数
        theta = theta - gradient*learning_rate(i_iter) # 求步长
        i_iter += 1
    return theta

调用函数,求出截距和系数

Python机器学习之随机梯度下降法如何实现

以上随机梯度的缺点是不能照顾到每一点,因此需要进行改进。

以下对其中的函数进行修改。

def dj_sgd(theta, x_i, y_i): # 传入一个样本,获取对应的梯度
    return x_i.T.dot(x_i.dot(theta)-y_i)*2 # MSE

def sgd(X_b, y, initial_theta, n_iters): # 求出整个theta的函数
    def learning_rate(i_iter):
        t0 = 5
        t1 = 50
        return t0/(i_iter+t1)
    theta = initial_theta
    m = len(X_b)
    
    for cur_iter in range(n_iters): # 每一次循环都把样本打乱,n_iters的代表整个样本看几轮
        random_indexs = np.random.permutation(m)
        X_random = X_b[random_indexs]
        y_random = y[random_indexs]
        for i in range(m):
            theta = theta - learning_rate(cur_iter*m+i) * (dj_sgd(theta, X_random[i], y_random[i]))
        return theta

与前边运算结果进行对比,其耗时更长。

Python机器学习之随机梯度下降法如何实现

读到这里,这篇“Python机器学习之随机梯度下降法如何实现”文章已经介绍完毕,想要掌握这篇文章的知识点还需要大家自己动手实践使用过才能领会,如果想了解更多相关内容的文章,欢迎关注亿速云行业资讯频道。

推荐阅读:
  1. python学习感言
  2. python处理http请求中特殊字符($、#、?)的问题

免责声明:本站发布的内容(图片、视频和文字)以原创、转载和分享为主,文章观点不代表本网站立场,如果涉及侵权请联系站长邮箱:is@yisu.com进行举报,并提供相关证据,一经查实,将立刻删除涉嫌侵权内容。

python

上一篇:Go中的channel怎么声明和使用

下一篇:pandas中按行或列的值对数据排序如何实现

相关阅读

您好,登录后才能下订单哦!

密码登录
登录注册
其他方式登录
点击 登录注册 即表示同意《亿速云用户服务条款》