Java线性回归基础代码怎么写

发布时间:2021-12-21 10:44:24 作者:iii
来源:亿速云 阅读:127

这篇文章主要介绍“Java线性回归基础代码怎么写”,在日常操作中,相信很多人在Java线性回归基础代码怎么写问题上存在疑惑,小编查阅了各式资料,整理出简单好用的操作方法,希望对大家解答”Java线性回归基础代码怎么写”的疑惑有所帮助!接下来,请跟着小编一起来学习吧!

# Use linear model to model this data.
from sklearn.linear_model import LinearRegression
import numpy as np
lr=LinearRegression()
lr.fit(pga.distance[:,np.newaxis],pga['accuracy']) # Another way is using pga[['distance']]
theta0=lr.intercept_
theta1=lr.coef_
print(theta0)
print(theta1)
#calculating cost-function for each theta1
#计算平均累积误差
def cost(x,y,theta0,theta1):
 J=0
 for i in range(len(x)):
 mse=(x[i]*theta1+theta0-y[i])**2
 J+=mse
 return J/(2*len(x))
theta0=100
theta1s = np.linspace(-3,2,197)
costs=[]
for theta1 in theta1s:
 costs.append(cost(pga['distance'],pga['accuracy'],theta0,theta1))
plt.plot(theta1s,costs)
plt.show()
print(pga.distance)
#调整theta
def partial_cost_theta0(x,y,theta0,theta1):
 #我们的模型是线性拟合函数时:y=theta1*x + theta0,而不是sigmoid函数,当非线性时我们可以用sigmoid
 #直接多整个x series操作,省的一个一个计算,最终求sum 再平均
 h=theta1*x+theta0 
 diff=(h-y)
 partial=diff.sum()/len(diff)
 return partial
partial0=partial_cost_theta0(pga.distance,pga.accuracy,1,1)
def partial_cost_theta1(x,y,theta0,theta1):
 #我们的模型是线性拟合函数:y=theta1*x + theta0,而不是sigmoid函数,当非线性时我们可以用sigmoid
 h=theta1*x+theta0
 diff=(h-y)*x
 partial=diff.sum()/len(diff)
 return partial
partial1=partial_cost_theta1(pga.distance,pga.accuracy,0,5)
print(partial0)
print(partial1)
def gradient_descent(x,y,alpha=0.1,theta0=0,theta1=0): #设置默认参数
 #计算成本
 #调整权值
 #计算错误代价,判断是否收敛或者达到最大迭代次数
 most_iterations=1000
 convergence_thres=0.000001 
 
 c=cost(x,y,theta0,theta1)
 costs=[c]
 cost_pre=c+convergence_thres+1.0
 
 counter=0
 while( (np.abs(c-cost_pre)>convergence_thres) & (counter<most_iterations) ):
 update0=alpha*partial_cost_theta0(x,y,theta0,theta1)
 update1=alpha*partial_cost_theta1(x,y,theta0,theta1)
 
 theta0-=update0
 theta1-=update1
 cost_pre=c
 c=cost(x,y,theta0,theta1)
 costs.append(c)
 counter+=1
 return {'theta0': theta0, 'theta1': theta1, "costs": costs}
print("Theta1 =", gradient_descent(pga.distance, pga.accuracy)['theta1'])
costs=gradient_descent(pga.distance,pga.accuracy,alpha=.01)['cost']
print(gradient_descent(pga.distance, pga.accuracy,alpha=.01)['theta1'])
plt.scatter(range(len(costs)),costs)
plt.show()
预览

数据集 :

复制下面数据,保存为: pga.csv

