python实现CTC以及案例讲解

在大多数语音识别任务中,我们都缺少文本和音频特征的alignment,Connectionist Temporal Classification作为一个损失函数,用于在序列数据上进行监督式学习,可以不需要对齐输入数据及标签。

对于输入序列 X = [ x 1 , x 2 , . . , x T ] X=[x_1, x_2, .., x_T] X=[x1​,x2​,..,xT​] 和 输出序列 Y = [ y 1 , y 2 , . . . , y U ] Y = [y_1, y_2, ..., y_U ] Y=[y1​,y2​,...,yU​],我们希望训练一个模型使条件概率 P ( Y ∣ X ) P(Y|X) P(Y∣X) 达到最大化,并且给定新的输入序列时我们希望模型可以推测出最优的输出序列, Y ∗ = a r g m a x Y   P ( Y ∣ X ) Y^*=\underset{Y}{argmax}\space P(Y|X) Y∗=Yargmax​ P(Y∣X),而CTC算法刚好可以同时做到训练和解码。

损失函数

语音识别任务中,大多数情况下都是输入序列长度大于文本序列长度,所以CTC算法的alignment方案也是基于将连续的几帧输入合并对应到某一个输出的token,即多对一,同时除了训练数据中所有的token集合,CTC还引入了一个空白token,在这里用 ϵ \epsilon ϵ 指代,他没有实际意义并且在最终输出序列中被移除,但这个token对生成alignment很有帮助。

CTC算法生成最终token输出序列步骤如下:
生成和输入序列长度相同的alignment → 合并相同token → 删除空白token → token序列

上面步骤准确来讲是解码的步骤,解码之前我们要训练模型,训练模型就需要损失函数,或者说需要一个被优化的目标函数:

以下图的普通RNN为例, p t ( a t ∣ X ) p_t(a_t|X) pt​(at​∣X) 是每一帧在token集合(含空白token)上的概率分布

通过每一帧的概率分布我们可以得到所有(有效)alignment的概率,最后所有alignment都可以对应到一个输出序列,进而也就得到所有输出序列的概率分布。我们找到所有能够合并到 label (Y)序列的 alignment,并将他们的概率分数相加,再取负对数就可以得到一对训练数据的Loss

那么对于整个数据集,可以得到目标函数 ∑ ( X , Y ) ∈ 训 练 数 据 集 − l o g   P ( Y ∣ X ) \sum_{(X,Y)\in 训练数据集}-log\space P(Y|X) ∑(X,Y)∈训练数据集​−log P(Y∣X),训练中需要将其最小化。

用暴力的方法找出所有alignment并对其概率求和效率很低,常用的算法是通过动态规划对alignment进行合并,准确来讲是一个动态规划+DFS的算法:

为了实现这个算法,先引入一个中间序列 Z = ( ϵ , y 1 , ϵ , y 2 . . . , ϵ , y U ) Z=(\epsilon,y_1,\epsilon,y_2...,\epsilon,y_U) Z=(ϵ,y1​,ϵ,y2​...,ϵ,yU​),也就是在label序列的起始,中间和终止位置插入空白token,引入这个中间序列可以说是CTC算法的精髓之一,下面我们以简单的 Y = ( a , b ) Y=(a,b) Y=(a,b) 输出序列进行说明:

中间序列 Z = ( ϵ , a , ϵ , b , ϵ ) Z=(\epsilon,a,\epsilon,b,\epsilon) Z=(ϵ,a,ϵ,b,ϵ),长度为 S S S

输入序列 X = ( x 1 , x 2 , x 3 , x 4 , x 5 , x 6 ) X=(x_1, x_2, x_3, x_4,x_5,x_6) X=(x1​,x2​,x3​,x4​,x5​,x6​),长度为 T T T

递归参数 α s , t \alpha_{s,t} αs,t​ 到 t t t 时刻为止中间序列的子序列 Z 1 : s Z_{1:s} Z1:s​获得的概率分数,也就是在 t t t时刻走到中间序列第 s s s个token时的概率分数

算法整体流程如下图所示,和原文中的图比起来加入了具体数值,理解起来更加直观,图中的红色路径表示不能进行跳转,因为如果直接从 t = 2 t=2 t=2 的第一个 ϵ \epsilon ϵ 跳到 t = 3 t=3 t=3 时刻的第3个 ϵ \epsilon ϵ,中间的token a a a 会被忽略,这样后面的路径不管怎么走都得不到正确的token序列。

