关于Tensorflow2.0 keras的子类式多输入多输出

  1.关键代码

  在定义好输入层、输出层后使用类 配置inputs outputs参数(数组)

  网络模型搭建

  class WideDeepModel(tf.keras.models.Model):

  def __init__(self):

  super(WideDeepModel, self).__init__()

  self.hidden1_layer = tf.keras.layers.Dense(30, activation='relu')

  self.hidden2_layer = tf.keras.layers.Dense(30, activation='relu')

  self.output_layer1 = tf.keras.layers.Dense(1)

  self.output_layer2 = tf.keras.layers.Dense(1)

  def call(self, inputs, training=None, mask=None):

  """完成模型的正向计算"""

  input_wide = inputs[0] # 输入1

  input_deep = inputs[1] # 输入2

  hidden1 = self.hidden1_layer(input_deep)

  hidden2 = self.hidden2_layer(hidden1)

  concat = tf.keras.layers.concatenate([input_wide, hidden2])

  output1 = self.output_layer1(concat) # 输出1

  output2 = self.output_layer2(hidden2) # 输出2

  return [output1, output2] # 输出组合

  # 构建网络

  model = WideDeepModel()

  model.build(input_shape=[(None, 5), (None, 6)])

  print(model.layers)

  model.summary()

  完整代码:

  import pprint

  import sys

  import matplotlib as mpl

  import matplotlib.pyplot as plt

  import numpy as np

  import pandas as pd

  import sklearn

  import tensorflow as tf

  from tensorflow import keras

  print(tf.__version__)

  print(sys.version_info)

  for module in mpl, np, pd, sklearn, keras, tf:

  print(module.__name__, module.__version__)

  from sklearn.datasets import fetch_california_housing

  # 1.加载数据集 波士顿房价预测

  housing = fetch_california_housing()

  print(housing.DESCR)

  print(housing.data.shape)

  print(housing.target.shape)

  pprint.pprint(housing.data[:5])

  pprint.pprint(housing.target[:5])

  from sklearn.model_selection import train_test_split

  # 2.拆分数据集

  # 训练集与测试集拆分

  x_train_all, x_test, y_train_all, y_test = train_test_split(housing.data,

  housing.target,

  random_state=7,

  test_size=0.20)

  # 训练集与验证集的拆分

  x_train, x_valid, y_train, y_valid = train_test_split(

  x_train_all, y_train_all, random_state=11, test_size=0.20)

  print(x_train.shape, y_train.shape)

  print(x_valid.shape, y_valid.shape)

  print(x_test.shape, y_test.shape)

  from sklearn.preprocessing import StandardScaler

  scaler = StandardScaler()

  # 3、数据预处理 数据集的归一化

  x_train_scaled = scaler.fit_transform(x_train)

  x_valid_scaled = scaler.transform(x_valid)

  x_test_scaled = scaler.transform(x_test)

  # 4、网络模型的搭建

  # 子类API

  class WideDeepModel(tf.keras.models.Model):

  def __init__(self):

  super(WideDeepModel, self).__init__()

  self.hidden1_layer = tf.keras.layers.Dense(30, activation='relu')

  self.hidden2_layer = tf.keras.layers.Dense(30, activation='relu')

  self.output_layer1 = tf.keras.layers.Dense(1)

  self.output_layer2 = tf.keras.layers.Dense(1)

  def call(self, inputs, training=None, mask=None):

  """完成模型的正向计算"""

  input_wide = inputs[0] # 输入1

  input_deep = inputs[1] # 输入2

  hidden1 = self.hidden1_layer(input_deep)

  hidden2 = self.hidden2_layer(hidden1)

  concat = tf.keras.layers.concatenate([input_wide, hidden2])

  output1 = self.output_layer1(concat)

  output2 = self.output_layer2(hidden2)

  return [output1, output2]

  # 构建网络 大连专业人流医院 http://www.dlrlyy.net/

  model = WideDeepModel()

  model.build(input_shape=[(None, 5), (None, 6)])

  print(model.layers)

  model.summary()

  # 5、模型的编译 设置损失函数 优化器

  model.compile(loss='mean_squared_error',

  optimizer='adam')

  # 6、设置回调函数

  callbacks = [tf.keras.callbacks.EarlyStopping(patience=5, min_delta=1e-3)]

  # 7、训练网络

  x_train_scaled_wide = x_train_scaled[:, :5]

  x_train_scaled_deep = x_train_scaled[:, 2:]

  x_valid_scaled_wide = x_valid_scaled[:, :5]

  x_valid_scaled_deep = x_valid_scaled[:, 2:]

  x_test_scaled_wide = x_test_scaled[:, :5]

  x_test_scaled_deep = x_test_scaled[:, 2:]

  history = model.fit([x_train_scaled_wide, x_train_scaled_deep],

  [y_train, y_train],

  validation_data=(

  [x_valid_scaled_wide, x_valid_scaled_deep],

  [y_valid, y_valid]),

  epochs=20,

  callbacks=callbacks)

  # 8、绘制训练过程数据

  def plot_learning_curves(hst):

  pd.DataFrame(hst.history).plot()

  plt.grid(True)

  plt.gca().set_ylim(0, 1)

  plt.show()

  plot_learning_curves(history)

  # 9.验证数据

  model.evaluate([x_test_scaled_wide, x_test_scaled_deep], [y_test, y_test])


请使用浏览器的分享功能分享到微信等