OpenCV python sklearn随机超参数搜索的实现

yipeiwu_com6年前Python基础

本文介绍了OpenCV python sklearn随机超参数搜索的实现,分享给大家,具体如下:

"""
房价预测数据集 使用sklearn执行超参数搜索
"""
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import sklearn
import pandas as pd
import os
import sys
import tensorflow as tf
from tensorflow_core.python.keras.api._v2 import keras # 不能使用 python
from sklearn.preprocessing import StandardScaler
from sklearn.datasets import fetch_california_housing
from sklearn.model_selection import train_test_split, RandomizedSearchCV
from scipy.stats import reciprocal

os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
assert tf.__version__.startswith('2.')

# 0.打印导入模块的版本
print(tf.__version__)
print(sys.version_info)
for module in mpl, np, sklearn, pd, tf, keras:
  print("%s version:%s" % (module.__name__, module.__version__))


# 显示学习曲线
def plot_learning_curves(his):
  pd.DataFrame(his.history).plot(figsize=(8, 5))
  plt.grid(True)
  plt.gca().set_ylim(0, 1)
  plt.show()


# 1.加载数据集 california 房价
housing = fetch_california_housing()

print(housing.DESCR)
print(housing.data.shape)
print(housing.target.shape)

# 2.拆分数据集 训练集 验证集 测试集
x_train_all, x_test, y_train_all, y_test = train_test_split(
  housing.data, housing.target, random_state=7)
x_train, x_valid, y_train, y_valid = train_test_split(
  x_train_all, y_train_all, random_state=11)

print(x_train.shape, y_train.shape)
print(x_valid.shape, y_valid.shape)
print(x_test.shape, y_test.shape)

# 3.数据集归一化
scaler = StandardScaler()
x_train_scaled = scaler.fit_transform(x_train)
x_valid_scaled = scaler.fit_transform(x_valid)
x_test_scaled = scaler.fit_transform(x_test)


# 创建keras模型
def build_model(hidden_layers=1, # 中间层的参数
        layer_size=30,
        learning_rate=3e-3):
  # 创建网络层
  model = keras.models.Sequential()
  model.add(keras.layers.Dense(layer_size, activation="relu",
                 input_shape=x_train.shape[1:]))
 # 隐藏层设置
  for _ in range(hidden_layers - 1):
    model.add(keras.layers.Dense(layer_size,
                   activation="relu"))
  model.add(keras.layers.Dense(1))

  # 优化器学习率
  optimizer = keras.optimizers.SGD(lr=learning_rate)
  model.compile(loss="mse", optimizer=optimizer)

  return model


def main():
  # RandomizedSearchCV

  # 1.转化为sklearn的model
  sk_learn_model = keras.wrappers.scikit_learn.KerasRegressor(build_model)

  callbacks = [keras.callbacks.EarlyStopping(patience=5, min_delta=1e-2)]

  history = sk_learn_model.fit(x_train_scaled, y_train, epochs=100,
                 validation_data=(x_valid_scaled, y_valid),
                 callbacks=callbacks)
  # 2.定义超参数集合
  # f(x) = 1/(x*log(b/a)) a <= x <= b
  param_distribution = {
    "hidden_layers": [1, 2, 3, 4],
    "layer_size": np.arange(1, 100),
    "learning_rate": reciprocal(1e-4, 1e-2),
  }

  # 3.执行超搜索参数
  # cross_validation:训练集分成n份, n-1训练, 最后一份验证.
  random_search_cv = RandomizedSearchCV(sk_learn_model, param_distribution,
                     n_iter=10,
                     cv=3,
                     n_jobs=1)
  random_search_cv.fit(x_train_scaled, y_train, epochs=100,
             validation_data=(x_valid_scaled, y_valid),
             callbacks=callbacks)
  # 4.显示超参数
  print(random_search_cv.best_params_)
  print(random_search_cv.best_score_)
  print(random_search_cv.best_estimator_)

  model = random_search_cv.best_estimator_.model
  print(model.evaluate(x_test_scaled, y_test))

  # 5.打印模型训练过程
  plot_learning_curves(history)


if __name__ == '__main__':
  main()

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持【听图阁-专注于Python设计】。

相关文章

使用Python中的reduce()函数求积的实例

使用Python中的reduce()函数求积的实例

编写一个prod()函数,可以接受一个list并利用reduce()求积。 from functools import reduce def prod(x,y): return x...

python实现封装得到virustotal扫描结果

本文实例讲述了python实现封装得到virustotal扫描结果的方法。分享给大家供大家参考。具体方法如下: import simplejson import urllib i...

python 实现将文件或文件夹用相对路径打包为 tar.gz 文件的方法

默认情况下,tarfile 打包成的 tar.gz 文件会带绝对路径,而很多情况下,我们需要的是相对打包文件夹的路径。 代码: <pre name="code" class="...

实例解析Python设计模式编程之桥接模式的运用

实例解析Python设计模式编程之桥接模式的运用

我们先来看一个例子: #encoding=utf-8 # #by panda #桥接模式 def printInfo(info): print unicode(i...

python执行shell获取硬件参数写入mysql的方法

本文实例讲述了python执行shell获取硬件参数写入mysql的方法。分享给大家供大家参考。具体分析如下: 最近要获取服务器各种参数,包括cpu、内存、磁盘、型号等信息。试用了Hyp...