keras回调函数如何使用

发布时间:2023-03-13 14:12:49 作者:iii
来源:亿速云 阅读:108

这篇文章主要介绍了keras回调函数如何使用的相关知识,内容详细易懂,操作简单快捷,具有一定借鉴价值,相信大家阅读完这篇keras回调函数如何使用文章都会有所收获,下面我们一起来看看吧。

回调函数

fit()方法中使用callbacks参数

# 这里有两个callback函数:早停和模型检查点
callbacks_list=[
    keras.callbacks.EarlyStopping(
        monitor="val_accuracy",#监控指标
        patience=2 #两轮内不再改善中断训练
    ),
    keras.callbacks.ModelCheckpoint(
        filepath="checkpoint_path",
        monitor="val_loss",
        save_best_only=True
    )
]
#模型获取
model=get_minist_model()
model.compile(optimizer="rmsprop",
             loss="sparse_categorical_crossentropy",
             metrics=["accuracy"])

model.fit(train_images,train_labels,
         epochs=10,callbacks=callbacks_list, #该参数使用回调函数
         validation_data=(val_images,val_labels))

test_metrics=model.evaluate(test_images,test_labels)#计算模型在新数据上的损失和指标
predictions=model.predict(test_images)#计算模型在新数据上的分类概率

keras回调函数如何使用

模型的保存和加载

#也可以在训练完成后手动保存模型,只需调用model.save('my_checkpoint_path')。
#重新加载模型
model_new=keras.models.load_model("checkpoint_path.keras")

通过对Callback类子类化来创建自定义回调函数

on_epoch_begin(epoch, logs) ←----在每轮开始时被调用
on_epoch_end(epoch, logs) ←----在每轮结束时被调用
on_batch_begin(batch, logs) ←----在处理每个批量之前被调用
on_batch_end(batch, logs) ←----在处理每个批量之后被调用
on_train_begin(logs) ←----在训练开始时被调用
on_train_end(logs ←----在训练结束时被调用

from matplotlib import pyplot as plt
# 实现记录每一轮中每个batch训练后的损失,并为每个epoch绘制一个图
class LossHistory(keras.callbacks.Callback):
    def on_train_begin(self, logs):
        self.per_batch_losses = []

    def on_batch_end(self, batch, logs):
        self.per_batch_losses.append(logs.get("loss"))

    def on_epoch_end(self, epoch, logs):
        plt.clf()
        plt.plot(range(len(self.per_batch_losses)), self.per_batch_losses,
                 label="Training loss for each batch")
        plt.xlabel(f"Batch (epoch {epoch})")
        plt.ylabel("Loss")
        plt.legend()
        plt.savefig(f"plot_at_epoch_{epoch}")
        self.per_batch_losses = [] #清空,方便下一轮的技术
model = get_mnist_model()
model.compile(optimizer="rmsprop",
              loss="sparse_categorical_crossentropy",
              metrics=["accuracy"])
model.fit(train_images, train_labels,
          epochs=10,
          callbacks=[LossHistory()],
          validation_data=(val_images, val_labels))

keras回调函数如何使用

【其他】模型的定义 和 数据加载

def get_minist_model():
    inputs=keras.Input(shape=(28*28,))
    features=layers.Dense(512,activation="relu")(inputs)
    features=layers.Dropout(0.5)(features)
    outputs=layers.Dense(10,activation="softmax")(features)
    model=keras.Model(inputs,outputs)
    return model
    
#datset
from tensorflow.keras.datasets import mnist
(train_images,train_labels),(test_images,test_labels)=mnist.load_data()
train_images=train_images.reshape((60000,28*28)).astype("float32")/255
test_images=test_images.reshape((10000,28*28)).astype("float32")/255
train_images,val_images=train_images[10000:],train_images[:10000]
train_labels,val_labels=train_labels[10000:],train_labels[:10000]

关于“keras回调函数如何使用”这篇文章的内容就介绍到这里,感谢各位的阅读!相信大家对“keras回调函数如何使用”知识都有一定的了解,大家如果还想学习更多知识,欢迎关注亿速云行业资讯频道。

推荐阅读:
  1. 基于keras中训练数据的几种方式对比有什么不同
  2. keras中get_value运行越来越慢怎么办

免责声明:本站发布的内容(图片、视频和文字)以原创、转载和分享为主,文章观点不代表本网站立场,如果涉及侵权请联系站长邮箱:is@yisu.com进行举报,并提供相关证据,一经查实,将立刻删除涉嫌侵权内容。

keras

上一篇:VS Code中怎么安装运行、编写C语言程序

下一篇:docker容器因报错无法启动问题怎么检查及修复

相关阅读

您好,登录后才能下订单哦!

密码登录
登录注册
其他方式登录
点击 登录注册 即表示同意《亿速云用户服务条款》