网格搜索
是一种模型调优的方法,通常与交叉验证
搭配使用。我们先讨论交叉验证
。
交叉验证
交叉验证:为了让被评估的模型更加准确可信。
根据之前的知识,我们知道我们会把数据集分为训练集和测试集。
现在把测试集排除,我们对剩下的训练集再进行划分为训练集和验证集。
通过新的训练集和验证集,我们也可以得到一个准确率。
例如:
把数据分成4等份,这个也被成为4折交叉验证。
第一份 | 第二份 | 第三份 | 第四份 | 结果 |
---|---|---|---|---|
验证集 | 训练集 | 训练集 | 训练集 | 当前情况下,模型的准确率 |
训练集 | 验证集 | 训练集 | 训练集 | 当前情况下,模型的准确率 |
训练集 | 训练集 | 验证集 | 训练集 | 当前情况下,模型的准确率 |
训练集 | 训练集 | 训练集 | 验证集 | 当前情况下,模型的准确率 |
在每一种的训练集和验证集下,我们都可以一个得到模型的准确率,对准确率求平均。
通过这种方法让模型的评估结果更加准确可信。
我们在《深度学习初步及其Python实现:5.过拟合》会再次讨论交叉验证,那时候的交叉验证是为了帮助我们检查到底是欠拟合还是过拟合。在这里,交叉验证的作用多次实验求平均,以帮助我们确定模型最佳的参数。
网格搜索的方法
通常情况下,有些参数是需要指定的。例如,上文kNN中的K。
这种参数叫做超参数
网格搜索的过程为:
- 对模型预设几组超参数组合
- 每组超参数都采用交叉验证进行评估
- 最后选出最优组合
有时候超参数不止一个,比如两个
- a:1 3 5 7 9
- b:0 2 4 6 8
则需要对两个超参数的网格搜索,即两两组合,共有中组合。
网格搜索的实现
我们以上一章的kNN为例。
1 | from sklearn.model_selection import GridSearchCV |
其中
GridSearch
代表网格搜索CV
是Cross Validation
,代表交叉验证
参数有:
- estimator:估计器对象
- param_grid:估计器参数
- dict类型数据,比如:{“n_neighbors”:[1,3,5,7,9]}
- cv:指定几折交叉验证
方法有:
- fit:输入训练数据
- score:在测试集上的准确率
返回有:
- best_score_:最好的结果
- best_estimator_:最好的参数模型
- cv_results_:每次交叉验证后测试集准确率结果和训练集准确率结果
示例代码:
1 | from sklearn.model_selection import train_test_split |
运行结果:
1 | 在测试集上的准确率 |
完整的代码已经PUSH到了我的GitHub上
https://github.com/KakaWanYifan/BaiduPower
本博客所有文章版权为文章作者所有,未经书面许可,任何机构和个人不得以任何形式转载、摘编或复制。
留言板