其他情况下都可以接受来自上一个时刻的第 s − 2 , s − 1 , s s-2,s-1,s s−2,s−1,s个token的跳转,再对图中的节点做进一步解释,以绿色节点为例,该节点就是 α 4 , 4 \alpha_{4,4} α4,4​ (下标从1开始),表示前面不管怎么走,在 t = 4 t=4 t=4时刻落到第4个token时获得的概率分数,也就是把这个时刻能走到 b b b 的所有alignment 概率分数加起来。那么把最后一帧的2个节点的概率分数相加就是所有alignment的概率分数,即 P ( Y ∣ X ) = α S , T + α S − 1 , T P(Y|X)=\alpha_{S,T}+\alpha_{S-1, T} P(Y∣X)=αS,T​+αS−1,T​

下面直接给出dp的状态转换公式, p t ( z s ∣ X ) p_t(z_s|X) pt​(zs​∣X) 表示 t t t 时刻第 s s s 个字符的概率:

α s , t = ( α s , t − 1 + α s − 1 , t − 1 ) × p t ( z s ∣ X ) \alpha_{s,t}=(\alpha_{s,t-1}+\alpha_{s-1, t-1})\times p_t(z_s|X) αs,t​=(αs,t−1​+αs−1,t−1​)×pt​(zs​∣X), ( a , ϵ , a ) (a,\epsilon, a) (a,ϵ,a)或者 ( ϵ , a , ϵ ) (\epsilon,a,\epsilon) (ϵ,a,ϵ) 模式

α s , t = ( α s − 2 , t − 1 + α s − 1 , t − 1 + α s , t − 1 ) × p t ( z s ∣ X ) \alpha_{s,t}=(\alpha_{s-2,t-1}+\alpha_{s-1,t-1}+\alpha_{s,t-1})\times p_t(z_s|X) αs,t​=(αs−2,t−1​+αs−1,t−1​+αs,t−1​)×pt​(zs​∣X),其他情况

解码

解码问题就是已经有训练好的模型,需要通过输入序列推测出最优的token序列,实际上就是解决 Y ∗ = a r g m a x Y   P ( Y ∣ X ) Y^*=\underset{Y}{argmax}\space P(Y|X) Y∗=Yargmax​ P(Y∣X) 这个问题,那么能想到最直接的方法就是取每一帧概率分数最高的token,连接起来去掉 ϵ \epsilon ϵ 组成输出序列,也就是贪婪解码:

这样做虽然很高效但有时并不是最优解,比如几个概率分数较小的alignment序列最后都能转换为相同的token序列,那么将这些较小的alignment概率分数加起来可能会大于贪婪解码的概率分数。

常用的算法是改进版的beam search,常规的beam search是在每一帧都会保存概率分数最大的前几个路径并舍弃其他的,最后会给出最优的 b e a m beam beam 个路径,在此基础上,我们在路径搜索的过程中,需要对能映射到相同输出的alignment进行合并,合并之后再进行beam的枝剪。

和语言模型结合

CTC最明显的特点就是前后帧之间的条件独立假设

缺点:不适合包括语音识别在内的大多数seq2seq任务,上下文之间的相关性会被忽略,因此经常需要额外引入语言模型。

优点:不考虑上下文的相关性可以使模型泛化能力更强,比如如果不考虑文本之间的相关性,用于识别日常会话的声学模型可以直接用在会议内容转录的场景中。

由于语言模型分数和CTC的条件概率分数相互独立,因此最终的解码序列可以写成
Y ∗ = a r g m a x Y   P ( Y ∣ X ) × P ( Y ) α Y^*=\underset{Y}{argmax} \space P(Y|X)\times P(Y)^\alpha Y∗=Yargmax​ P(Y∣X)×P(Y)α, P ( Y ) P(Y) P(Y)表示语言模型的概率分数,可以是bigram也可以是3gram,以bigram为例的话,如果当前时刻序列是 ( a , b , c ) (a,b,c) (a,b,c),计算下一帧跳到 d d d 的概率分数时,不仅要考虑下一时刻的token概率分布,还要考虑训练文本中 ( c , d ) (c,d) (c,d) 出现的频次,即 c o u n t ( c , d ) / c o u n t ( c , ∗ ) count(c,d) / count(c,*) count(c,d)/count(c,∗),将这个概率和 d d d出现的概率相乘才是最终的概率分数, α \alpha α 是语言模型因子,需要fine tuning。

代码实现

