Tensorflow中CNN入门的手写数字识别是怎样的

发布时间:2021-12-23 15:39:10 作者:柒染
来源:亿速云 阅读:177

Tensorflow中CNN入门的手写数字识别是怎样的

引言

卷积神经网络(Convolutional Neural Networks, CNN)是深度学习领域中最重要和广泛应用的模型之一,尤其在图像识别任务中表现卓越。TensorFlow 是一个强大的开源机器学习框架,支持构建和训练各种深度学习模型。本文将详细介绍如何使用 TensorFlow 实现一个简单的卷积神经网络(CNN)来进行手写数字识别任务。

1. 环境准备

在开始之前,确保你已经安装了 TensorFlow。如果还没有安装,可以通过以下命令进行安装:

pip install tensorflow

此外,我们还需要一些辅助库,如 NumPy 和 Matplotlib,用于数据处理和可视化:

pip install numpy matplotlib

2. 数据集介绍

我们将使用经典的 MNIST 数据集,它包含 60,000 张训练图像和 10,000 张测试图像,每张图像都是 28x28 像素的灰度图像,表示手写数字 0 到 9。

import tensorflow as tf
from tensorflow.keras import datasets, layers, models
import matplotlib.pyplot as plt

# 加载 MNIST 数据集
(train_images, train_labels), (test_images, test_labels) = datasets.mnist.load_data()

# 归一化像素值到 [0, 1] 范围
train_images, test_images = train_images / 255.0, test_images / 255.0

3. 数据预处理

在将数据输入到 CNN 之前,我们需要对其进行一些预处理。首先,我们需要将图像数据从 (28, 28) 的形状调整为 (28, 28, 1),以便与 CNN 的输入格式兼容。

# 调整图像形状
train_images = train_images[..., tf.newaxis]
test_images = test_images[..., tf.newaxis]

print("训练图像形状:", train_images.shape)
print("测试图像形状:", test_images.shape)

4. 构建 CNN 模型

接下来,我们将构建一个简单的 CNN 模型。这个模型将包含两个卷积层、两个池化层和一个全连接层。

model = models.Sequential([
    layers.Conv2D(32, (3, 3), activation='relu', input_shape=(28, 28, 1)),
    layers.MaxPooling2D((2, 2)),
    layers.Conv2D(64, (3, 3), activation='relu'),
    layers.MaxPooling2D((2, 2)),
    layers.Conv2D(64, (3, 3), activation='relu'),
    layers.Flatten(),
    layers.Dense(64, activation='relu'),
    layers.Dense(10)
])

model.summary()

4.1 模型结构解释

5. 编译模型

在训练模型之前,我们需要编译模型,指定损失函数、优化器和评估指标。

model.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])

6. 训练模型

现在,我们可以开始训练模型了。我们将使用训练数据集进行训练,并在测试数据集上进行验证。

history = model.fit(train_images, train_labels, epochs=5, 
                    validation_data=(test_images, test_labels))

6.1 训练过程可视化

为了更直观地了解训练过程,我们可以绘制训练和验证的准确率和损失曲线。

plt.figure(figsize=(12, 4))

plt.subplot(1, 2, 1)
plt.plot(history.history['accuracy'], label='Training Accuracy')
plt.plot(history.history['val_accuracy'], label='Validation Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')

plt.subplot(1, 2, 2)
plt.plot(history.history['loss'], label='Training Loss')
plt.plot(history.history['val_loss'], label='Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')

plt.show()

7. 模型评估

训练完成后,我们可以使用测试数据集来评估模型的性能。

test_loss, test_acc = model.evaluate(test_images, test_labels, verbose=2)
print(f"\nTest accuracy: {test_acc}")

8. 预测与可视化

最后,我们可以使用训练好的模型对测试图像进行预测,并可视化预测结果。

predictions = model.predict(test_images)

def plot_image(i, predictions_array, true_label, img):
    true_label, img = true_label[i], img[i]
    plt.grid(False)
    plt.xticks([])
    plt.yticks([])

    plt.imshow(img[..., 0], cmap=plt.cm.binary)

    predicted_label = tf.argmax(predictions_array[i])
    if predicted_label == true_label:
        color = 'blue'
    else:
        color = 'red'

    plt.xlabel(f"{predicted_label} ({100 * tf.reduce_max(predictions_array[i]):.2f}%)", color=color)

def plot_value_array(i, predictions_array, true_label):
    true_label = true_label[i]
    plt.grid(False)
    plt.xticks(range(10))
    plt.yticks([])
    thisplot = plt.bar(range(10), predictions_array[i], color="#777777")
    plt.ylim([0, 1])
    predicted_label = tf.argmax(predictions_array[i])

    thisplot[predicted_label].set_color('red')
    thisplot[true_label].set_color('blue')

# 绘制前 5 个测试图像的预测结果
num_rows = 5
num_cols = 2
num_images = num_rows * num_cols
plt.figure(figsize=(2 * 2 * num_cols, 2 * num_rows))
for i in range(num_images):
    plt.subplot(num_rows, 2 * num_cols, 2 * i + 1)
    plot_image(i, predictions, test_labels, test_images)
    plt.subplot(num_rows, 2 * num_cols, 2 * i + 2)
    plot_value_array(i, predictions, test_labels)
plt.tight_layout()
plt.show()

9. 总结

通过本文,我们学习了如何使用 TensorFlow 构建一个简单的卷积神经网络(CNN)来进行手写数字识别。我们从数据预处理、模型构建、编译、训练、评估到预测与可视化,一步步完成了整个流程。希望这篇文章能帮助你入门 CNN 和 TensorFlow,并为你在深度学习领域的进一步探索打下坚实的基础。

10. 进一步学习

如果你对 CNN 和 TensorFlow 感兴趣,可以进一步学习以下内容:

11. 参考资源


通过本文的学习,你应该已经掌握了如何使用 TensorFlow 构建和训练一个简单的 CNN 模型来进行手写数字识别。希望你能在此基础上继续探索深度学习的更多可能性!

推荐阅读:
  1. Python+Tensorflow+CNN实现车牌识别
  2. TensorFlow实现CNN的方法

免责声明:本站发布的内容(图片、视频和文字)以原创、转载和分享为主,文章观点不代表本网站立场,如果涉及侵权请联系站长邮箱:is@yisu.com进行举报,并提供相关证据,一经查实,将立刻删除涉嫌侵权内容。

tensorflow cnn

上一篇:为什么会有TensorFlow

下一篇:mysql中出现1053错误怎么办

相关阅读

您好,登录后才能下订单哦!

密码登录
登录注册
其他方式登录
点击 登录注册 即表示同意《亿速云用户服务条款》