keras 简单 lstm实例(基于one-hot编码)

简单的LSTM问题,能够预测一句话的下一个字词是什么

固定长度的句子,一个句子有3个词。

使用one-hot编码

各种引用

import keras
from keras.models import Sequential
from keras.layers import LSTM, Dense, Dropout
import numpy as np

数据预处理

data = 'abcdefghijklmnopqrstuvwxyz'
data_set = set(data)

word_2_int = {b:a for a,b in enumerate(data_set)}
int_2_word = {a:b for a,b in enumerate(data_set)}

word_len = len(data_set)
print(word_2_int)
print(int_2_word)

一些辅助函数

def words_2_ints(words):
 ints = []
 for itmp in words:
  ints.append(word_2_int[itmp])
 return ints

print(words_2_ints('ab'))

def words_2_one_hot(words, num_classes=word_len):
 return keras.utils.to_categorical(words_2_ints(words), num_classes=num_classes)
print(words_2_one_hot('a'))
def get_one_hot_max_idx(one_hot):
 idx_ = 0
 max_ = 0
 for i in range(len(one_hot)):
  if max_ < one_hot[i]:
   max_ = one_hot[i]
   idx_ = i
 return idx_

def one_hot_2_words(one_hot):
 tmp = []
 for itmp in one_hot:
  tmp.append(int_2_word[get_one_hot_max_idx(itmp)])
 return "".join(tmp)

print( one_hot_2_words(words_2_one_hot('adhjlkw')) )

构造样本

time_step = 3 #一个句子有3个词

def genarate_data(batch_size=5, genarate_num=100):
 #genarate_num = -1 表示一直循环下去,genarate_num=1表示生成一个batch的数据,以此类推
 #这里,我也不知道数据有多少,就这么循环的生成下去吧。
 #入参batch_size 控制一个batch 有多少数据,也就是一次要yield进多少个batch_size的数据
 '''
 例如,一个batch有batch_size=5个样本,那么对于这个例子,需要yield进的数据为:
 abc->d
 bcd->e
 cde->f
 def->g
 efg->h
 然后把这些数据都转换成one-hot形式,最终数据,输入x的形式为:

 [第1个batch]
 [第2个batch]
 ...
 [第genarate_num个batch]

 每个batch的形式为:

 [第1句话(如abc)]
 [第2句话(如bcd)]
 ...
 每一句话的形式为:

 [第1个词的one-hot表示]
 [第2个词的one-hot表示]
 ...
 '''
 cnt = 0
 batch_x = []
 batch_y = []
 sample_num = 0
 while(True):
  for i in range(len(data) - time_step):
   batch_x.append(words_2_one_hot(data[i : i+time_step]))
   batch_y.append(words_2_one_hot(data[i+time_step])[0]) #这里数据加[0],是为了符合keras的输出数据格式。 因为不加[0],表示是3维的数据。 你可以自己尝试不加0,看下面的test打印出来是什么
   sample_num += 1
   #print('sample num is :', sample_num)
   if len(batch_x) == batch_size:
    yield (np.array(batch_x), np.array(batch_y))
    batch_x = []
    batch_y = []
    if genarate_num != -1:
     cnt += 1

    if cnt == genarate_num:
     return

for test in genarate_data(batch_size=3, genarate_num=1):
 print('--------x:')
 print(test[0])
 print('--------y:')
 print(test[1])

搭建模型并训练

model = Sequential()

# LSTM输出维度为 128
# input_shape控制输入数据的形态
# time_stemp表示一句话有多少个单词
# word_len 表示一个单词用多少维度表示,这里是26维

model.add(LSTM(128, input_shape=(time_step, word_len)))
model.add(Dense(word_len, activation='softmax')) #输出用一个softmax,来分类,维度就是26,预测是哪一个字母

model.compile(loss='categorical_crossentropy', optimizer='rmsprop', metrics=['accuracy'])