损失函数(动态规划+DFS)
常规beam search解码
合并alignment的beam search解码
加入语言模型的 beam search解码

到此这篇关于python实现CTC以及案例讲解的文章就介绍到这了,更多相关python实现CTC内容请搜索我们以前的文章或继续浏览下面的相关文章希望大家以后多多支持我们!

(0)

相关推荐

  • 使用keras框架cnn+ctc_loss识别不定长字符图片操作

    我就废话不多说了,大家还是直接看代码吧~ # -*- coding: utf-8 -*- #keras==2.0.5 #tensorflow==1.1.0 import os,sys,string import sys import logging import multiprocessing import time import json import cv2 import numpy as np from sklearn.model_selection import train_test_s

  • IOS ObjectC与javascript交互详解及实现代码

    IOS OC与js交互详解 JS注入 : 把JS代码有OC注入到网页 JS注入又叫做OC和JS的交互 OC和JS的交互需要一个桥梁(中介),这个桥梁就是UIWebView的代理方法 网页加载初始内容 #import "ViewController.h" @interface ViewController ()<UIWebViewDelegate> @property (weak, nonatomic) IBOutlet UIWebView *webView; @end -

  • Asp.net Core 3.1基于AspectCore实现AOP实现事务、缓存拦截器功能

    最近想给我的框架加一种功能,就是比如给一个方法加一个事务的特性Attribute,那这个方法就会启用事务处理.给一个方法加一个缓存特性,那这个方法就会进行缓存. 这个也是网上说的面向切面编程AOP. AOP的概念也很好理解,跟中间件差不多,说白了,就是我可以任意地在方法的前面或后面添加代码,这很适合用于缓存.日志等处理. 在net core2.2时,我当时就尝试过用autofac实现aop,但这次我不想用autofac,我用了一个更轻量级的框架,AspectCore. 用起来非常非常的简单,但一

  • 解决Keras中循环使用K.ctc_decode内存不释放的问题

    如下一段代码,在多次调用了K.ctc_decode时,会发现程序占用的内存会越来越高,执行速度越来越慢. data = generator(...) model = init_model(...) for i in range(NUM): x, y = next(data) _y = model.predict(x) shape = _y.shape input_length = np.ones(shape[0]) * shape[1] ctc_decode = K.ctc_decode(_y,

  • Asp.Net Core轻量级Aop解决方案:AspectCore

    什么是AspectCore Project ? AspectCore Project 是适用于Asp.Net Core 平台的轻量级 Aop(Aspect-oriented programming) 解决方案,它更好的遵循Asp.Net Core的模块化开发理念,使用AspectCore可以更容易构建低耦合.易扩展的Web应用程序.AspectCore使用Emit实现高效的动态代理从而不依赖任何第三方Aop库. 开使使用AspectCore 启动 Visual Studio.从 File 菜单,

  • Kotlin基础教程之dataclass,objectclass,use函数,类扩展,socket

    Kotlin基础教程之dataclass,objectclass,use函数,类扩展,socket Kotlin提供了一些机制来扩展已有的类,如下: 还记得我们之前写过的Point3D类吗?(将其略作修改,将成员变量改为Double类型) 让我们为其扩展一个length函数 扩展的方法很简单,只要在函数名前面加上类名就行了. 这样Point3D的对象就有了一个名为length的方法. 运行的结果不出所料: 除此之外,在Kotlin中还有一些特殊的类,比如Data Class: 有些类只包含数据,

  • asp内置对象 ObjectContext 事务管理 详解

    asp内置对象 ObjectContext 详解 您可以使用 ObjectContext 对象提交或放弃一项由 Microsoft Transaction Server (MTS) 管理的事务,它由 ASP 页包含的脚本初始化. ASP 包含 @TRANSACTION 指令时,该页会在事务中运行,直到事务成功或失败后才会终止. 语法 ObjectContext.method 方法 SetComplete SetComplete 方法声明脚本不了解事务未完成的原因.如果事务中的所有组件都调用 Se

  • python实现CTC以及案例讲解

    在大多数语音识别任务中,我们都缺少文本和音频特征的alignment,Connectionist Temporal Classification作为一个损失函数,用于在序列数据上进行监督式学习,可以不需要对齐输入数据及标签. 对于输入序列 X = [ x 1 , x 2 , . . , x T ] X=[x_1, x_2, .., x_T] X=[x1​,x2​,..,xT​] 和 输出序列 Y = [ y 1 , y 2 , . . . , y U ] Y = [y_1, y_2, ...,

  • python代码实现备忘录案例讲解

    文件操作 TXT文件 读取txt文件 读取txt文件全部内容: def read_all(txt): ...: with open(txt,'r') as f: ...: return f.read() ...: read_all('test.txt') Out[23]: 'a,b,c,d\ne,f,g,h\ni,j,k,l\n' 按行读取txt文件内容 def read_line(txt): ...: line_list = [] ...: with open(txt,'r') as f: .

  • Python之根据输入参数计算结果案例讲解

    一.问题描述 define function,calculate the input parameters and return the result. 数据存放在 txt 里,为 10 行 10 列的矩阵. 编写一个函数,传入参数:文件路径.第一个数据行列索引.第二个数据行列索引和运算符. 返回计算结果 如果没有传入文件路径,随机生成 10*10 的值的范围在 [6, 66] 之间的随机整数数组存入 txt 以供后续读取数据和测试. 二.Python程序 导入需要的依赖库和日志输出配置 # -

  • Python进行区间取值案例讲解

    需求背景: 进行分值计算.如下图,如果只是一两个还好说,写写判断,但是如果有几十个,几百个,会不会惨不忍睹.而且,下面的还是三种情况. 例如: 解决: # 根据值.比较list, 值list,返回区间值, other_value 即不在的情况 def get_value_by_between(self, compare_value, compare_list, value_list, other_value, type="compare", left=False, right=True

  • python之json文件转xml文件案例讲解

    json文件格式 这是yolov4模型跑出来的检测结果result.json 下面是截取的一张图的检测结果 { "frame_id":1, #图片的序号 "filename":"/media/wuzhou/Gap/rgb-piglet/test/00000000.jpg", #图片的路径 "objects": [ #该图中所有的目标:目标类别.目标名称.归一化的框的坐标(xywh格式).置信度 {"class_id&

  • Python之urlencode和urldecode案例讲解

    python中的urlencode和urldecode python将字符串转化成urlencode ,或者将url编码字符串decode的方法: 方法1: urlencode:urllib中的quote方法 >>> from urllib import quote >>> quote(':') '%3A' >>> quote('http://www.baidu.com') 'http%3A//www.baidu.com' urldecode:urll

  • Python之进行URL编码案例讲解

    为什么要对URL进行encode 在写网络爬虫时,发现提交表单中的中文字符都变成了TextBox1=%B8%C5%C2%CA%C2%DB这种样子,观察这是中文对应的GB2312编码,实际上是进行了GB2312编码和urlencode. 那么为什么要对URL进行encode? 因为在标准的url规范中中文和很多的字符是不允许出现在url中的.为了字符编码(gbk.utf-8)和特殊字符不出现在url中,url转义是为了符合url的规范. 具体代码 urlencode编码:urllib中的quote

  • python之多种方式传递函数方法案例讲解

    这篇文章主要介绍了python进阶教程之函数参数的多种传递方法,包括关键字传递.默认值传递.包裹位置传递.包裹关键字混合传递等,需要的朋友可以参考下 我们已经接触过函数(function)的参数(arguments)传递.当时我们根据位置,传递对应的参数.我们将接触更多的参数传递方式. 回忆一下位置传递: def f(a,b,c): return a+b+c print(f(1,2,3)) 在调用f时,1,2,3根据位置分别传递给了a,b,c. 关键字传递 有些情况下,用位置传递会感觉比较死板.

  • JavaWeb案例讲解Servlet常用对象

    概述 本次文章基于第三章的ServletConfig,ServletContext,HttpServletRequest,HttpServletResponse对象完成一个图书订阅系统的购买图书和查看图书购买记录功能. 搭建项目主页面 创建一个动态网站项目,在src中新建包com.book.servlet. 在包中,新建HomeServlet作为主页.效果图如下: 为了让一访问项目根路径地址就默认进入HomeServlet,这里需要将 HomeServlet的虚拟地址写入web.xml文件中作为

  • python正则表达式用法超详细讲解大全

    目录 一.re.compile 函数 二.正则表达式 表示字符 表示数字 匹配边界 三.re模块的高级用法 1.findall:pattern在string里所有的非重复匹配,返回一个迭代器iterator保存了匹配对象 2.sub:将匹配到的字符串,再次进行操作 3.split:切割匹配成功的字符串 四.贪婪和非贪婪模式 总结 一.re.compile 函数 作用:compile 函数用于编译正则表达式,生成一个正则表达式( Pattern )对象,供 match() 和 search() 这

随机推荐