distance,accuracy
290.3,59.5
302.1,54.7
287.1,62.4
282.7,65.4
299.1,52.8
300.2,51.1
300.9,58.3
279.5,73.9
287.8,67.6
284.7,67.2
296.7,60
283.3,59.4
284,72.2
292,62.1
282.6,66.5
287.9,60.9
279.2,67.3
291.7,64.8
289.9,58.1
289.8,61.7
298.8,56.4
280.8,60.5
294.9,57.5
287.5,61.8
282.7,56
277.7,72.5
270.5,71.7
285.2,66
315.1,55.2
281.9,67.6
293.3,58.2
286,59.9
285.6,58.2
289.9,65.7
277.5,59
293.6,56.8
301.1,65.4
300.8,63.4
287.4,67.3
281.8,72.6
277.4,63.1
279.1,66.5
287.4,66.4
280.9,62.3
287.8,57.2
261.4,69.2
272.6,69.4
291.3,65.3
294.2,52.8
285.5,49
287.9,61.1
282.2,65.6
301.3,58.2
276.2,61.7
281.6,68.1
275.5,61.2
309.7,53.1
287.7,56.4
291.6,56.9
284.1,65
299.6,57.5
282.7,60
271.5,72
292.1,58.2
295,59.4
274.9,69
273.6,68.7
299.9,60.1
279.9,74
289.9,66
283.6,59.8
310.3,52.4
291.7,65.6
284.2,63.2
295,53.5
298.6,55.1
297.4,60.4
299.7,67.7
284.4,69.7
286.4,72.4
285.9,66.9
297.6,54.3
272.5,62
277,66.2
287.6,60.9
280.4,69.4
280,63.7
295.4,52.8
274.4,68.8
286.5,73.1
287.7,65.2
291.5,65.9
279,69.4
299,65.2
290.1,69.1
288.9,67.9
288.8,68.2
283.2,61
293.2,58.4
285.3,67.3
284.1,65.7
281.4,67.7
286.1,61.4
284.9,62.3
284.8,68.1
296,62
282.9,71.8
280.9,67.8
291.2,62
292.8,62.2
291,61.9
285.7,62.4
283.9,62.9
298.4,61.5
285.1,65.3
286.1,60.1
283.1,65.4
289.4,58.3
284.6,70.7
296.6,62.3
295.9,64.9
295.2,62.8
293.9,54.5
275,65.5
286.8,69.5
291.1,64.4
284.8,62.5
283.7,59.5
295.4,66.9
291.8,62.7
274.9,72.3
302.9,61.2
272.1,80.4
274.9,74.9
296.3,59.4
286.2,58.8
294.2,63.3
284.1,66.5
299.2,62.4
275.4,71
273.2,70.9
281.6,65.9
295.7,55.3
287.1,56.8
287.7,66.9
296.7,53.7
282.2,64.2
291.7,65.6
281.6,73.4
311,56.2
278.6,64.7
288,65.7
276.7,72.1
292,62
286.4,69.9
292.7,65.7
294.2,62.9
278.6,59.6
283.1,69.2
284.1,66
278.6,73.6
291.1,60.4
294.6,59.4
274.3,70.5
274,57.1
283.8,62.7
272.7,66.9
303.2,58.3
282,70.4
281.9,61
287,59.9
293.5,63.8
283.6,56.3
296.9,55.3
290.9,58.2
303,58.1
292.8,61.1
281.1,65
293,61.1
284,66.5
279.8,66.7
292.9,65.4
284,66.9
282,64.5
280.6,64
287.7,63.4
287.7,63.4
298.3,59.5
299.6,53.4
291.3,62.5
295.2,61.4
288,62.4
297.8,59.5
286,62.6
285.3,66.2
286.9,63.4
275.1,73.7

到此,关于“Java线性回归基础代码怎么写”的学习就结束了,希望能够解决大家的疑惑。理论与实践的搭配能更好的帮助大家学习,快去试试吧!若想继续学习更多相关知识,请继续关注亿速云网站,小编会继续努力为大家带来更多实用的文章!

推荐阅读:
  1. 关于一些基础代码的实现
  2. java输入语句怎么写

免责声明:本站发布的内容(图片、视频和文字)以原创、转载和分享为主,文章观点不代表本网站立场,如果涉及侵权请联系站长邮箱:is@yisu.com进行举报,并提供相关证据,一经查实,将立刻删除涉嫌侵权内容。

java

上一篇:基于Arrays.sort()和lambda表达式如何实现

下一篇:java如何将list字符串用逗号隔开拼接字符串

相关阅读

您好,登录后才能下订单哦!

密码登录
登录注册
其他方式登录
点击 登录注册 即表示同意《亿速云用户服务条款》