model.fit_generator(generator=genarate_data(batch_size=5, genarate_num=-1), epochs=50, steps_per_epoch=10)
#steps_per_epoch的意思是,一个epoch中,执行多少个batch
#batch_size是一个batch中,有多少个样本。
#所以,batch_size*steps_per_epoch就等于一个epoch中,训练的样本数量。(这个说法不对!再观察看看吧)
#可以将epochs设置成1,或者2,然后在genarate_data中打印样本序号,观察到样本总数。

使用训练后的模型进行预测:

result = model.predict(np.array([words_2_one_hot('bcd')]))

print(one_hot_2_words(result))

可以看到,预测结果为

e

补充知识:训练集产生的onehot编码特征如何在测试集、预测集复现

数据处理中有时要用到onehot编码,如果使用pandas自带的get_dummies方法,训练集产生的onehot编码特征会跟测试集、预测集不一样,正确的方式是使用sklearn自带的OneHotEncoder。

代码

import pandas as pd
from sklearn.preprocessing import OneHotEncoder
ohe = OneHotEncoder(handle_unknown='ignore')
data_train=pd.DataFrame({'职业':['数据挖掘工程师','数据库开发工程师','数据分析师','数据分析师'],
     '籍贯':['福州','厦门','泉州','龙岩']})
ohe.fit(data_train)#训练规则
feature_names=ohe.get_feature_names(data_train.columns)#获取编码后的特征名
data_train_onehot=pd.DataFrame(ohe.transform(data_train).toarray(),columns=feature_names)#应用规则在训练集上

data_new=pd.DataFrame({'职业':['数据挖掘工程师','jave工程师'],
     '籍贯':['福州','莆田']})
data_new_onehot=pd.DataFrame(ohe.transform(data_new).toarray(),columns=feature_names)#应用规则在预测集上

以上这篇keras 简单 lstm实例(基于one-hot编码)就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持我们。

(0)

