
作者 | Francois Chollet
编译 | CDA数据分析师
A ten-minute introduction to sequence-to-sequence learning in Keras
什么是顺序学习?
序列到序列学习(Seq2Seq)是关于将模型从一个域(例如英语中的句子)转换为另一域(例如将相同句子翻译为法语的序列)的训练模型。
“猫坐在垫子上” -> [ Seq2Seq 模型] -> “在小吃中聊天
这可用于机器翻译或免费问答(在给定自然语言问题的情况下生成自然语言答案)-通常,它可在需要生成文本的任何时间使用。
有多种处理此任务的方法,可以使用RNN或使用一维卷积网络。在这里,我们将重点介绍RNN。
普通情况:输入和输出序列的长度相同
当输入序列和输出序列的长度相同时,您可以简单地使用Keras LSTM或GRU层(或其堆栈)来实现此类模型。在此示例脚本 中就是这种情况, 该脚本显示了如何教RNN学习加编码为字符串的数字:
该方法的一个警告是,它假定可以生成target[...t]给定input[...t]。在某些情况下(例如添加数字字符串),该方法有效,但在大多数用例中,则无效。在一般情况下,有关整个输入序列的信息是必需的,以便开始生成目标序列。
一般情况:规范序列间
在一般情况下,输入序列和输出序列具有不同的长度(例如,机器翻译),并且需要整个输入序列才能开始预测目标。这需要更高级的设置,这是人们在没有其他上下文的情况下提到“序列模型的序列”时通常所指的东西。运作方式如下:
RNN层(或其堆栈)充当“编码器”:它处理输入序列并返回其自己的内部状态。请注意,我们放弃了编码器RNN的输出,仅恢复 了状态。在下一步中,此状态将用作解码器的“上下文”或“条件”。
另一个RNN层(或其堆栈)充当“解码器”:在给定目标序列的先前字符的情况下,对其进行训练以预测目标序列的下一个字符。具体而言,它经过训练以将目标序列变成相同序列,但在将来会偏移一个时间步,在这种情况下,该训练过程称为“教师强迫”。重要的是,编码器使用来自编码器的状态向量作为初始状态,这就是解码器如何获取有关应该生成的信息的方式。有效地,解码器学会产生targets[t+1...] 给定的targets[...t],调节所述输入序列。
在推断模式下,即当我们想解码未知的输入序列时,我们会经历一个略有不同的过程:
同样的过程也可以用于训练Seq2Seq网络,而无需 “教师强制”,即通过将解码器的预测重新注入到解码器中。
一个Keras例子
因为训练过程和推理过程(解码句子)有很大的不同,所以我们对两者使用不同的模型,尽管它们都利用相同的内部层。
这是我们的训练模型。它利用Keras RNN的三个关键功能:
return_state构造器参数,配置RNN层返回一个列表,其中,第一项是输出与下一个条目是内部RNN状态。这用于恢复编码器的状态。
inital_state呼叫参数,指定一个RNN的初始状态(S)。这用于将编码器状态作为初始状态传递给解码器。
return_sequences构造函数的参数,配置RNN返回其输出全序列(而不只是最后的输出,其默认行为)。在解码器中使用。
from keras.models import Model from keras.layers import Input, LSTM, Dense # Define an input sequence and process it. encoder_inputs = Input(shape=(None, num_encoder_tokens)) encoder = LSTM(latent_dim, return_state=True) encoder_outputs, state_h, state_c = encoder(encoder_inputs) # We discard `encoder_outputs` and only keep the states. encoder_states = [state_h, state_c] # Set up the decoder, using `encoder_states` as initial state. decoder_inputs = Input(shape=(None, num_decoder_tokens)) # We set up our decoder to return full output sequences, # and to return internal states as well. We don't use the # return states in the training model, but we will use them in inference. decoder_lstm = LSTM(latent_dim, return_sequences=True, return_state=True) decoder_outputs, _, _ = decoder_lstm(decoder_inputs, initial_state=encoder_states) decoder_dense = Dense(num_decoder_tokens, activation='softmax') decoder_outputs = decoder_dense(decoder_outputs) # Define the model that will turn # `encoder_input_data` & `decoder_input_data` into `decoder_target_data` model = Model([encoder_inputs, decoder_inputs], decoder_outputs)
我们分两行训练我们的模型,同时监视20%的保留样本中的损失。
# Run training model.compile(optimizer='rmsprop', loss='categorical_crossentropy') model.fit([encoder_input_data, decoder_input_data], decoder_target_data, batch_size=batch_size, epochs=epochs, validation_split=0.2)
在MacBook CPU上运行大约一个小时后,我们就可以进行推断了。为了解码测试语句,我们将反复:
这是我们的推理设置:
encoder_model = Model(encoder_inputs, encoder_states) decoder_state_input_h = Input(shape=(latent_dim,)) decoder_state_input_c = Input(shape=(latent_dim,)) decoder_states_inputs = [decoder_state_input_h, decoder_state_input_c] decoder_outputs, state_h, state_c = decoder_lstm( decoder_inputs, initial_state=decoder_states_inputs) decoder_states = [state_h, state_c] decoder_outputs = decoder_dense(decoder_outputs) decoder_model = Model( [decoder_inputs] + decoder_states_inputs, [decoder_outputs] + decoder_states)
我们使用它来实现上述推理循环:
def decode_sequence(input_seq): # Encode the input as state vectors. states_value = encoder_model.predict(input_seq) # Generate empty target sequence of length 1. target_seq = np.zeros((1, 1, num_decoder_tokens)) # Populate the first character of target sequence with the start character. target_seq[0, 0, target_token_index['\t']] = 1. # Sampling loop for a batch of sequences # (to simplify, here we assume a batch of size 1). stop_condition = False decoded_sentence = '' while not stop_condition: output_tokens, h, c = decoder_model.predict( [target_seq] + states_value) # Sample a token sampled_token_index = np.argmax(output_tokens[0, -1, :]) sampled_char = reverse_target_char_index[sampled_token_index] decoded_sentence += sampled_char # Exit condition: either hit max length # or find stop character. if (sampled_char == '\n' or len(decoded_sentence) > max_decoder_seq_length): stop_condition = True # Update the target sequence (of length 1). target_seq = np.zeros((1, 1, num_decoder_tokens)) target_seq[0, 0, sampled_token_index] = 1. # Update states states_value = [h, c] return decoded_sentence
我们得到了一些不错的结果-毫不奇怪,因为我们正在解码从训练测试中提取的样本
Input sentence: Be nice. Decoded sentence: Soyez gentil ! - Input sentence: Drop it! Decoded sentence: Laissez tomber ! - Input sentence: Get out! Decoded sentence: Sortez !
到此,我们结束了对Keras中序列到序列模型的十分钟介绍。提醒:此脚本的完整代码可以在GitHub上找到。
数据分析咨询请扫描二维码
若不方便扫码,搜微信号:CDAshujufenxi
Excel 导入数据含缺失值?详解 dropna 函数的功能与实战应用 在用 Python(如 pandas 库)处理 Excel 数据时,“缺失值” 是高频 ...
2025-09-16深入解析卡方检验与 t 检验:差异、适用场景与实践应用 在数据分析与统计学领域,假设检验是验证研究假设、判断数据差异是否 “ ...
2025-09-16CDA 数据分析师:掌控表格结构数据全功能周期的专业操盘手 表格结构数据(以 “行 - 列” 存储的结构化数据,如 Excel 表、数据 ...
2025-09-16MySQL 执行计划中 rows 数量的准确性解析:原理、影响因素与优化 在 MySQL SQL 调优中,EXPLAIN执行计划是核心工具,而其中的row ...
2025-09-15解析 Python 中 Response 对象的 text 与 content:区别、场景与实践指南 在 Python 进行 HTTP 网络请求开发时(如使用requests ...
2025-09-15CDA 数据分析师:激活表格结构数据价值的核心操盘手 表格结构数据(如 Excel 表格、数据库表)是企业最基础、最核心的数据形态 ...
2025-09-15Python HTTP 请求工具对比:urllib.request 与 requests 的核心差异与选择指南 在 Python 处理 HTTP 请求(如接口调用、数据爬取 ...
2025-09-12解决 pd.read_csv 读取长浮点数据的科学计数法问题 为帮助 Python 数据从业者解决pd.read_csv读取长浮点数据时的科学计数法问题 ...
2025-09-12CDA 数据分析师:业务数据分析步骤的落地者与价值优化者 业务数据分析是企业解决日常运营问题、提升执行效率的核心手段,其价值 ...
2025-09-12用 SQL 验证业务逻辑:从规则拆解到数据把关的实战指南 在业务系统落地过程中,“业务逻辑” 是连接 “需求设计” 与 “用户体验 ...
2025-09-11塔吉特百货孕妇营销案例:数据驱动下的精准零售革命与启示 在零售行业 “流量红利见顶” 的当下,精准营销成为企业突围的核心方 ...
2025-09-11CDA 数据分析师与战略 / 业务数据分析:概念辨析与协同价值 在数据驱动决策的体系中,“战略数据分析”“业务数据分析” 是企业 ...
2025-09-11Excel 数据聚类分析:从操作实践到业务价值挖掘 在数据分析场景中,聚类分析作为 “无监督分组” 的核心工具,能从杂乱数据中挖 ...
2025-09-10统计模型的核心目的:从数据解读到决策支撑的价值导向 统计模型作为数据分析的核心工具,并非简单的 “公式堆砌”,而是围绕特定 ...
2025-09-10CDA 数据分析师:商业数据分析实践的落地者与价值创造者 商业数据分析的价值,最终要在 “实践” 中体现 —— 脱离业务场景的分 ...
2025-09-10机器学习解决实际问题的核心关键:从业务到落地的全流程解析 在人工智能技术落地的浪潮中,机器学习作为核心工具,已广泛应用于 ...
2025-09-09SPSS 编码状态区域中 Unicode 的功能与价值解析 在 SPSS(Statistical Product and Service Solutions,统计产品与服务解决方案 ...
2025-09-09CDA 数据分析师:驾驭商业数据分析流程的核心力量 在商业决策从 “经验驱动” 向 “数据驱动” 转型的过程中,商业数据分析总体 ...
2025-09-09R 语言:数据科学与科研领域的核心工具及优势解析 一、引言 在数据驱动决策的时代,无论是科研人员验证实验假设(如前文中的 T ...
2025-09-08T 检验在假设检验中的应用与实践 一、引言 在科研数据分析、医学实验验证、经济指标对比等领域,常常需要判断 “样本间的差异是 ...
2025-09-08