手把手教你使用TensorFlow2实现RNN

目录
  • 概述
  • 权重共享
  • 计算过程:
  • 案例
    • 数据集
    • RNN 层
    • 获取数据
  • 完整代码

概述

RNN (Recurrent Netural Network) 是用于处理序列数据的神经网络. 所谓序列数据, 即前面的输入和后面的输入有一定的联系.

权重共享

传统神经网络:

RNN:

RNN 的权重共享和 CNN 的权重共享类似, 不同时刻共享一个权重, 大大减少了参数数量.

计算过程:

计算状态 (State)

计算输出:

案例

数据集

IBIM 数据集包含了来自互联网的 50000 条关于电影的评论, 分为正面评价和负面评价.

RNN 层

class RNN(tf.keras.Model):

    def __init__(self, units):
        super(RNN, self).__init__()

        # 初始化 [b, 64] (b 表示 batch_size)
        self.state0 = [tf.zeros([batch_size, units])]
        self.state1 = [tf.zeros([batch_size, units])]

        # [b, 80] => [b, 80, 100]
        self.embedding = tf.keras.layers.Embedding(total_words, embedding_len, input_length=max_review_len)

        self.rnn_cell0 = tf.keras.layers.SimpleRNNCell(units=units, dropout=0.2)
        self.rnn_cell1 = tf.keras.layers.SimpleRNNCell(units=units, dropout=0.2)

        # [b, 80, 100] => [b, 64] => [b, 1]
        self.out_layer = tf.keras.layers.Dense(1)

    def call(self, inputs, training=None):
        """

        :param inputs: [b, 80]
        :param training:
        :return:
        """

        state0 = self.state0
        state1 = self.state1

        x = self.embedding(inputs)

        for word in tf.unstack(x, axis=1):
            out0, state0 = self.rnn_cell0(word, state0, training=training)
            out1, state1 = self.rnn_cell1(out0, state1, training=training)

        # [b, 64] -> [b, 1]
        x = self.out_layer(out1)

        prob = tf.sigmoid(x)

        return prob

获取数据

def get_data():
    # 获取数据
    (X_train, y_train), (X_test, y_test) = tf.keras.datasets.imdb.load_data(num_words=total_words)

    # 更改句子长度
    X_train = tf.keras.preprocessing.sequence.pad_sequences(X_train, maxlen=max_review_len)
    X_test = tf.keras.preprocessing.sequence.pad_sequences(X_test, maxlen=max_review_len)

    # 调试输出
    print(X_train.shape, y_train.shape)  # (25000, 80) (25000,)
    print(X_test.shape, y_test.shape)  # (25000, 80) (25000,)

    # 分割训练集
    train_db = tf.data.Dataset.from_tensor_slices((X_train, y_train))
    train_db = train_db.shuffle(10000).batch(batch_size, drop_remainder=True)

    # 分割测试集
    test_db = tf.data.Dataset.from_tensor_slices((X_test, y_test))
    test_db = test_db.batch(batch_size, drop_remainder=True)

    return train_db, test_db

完整代码

import tensorflow as tf

class RNN(tf.keras.Model):

    def __init__(self, units):
        super(RNN, self).__init__()

        # 初始化 [b, 64]
        self.state0 = [tf.zeros([batch_size, units])]
        self.state1 = [tf.zeros([batch_size, units])]

        # [b, 80] => [b, 80, 100]
        self.embedding = tf.keras.layers.Embedding(total_words, embedding_len, input_length=max_review_len)

        self.rnn_cell0 = tf.keras.layers.SimpleRNNCell(units=units, dropout=0.2)
        self.rnn_cell1 = tf.keras.layers.SimpleRNNCell(units=units, dropout=0.2)

        # [b, 80, 100] => [b, 64] => [b, 1]
        self.out_layer = tf.keras.layers.Dense(1)

    def call(self, inputs, training=None):
        """

        :param inputs: [b, 80]
        :param training:
        :return:
        """

        state0 = self.state0
        state1 = self.state1

        x = self.embedding(inputs)

        for word in tf.unstack(x, axis=1):
            out0, state0 = self.rnn_cell0(word, state0, training=training)
            out1, state1 = self.rnn_cell1(out0, state1, training=training)

        # [b, 64] -> [b, 1]
        x = self.out_layer(out1)

        prob = tf.sigmoid(x)

        return prob

# 超参数
total_words = 10000  # 文字数量
max_review_len = 80  # 句子长度
embedding_len = 100  # 词维度
batch_size = 1024  # 一次训练的样本数目
learning_rate = 0.0001  # 学习率
iteration_num = 20  # 迭代次数
optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate)  # 优化器
loss = tf.losses.BinaryCrossentropy(from_logits=True)  # 损失
model = RNN(64)

# 调试输出summary
model.build(input_shape=[None, 64])
print(model.summary())

