Scikit-learn

Scikit-learn中怎么实现网格搜索

小亿
104
2024-05-10 17:18:55
栏目: 编程语言

在Scikit-learn中,可以使用GridSearchCV类实现网格搜索。GridSearchCV类可以用来选择最优的参数组合,从而优化模型的性能。

下面是一个简单的示例代码,演示如何使用GridSearchCV进行网格搜索:

from sklearn.model_selection import GridSearchCV
from sklearn.svm import SVC
from sklearn.datasets import load_iris

# 加载数据
iris = load_iris()
X = iris.data
y = iris.target

# 定义要搜索的参数网格
param_grid = {'C': [0.1, 1, 10, 100], 'gamma': [0.001, 0.01, 0.1, 1]}

# 创建模型
svm = SVC()

# 创建GridSearchCV对象
grid_search = GridSearchCV(svm, param_grid, cv=5)

# 进行网格搜索
grid_search.fit(X, y)

# 输出最优参数组合和对应的评分
print("Best parameters: {}".format(grid_search.best_params_))
print("Best cross-validation score: {:.2f}".format(grid_search.best_score_))

在上面的代码中,首先加载了Iris数据集,并定义了要搜索的参数网格。然后创建了一个SVC模型,并使用GridSearchCV类进行网格搜索。最后输出了最优的参数组合和对应的评分。

通过使用GridSearchCV类,可以方便地进行参数调优,从而提高模型的性能。

0
看了该问题的人还看了