相关推荐

  • Tensorflow实现将标签变为one-hot形式

    将数据标签变为类似MNIST的one-hot编码形式 def one_hot(indices, depth, on_value=None, off_value=None, axis=None, dtype=None, name=None): """Returns a one-hot tensor. The locations represented by indices in `indices` take value `on_value`, while all other l

  • 对python sklearn one-hot编码详解

    one-hot编码的作用 使用one-hot编码,将离散特征的取值扩展到了欧式空间,离散特征的某个取值就对应欧式空间的某个点 将离散特征通过one-hot编码映射到欧式空间,是因为,在回归,分类,聚类等机器学习算法中,特征之间距离的计算或相似度的计算是非常重要的,而我们常用的距离或相似度的计算都是在欧式空间的相似度计算,计算余弦相似性,基于的就是欧式空间. sklearn的一个例子 from sklearn import preprocessing enc = preprocessing.One

  • pandas使用get_dummies进行one-hot编码的方法

    离散特征的编码分为两种情况: 1.离散特征的取值之间没有大小的意义,比如color:[red,blue],那么就使用one-hot编码 2.离散特征的取值有大小的意义,比如size:[X,XL,XXL],那么就使用数值的映射{X:1,XL:2,XXL:3} 使用pandas可以很方便的对离散型特征进行one-hot编码 import pandas as pd df = pd.DataFrame([ ['green', 'M', 10.1, 'class1'], ['red', 'L', 13.5

  • python对离散变量的one-hot编码方法

    我们在进行建模时,变量中经常会有一些变量为离散型变量,例如性别.这些变量我们一般无法直接放到模型中去训练模型.因此在使用之前,我们往往会对此类变量进行处理.一般是对离散变量进行one-hot编码.下面具体介绍通过python对离散变量进行one-hot的方法. 注意:这里提供两种哑编码的实现方法,pandas和sklearn.它们最大的区别是,pandas默认只处理字符串类别变量,sklearn默认只处理数值型类别变量(需要先 LabelEncoder ) ① pd.get_dummies(pr

  • keras 简单 lstm实例(基于one-hot编码)

    简单的LSTM问题,能够预测一句话的下一个字词是什么 固定长度的句子,一个句子有3个词. 使用one-hot编码 各种引用 import keras from keras.models import Sequential from keras.layers import LSTM, Dense, Dropout import numpy as np 数据预处理 data = 'abcdefghijklmnopqrstuvwxyz' data_set = set(data) word_2_int

  • Java 详解单向加密--MD5、SHA和HMAC及简单实现实例

    Java 详解单向加密--MD5.SHA和HMAC及简单实现实例 概要: MD5.SHA.HMAC这三种加密算法,可谓是非可逆加密,就是不可解密的加密方法. MD5 MD5即Message-Digest Algorithm 5(信息-摘要算法5),用于确保信息传输完整一致.MD5是输入不定长度信息,输出固定长度128-bits的算法. MD5算法具有以下特点: 1.压缩性:任意长度的数据,算出的MD5值长度都是固定的. 2.容易计算:从原数据计算出MD5值很容易. 3.抗修改性:对原数据进行任何

  • smarty简单应用实例

    本文讲述了smarty简单应用实例.分享给大家供大家参考,具体如下: <?php require 'smarty/libs/Smarty.class.php'; $smarty = new Smarty; $smarty->template_dir="smarty/templates/templates"; $smarty->compile_dir="smarty/templates/templates_c"; $smarty->config

  • C# 中SharpMap的简单使用实例详解

    本文是利用ShapMap实现GIS的简单应用的小例子,以供学习分享使用.关于SharpMap的说明,网上大多是以ShapeFile为例进行简单的说明,就连官网上的例子也不多.本文是自己参考了源代码进行整理的,主要是WinForm的例子.原理方面本文也不过多论述,主要是实例演示,需要的朋友还是以SharpMap源码进行深入研究. 什么是SharpMap ? SharpMap是一个基于.net 2.0使用C#开发的Map渲染类库,可以渲染各类GIS数据(目前支持ESRI Shape和PostGIS格

  • IOS文件的简单读写实例详解

    IOS文件的简单读写实例详解 数组(可变与不可变)和字典(可变与不可变)中元素对象的类型,必须是NSString,NSArray,NSDictionary,NSData,否则不能直接写入文件 #pragma mark---NSString的写入与读取--- //1:获取路径 NSString *docunments = [NSSearchPathForDirectoriesInDomains(NSDocumentDirectory, NSUserDomainMask, YES)firstObje

  • 关于Android高德地图的简单开发实例代码(DEMO)

    废话不多说了,直接给大家上干货了. 以下为初次接触时 ,练手的DEMO import android.app.Activity; import android.app.ProgressDialog; import android.content.ContentValues; import android.database.Cursor; import android.database.SQLException; import android.database.sqlite.SQLiteDatab

  • vue中的非父子间的通讯问题简单的实例代码

    官网上的例子好晦涩,看了一个头两个大,关于非父子间的通讯问题,经过查阅得到了下面的例子, <!DOCTYPE html> <html lang="en"> <head> <meta charset="UTF-8"> <title>兄弟之间的通讯问题</title> <script src="vue.js"></script> </head>

  • AJAX简单应用实例-弹出层

    function createobj() {  if (window.ActiveXObject) {          return(new ActiveXObject("Microsoft.XMLHTTP"));      }      else if (window.XMLHttpRequest) {          return(new XMLHttpRequest());      } } function personalInfo() {   var oBao=creat

  • bootstrap导航栏、下拉菜单、表单的简单应用实例解析

    制作效果图如下: 代码如下(以后做东西可以改改就能直接用): <!DOCTYPE html> <html lang="zh-cn"> <head> <meta charset="utf-8"> <meta http-equiv="X-UA-Compatible" content="IE=edge"> <meta name="viewport"

  • spring和quartz整合,并简单调用(实例讲解)

    工作中会定时任务~简单学习一下. 第0步: 工欲善其事必先利其器,首先要做的自然是导包了. 在spring配置包扫描以及在 pom导入包 spring.xml: pom.xml 1.在spring-quartz.xml(和spring.xml同一个位置)配置相关属性 xml的头部每个人都可能不一样,这个自己要用的时候注意. quartz表达式根据自己需求去写,不列举了,这里的是1秒一次的. 2.Task包下配置类 我们这边将定时任务存放到一个包中,命名为task.用spring的自动注解serv

随机推荐