# 组合
model.compile(optimizer=optimizer, loss=loss, metrics=["accuracy"])

def get_data():
    # 获取数据
    (X_train, y_train), (X_test, y_test) = tf.keras.datasets.imdb.load_data(num_words=total_words)

    # 更改句子长度
    X_train = tf.keras.preprocessing.sequence.pad_sequences(X_train, maxlen=max_review_len)
    X_test = tf.keras.preprocessing.sequence.pad_sequences(X_test, maxlen=max_review_len)

    # 调试输出
    print(X_train.shape, y_train.shape)  # (25000, 80) (25000,)
    print(X_test.shape, y_test.shape)  # (25000, 80) (25000,)

    # 分割训练集
    train_db = tf.data.Dataset.from_tensor_slices((X_train, y_train))
    train_db = train_db.shuffle(10000).batch(batch_size, drop_remainder=True)

    # 分割测试集
    test_db = tf.data.Dataset.from_tensor_slices((X_test, y_test))
    test_db = test_db.batch(batch_size, drop_remainder=True)

    return train_db, test_db

if __name__ == "__main__":
    # 获取分割的数据集
    train_db, test_db = get_data()

    # 拟合
    model.fit(train_db, epochs=iteration_num, validation_data=test_db, validation_freq=1)

输出结果:

Model: "rnn"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
embedding (Embedding) multiple 1000000
_________________________________________________________________
simple_rnn_cell (SimpleRNNCe multiple 10560
_________________________________________________________________
simple_rnn_cell_1 (SimpleRNN multiple 8256
_________________________________________________________________
dense (Dense) multiple 65
=================================================================
Total params: 1,018,881
Trainable params: 1,018,881
Non-trainable params: 0
_________________________________________________________________
None

(25000, 80) (25000,)
(25000, 80) (25000,)
Epoch 1/20
2021-07-10 17:59:45.150639: I tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc:176] None of the MLIR Optimization Passes are enabled (registered 2)
24/24 [==============================] - 12s 294ms/step - loss: 0.7113 - accuracy: 0.5033 - val_loss: 0.6968 - val_accuracy: 0.4994
Epoch 2/20
24/24 [==============================] - 7s 292ms/step - loss: 0.6951 - accuracy: 0.5005 - val_loss: 0.6939 - val_accuracy: 0.4994
Epoch 3/20
24/24 [==============================] - 7s 297ms/step - loss: 0.6937 - accuracy: 0.5000 - val_loss: 0.6935 - val_accuracy: 0.4994
Epoch 4/20
24/24 [==============================] - 8s 316ms/step - loss: 0.6934 - accuracy: 0.5001 - val_loss: 0.6933 - val_accuracy: 0.4994
Epoch 5/20
24/24 [==============================] - 7s 301ms/step - loss: 0.6934 - accuracy: 0.4996 - val_loss: 0.6933 - val_accuracy: 0.4994
Epoch 6/20
24/24 [==============================] - 8s 334ms/step - loss: 0.6932 - accuracy: 0.5000 - val_loss: 0.6932 - val_accuracy: 0.4994
Epoch 7/20
24/24 [==============================] - 10s 398ms/step - loss: 0.6931 - accuracy: 0.5006 - val_loss: 0.6932 - val_accuracy: 0.4994
Epoch 8/20
24/24 [==============================] - 9s 382ms/step - loss: 0.6930 - accuracy: 0.5006 - val_loss: 0.6931 - val_accuracy: 0.4994
Epoch 9/20
24/24 [==============================] - 8s 322ms/step - loss: 0.6924 - accuracy: 0.4995 - val_loss: 0.6913 - val_accuracy: 0.5240
Epoch 10/20
24/24 [==============================] - 8s 321ms/step - loss: 0.6812 - accuracy: 0.5501 - val_loss: 0.6655 - val_accuracy: 0.5767
Epoch 11/20
24/24 [==============================] - 8s 318ms/step - loss: 0.6381 - accuracy: 0.6896 - val_loss: 0.6235 - val_accuracy: 0.7399
Epoch 12/20
24/24 [==============================] - 8s 323ms/step - loss: 0.6088 - accuracy: 0.7655 - val_loss: 0.6110 - val_accuracy: 0.7533
Epoch 13/20
24/24 [==============================] - 8s 321ms/step - loss: 0.5949 - accuracy: 0.7956 - val_loss: 0.6111 - val_accuracy: 0.7878
Epoch 14/20
24/24 [==============================] - 8s 324ms/step - loss: 0.5859 - accuracy: 0.8142 - val_loss: 0.5993 - val_accuracy: 0.7904
Epoch 15/20
24/24 [==============================] - 8s 330ms/step - loss: 0.5791 - accuracy: 0.8318 - val_loss: 0.5961 - val_accuracy: 0.7907
Epoch 16/20
24/24 [==============================] - 8s 340ms/step - loss: 0.5739 - accuracy: 0.8421 - val_loss: 0.5942 - val_accuracy: 0.7961
Epoch 17/20
24/24 [==============================] - 9s 378ms/step - loss: 0.5701 - accuracy: 0.8497 - val_loss: 0.5933 - val_accuracy: 0.8014
Epoch 18/20
24/24 [==============================] - 9s 361ms/step - loss: 0.5665 - accuracy: 0.8589 - val_loss: 0.5958 - val_accuracy: 0.8082
Epoch 19/20
24/24 [==============================] - 8s 353ms/step - loss: 0.5630 - accuracy: 0.8681 - val_loss: 0.5931 - val_accuracy: 0.7966
Epoch 20/20
24/24 [==============================] - 8s 314ms/step - loss: 0.5614 - accuracy: 0.8702 - val_loss: 0.5925 - val_accuracy: 0.7959

Process finished with exit code 0

到此这篇关于手把手教你使用TensorFlow2实现RNN的文章就介绍到这了,更多相关TensorFlow2实现RNN内容请搜索我们以前的文章或继续浏览下面的相关文章希望大家以后多多支持我们!

(0)

相关推荐

  • TensorFlow实现RNN循环神经网络

    RNN(recurrent neural Network)循环神经网络 主要用于自然语言处理(nature language processing,NLP) RNN主要用途是处理和预测序列数据 RNN广泛的用于 语音识别.语言模型.机器翻译 RNN的来源就是为了刻画一个序列当前的输出与之前的信息影响后面节点的输出 RNN 是包含循环的网络,允许信息的持久化. RNN会记忆之前的信息,并利用之前的信息影响后面节点的输出. RNN的隐藏层之间的节点是有相连的,隐藏层的输入不仅仅包括输入层的输出,还包

  • 浅谈Tensorflow 动态双向RNN的输出问题

    tf.nn.bidirectional_dynamic_rnn() 函数: def bidirectional_dynamic_rnn( cell_fw, # 前向RNN cell_bw, # 后向RNN inputs, # 输入 sequence_length=None,# 输入序列的实际长度(可选,默认为输入序列的最大长度) initial_state_fw=None, # 前向的初始化状态(可选) initial_state_bw=None, # 后向的初始化状态(可选) dtype=No

  • Tensorflow与RNN、双向LSTM等的踩坑记录及解决

    1.tensorflow(不定长)文本序列读取与解析 tensorflow读取csv时需要指定各列的数据类型. 但是对于RNN这种接受序列输入的模型来说,一条序列的长度是不固定.这时如果使用csv存储序列数据,应当首先将特征序列拼接成一列. 例如两条数据序列,第一项是标签,之后是特征序列 [0, 1.1, 1.2, 2.3] 转换成 [0, '1.1_1.2_2.3'] [1, 1.0, 2.5, 1.6, 3.2, 4.5] 转换成 [1, '1.0_2.5_1.6_3.2_4.5'] 这样每

  • 手把手教你使用TensorFlow2实现RNN

    目录 概述 权重共享 计算过程: 案例 数据集 RNN 层 获取数据 完整代码 概述 RNN (Recurrent Netural Network) 是用于处理序列数据的神经网络. 所谓序列数据, 即前面的输入和后面的输入有一定的联系. 权重共享 传统神经网络: RNN: RNN 的权重共享和 CNN 的权重共享类似, 不同时刻共享一个权重, 大大减少了参数数量. 计算过程: 计算状态 (State) 计算输出: 案例 数据集 IBIM 数据集包含了来自互联网的 50000 条关于电影的评论,

  • 手把手教你用Hexo+Github搭建属于自己的博客(详细图文)

    在大三的时候,一直就想搭建属于自己的一个博客,但由于各种原因,最终都不了了之,恰好最近比较有空,于是就自己参照网上的教程,搭建了属于自己的博客. 至于为什么要搭建自己的博客了? 哈哈,大概是为了装逼吧,同时自己搭建博客的话,样式的选择也比较自由,可以自己选择,不需要受限于各大平台. 转载请注明原博客地址:手把手教你用Hexo+Github 搭建属于自己的博客 大概可以分为以下几个步骤 搭建环境准备(包括node.js和git环境,gitHub账户的配置) 安装Hexo 配置Hexo 怎样将Hex

  • 手把手教你使用 virtualBox 让虚拟机连接网络的教程

    1 设置 virtualBox 打开设置->网络 采用桥接模式连接网络,并选择对应的物理网卡. 2 设置虚拟机(centos7) 1.使用 nmcli 命令,查看当前虚拟机的所有网络基本信息: nmcli connection show 具体参数说明如下: 参数名称 说明 NAME 连网代号,通常与 DEVICE 一样 UUID 识别码 TYPE 网卡的类型:802-3-ethernet 就是以太网 DEVICE 网卡名称 * 这里的 enp0s3 是 centos7 自动生成的带随机数的网卡名

  • Android消息推送:手把手教你集成小米推送(附demo)

    前言 在Android开发中,消息推送功能的使用非常常见. 为了降低开发成本,使用第三方推送是现今较为流行的解决方案. 今天,我将手把手教大家如何在你的应用里集成小米推送 目录 1. 官方Demo解析 首先,我们先对小米官方的推送Demo进行解析. 请先到官网下载官方Demo和SDK说明文档 1.1 Demo概况 目录说明: DemoApplication类 继承自Application类,其作用主要是:设置App的ID & Key.注册推送服务 DemoMessageReceiver类 继承自

  • 比较详细的手把手教你写批处理(willsort题注版)第1/5页

    另,建议Climbing兄取文不用拘泥于国内,此类技术文章,内外水平相差极大:与其修正国内只言片语,不如翻译国外优秀著述. -------------------------------------------------------- 标题:手把手教你写批处理-批处理的介绍 作者:佚名 编者:Climbing 题注:willsort 日期:2004-09-21 -------------------------------------------------------- 批处理的介绍 扩展名

  • 手把手教你配置一台Linux虚拟机

    手把手教你配置一台Linux虚拟机 前言: Linux distribution 越来越多,也越来越成熟,所以安装起来也是比较简单,但是要理解安装的每一个步骤还是需要对Linux的基础知识有一定的了解,不过不用很深入,如果很深入我也不会.这里我选择的安装方式都是最简单的,在磁盘分区最重要的步骤也是以最简单的方式分区. 本次Linux配置的目的不是作为商业用途,而是在于新手熟悉Linux的操作系统,使新手能自己在本地配置Linux系统. 选择distrubution版本,因为我们是把Linux作为

  • 手把手教你用python抢票回家过年(代码简单)

    首先看看如何快速查看剩余火车票? 当你想查询一下火车票信息的时候,你还在上12306官网吗?或是打开你手机里的APP?下面让我们来用Python写一个命令行版的火车票查看器, 只要在命令行敲一行命令就能获得你想要的火车票信息!如果你刚掌握了Python基础,这将是个不错的小练习. 接口设计 一个应用写出来最终是要给人使用的,哪怕只是给你自己使用.所以,首先应该想想你希望怎么使用它?让我们先给这个小应用起个名字吧,既然及查询票务信息,那就叫它tickets好了.我们希望用户只要输入出发站,到达站以

  • C语言手把手教你实现贪吃蛇AI(下)

    本文实例为大家分享了C语言实现贪吃蛇AI的具体代码,供大家参考,具体内容如下 1. 目标 这一部分的目标是把之前写的贪吃蛇加入AI功能,即自动的去寻找食物并吃掉. 2. 控制策略 为了保证蛇不会走入"死地",所以蛇每前进一步都需要检查,移动到新的位置后,能否找到走到蛇尾的路径,如果可以,才可以走到新的位置:否则在当前的位置寻找走到蛇尾的路径,并按照路径向前走一步,开始循环之前的操作,如下图所示.这个策略可以工作,但是并不高效,也可以尝试其他的控制策略,比如易水寒的贪吃蛇AI 运行效果如

  • C语言手把手教你实现贪吃蛇AI(中)

    手把手教你实现贪吃蛇AI,具体内容如下 1. 目标 这一部分主要是讲解编写贪吃蛇AI所需要用到的算法基础. 2. 问题分析 贪吃蛇AI说白了就是寻找一条从蛇头到食物的一条最短路径,同时这条路径需要避开障碍物,这里仅有的障碍就是蛇身.而A star 算法就是专门针对这一个问题的.在A star 算法中需要用到排序算法,这里采用堆排序(当然其他排序也可以),如果对堆排序不熟悉的朋友,请移步到这里--堆排序,先看看堆排序的内容. 3. A*算法 A star(也称A*)搜寻算法俗称A星算法.这是一种在

  • C语言手把手教你实现贪吃蛇AI(上)

    本文实例为大家分享了手把手教你实现贪吃蛇AI的具体步骤,供大家参考,具体内容如下 1. 目标 编写一个贪吃蛇AI,也就是自动绕过障碍,去寻找最优路径吃食物. 2. 问题分析 为了达到这一目的,其实很容易,总共只需要两步,第一步抓一条蛇,第二步给蛇装一个脑子.具体来说就是,首先我们需要有一条普通的贪吃蛇,也就是我们常玩儿的,手动控制去吃食物的贪吃蛇:然后给这条蛇加入AI,也就是通过算法控制,告诉蛇怎么最方便的绕开障碍去吃食物.为了讲清楚这个问题,文章将分为三部分:上,写一个贪吃蛇程序:中,算法基础

随机推荐