京公网安备 11010802034615号
经营许可证编号:京B2-20210330
作者 | Francois Chollet
编译 | CDA数据分析师
A ten-minute introduction to sequence-to-sequence learning in Keras
什么是顺序学习?
序列到序列学习(Seq2Seq)是关于将模型从一个域(例如英语中的句子)转换为另一域(例如将相同句子翻译为法语的序列)的训练模型。
“猫坐在垫子上” -> [ Seq2Seq 模型] -> “在小吃中聊天
这可用于机器翻译或免费问答(在给定自然语言问题的情况下生成自然语言答案)-通常,它可在需要生成文本的任何时间使用。
有多种处理此任务的方法,可以使用RNN或使用一维卷积网络。在这里,我们将重点介绍RNN。
普通情况:输入和输出序列的长度相同
当输入序列和输出序列的长度相同时,您可以简单地使用Keras LSTM或GRU层(或其堆栈)来实现此类模型。在此示例脚本 中就是这种情况, 该脚本显示了如何教RNN学习加编码为字符串的数字:
该方法的一个警告是,它假定可以生成target[...t]给定input[...t]。在某些情况下(例如添加数字字符串),该方法有效,但在大多数用例中,则无效。在一般情况下,有关整个输入序列的信息是必需的,以便开始生成目标序列。
一般情况:规范序列间
在一般情况下,输入序列和输出序列具有不同的长度(例如,机器翻译),并且需要整个输入序列才能开始预测目标。这需要更高级的设置,这是人们在没有其他上下文的情况下提到“序列模型的序列”时通常所指的东西。运作方式如下:
RNN层(或其堆栈)充当“编码器”:它处理输入序列并返回其自己的内部状态。请注意,我们放弃了编码器RNN的输出,仅恢复 了状态。在下一步中,此状态将用作解码器的“上下文”或“条件”。
另一个RNN层(或其堆栈)充当“解码器”:在给定目标序列的先前字符的情况下,对其进行训练以预测目标序列的下一个字符。具体而言,它经过训练以将目标序列变成相同序列,但在将来会偏移一个时间步,在这种情况下,该训练过程称为“教师强迫”。重要的是,编码器使用来自编码器的状态向量作为初始状态,这就是解码器如何获取有关应该生成的信息的方式。有效地,解码器学会产生targets[t+1...] 给定的targets[...t],调节所述输入序列。
在推断模式下,即当我们想解码未知的输入序列时,我们会经历一个略有不同的过程:
同样的过程也可以用于训练Seq2Seq网络,而无需 “教师强制”,即通过将解码器的预测重新注入到解码器中。
一个Keras例子
因为训练过程和推理过程(解码句子)有很大的不同,所以我们对两者使用不同的模型,尽管它们都利用相同的内部层。
这是我们的训练模型。它利用Keras RNN的三个关键功能:
return_state构造器参数,配置RNN层返回一个列表,其中,第一项是输出与下一个条目是内部RNN状态。这用于恢复编码器的状态。
inital_state呼叫参数,指定一个RNN的初始状态(S)。这用于将编码器状态作为初始状态传递给解码器。
return_sequences构造函数的参数,配置RNN返回其输出全序列(而不只是最后的输出,其默认行为)。在解码器中使用。
from keras.models import Model from keras.layers import Input, LSTM, Dense # Define an input sequence and process it. encoder_inputs = Input(shape=(None, num_encoder_tokens)) encoder = LSTM(latent_dim, return_state=True) encoder_outputs, state_h, state_c = encoder(encoder_inputs) # We discard `encoder_outputs` and only keep the states. encoder_states = [state_h, state_c] # Set up the decoder, using `encoder_states` as initial state. decoder_inputs = Input(shape=(None, num_decoder_tokens)) # We set up our decoder to return full output sequences, # and to return internal states as well. We don't use the # return states in the training model, but we will use them in inference. decoder_lstm = LSTM(latent_dim, return_sequences=True, return_state=True) decoder_outputs, _, _ = decoder_lstm(decoder_inputs, initial_state=encoder_states) decoder_dense = Dense(num_decoder_tokens, activation='softmax') decoder_outputs = decoder_dense(decoder_outputs) # Define the model that will turn # `encoder_input_data` & `decoder_input_data` into `decoder_target_data` model = Model([encoder_inputs, decoder_inputs], decoder_outputs)
我们分两行训练我们的模型,同时监视20%的保留样本中的损失。
# Run training model.compile(optimizer='rmsprop', loss='categorical_crossentropy') model.fit([encoder_input_data, decoder_input_data], decoder_target_data, batch_size=batch_size, epochs=epochs, validation_split=0.2)
在MacBook CPU上运行大约一个小时后,我们就可以进行推断了。为了解码测试语句,我们将反复:
这是我们的推理设置:
encoder_model = Model(encoder_inputs, encoder_states) decoder_state_input_h = Input(shape=(latent_dim,)) decoder_state_input_c = Input(shape=(latent_dim,)) decoder_states_inputs = [decoder_state_input_h, decoder_state_input_c] decoder_outputs, state_h, state_c = decoder_lstm( decoder_inputs, initial_state=decoder_states_inputs) decoder_states = [state_h, state_c] decoder_outputs = decoder_dense(decoder_outputs) decoder_model = Model( [decoder_inputs] + decoder_states_inputs, [decoder_outputs] + decoder_states)
我们使用它来实现上述推理循环:
def decode_sequence(input_seq): # Encode the input as state vectors. states_value = encoder_model.predict(input_seq) # Generate empty target sequence of length 1. target_seq = np.zeros((1, 1, num_decoder_tokens)) # Populate the first character of target sequence with the start character. target_seq[0, 0, target_token_index['\t']] = 1. # Sampling loop for a batch of sequences # (to simplify, here we assume a batch of size 1). stop_condition = False decoded_sentence = '' while not stop_condition: output_tokens, h, c = decoder_model.predict( [target_seq] + states_value) # Sample a token sampled_token_index = np.argmax(output_tokens[0, -1, :]) sampled_char = reverse_target_char_index[sampled_token_index] decoded_sentence += sampled_char # Exit condition: either hit max length # or find stop character. if (sampled_char == '\n' or len(decoded_sentence) > max_decoder_seq_length): stop_condition = True # Update the target sequence (of length 1). target_seq = np.zeros((1, 1, num_decoder_tokens)) target_seq[0, 0, sampled_token_index] = 1. # Update states states_value = [h, c] return decoded_sentence
我们得到了一些不错的结果-毫不奇怪,因为我们正在解码从训练测试中提取的样本
Input sentence: Be nice. Decoded sentence: Soyez gentil ! - Input sentence: Drop it! Decoded sentence: Laissez tomber ! - Input sentence: Get out! Decoded sentence: Sortez !
到此,我们结束了对Keras中序列到序列模型的十分钟介绍。提醒:此脚本的完整代码可以在GitHub上找到。
数据分析咨询请扫描二维码
若不方便扫码,搜微信号:CDAshujufenxi
当沃尔玛数据分析师首次发现 “啤酒与尿布” 的高频共现规律时,他们揭开了数据挖掘最迷人的面纱 —— 那些隐藏在消费行为背后 ...
2025-11-03这个问题精准切中了配对样本统计检验的核心差异点,理解二者区别是避免统计方法误用的关键。核心结论是:stats.ttest_rel(配对 ...
2025-11-03在 CDA(Certified Data Analyst)数据分析师的工作中,“高维数据的潜在规律挖掘” 是进阶需求 —— 例如用户行为包含 “浏览次 ...
2025-11-03在 MySQL 数据查询中,“按顺序计数” 是高频需求 —— 例如 “统计近 7 天每日订单量”“按用户 ID 顺序展示消费记录”“按产品 ...
2025-10-31在数据分析中,“累计百分比” 是衡量 “部分与整体关系” 的核心指标 —— 它通过 “逐步累加的占比”,直观呈现数据的分布特征 ...
2025-10-31在 CDA(Certified Data Analyst)数据分析师的工作中,“二分类预测” 是高频需求 —— 例如 “预测用户是否会流失”“判断客户 ...
2025-10-31在 MySQL 实际应用中,“频繁写入同一表” 是常见场景 —— 如实时日志存储(用户操作日志、系统运行日志)、高频交易记录(支付 ...
2025-10-30为帮助教育工作者、研究者科学分析 “班级规模” 与 “平均成绩” 的关联关系,我将从相关系数的核心定义与类型切入,详解 “数 ...
2025-10-30对 CDA(Certified Data Analyst)数据分析师而言,“相关系数” 不是简单的数字计算,而是 “从业务问题出发,量化变量间关联强 ...
2025-10-30在构建前向神经网络(Feedforward Neural Network,简称 FNN)时,“隐藏层数目设多少?每个隐藏层该放多少个神经元?” 是每个 ...
2025-10-29这个问题切中了 Excel 用户的常见困惑 —— 将 “数据可视化工具” 与 “数据挖掘算法” 的功能边界混淆。核心结论是:Excel 透 ...
2025-10-29在 CDA(Certified Data Analyst)数据分析师的工作中,“多组数据差异验证” 是高频需求 —— 例如 “3 家门店的销售额是否有显 ...
2025-10-29在数据分析中,“正态分布” 是许多统计方法(如 t 检验、方差分析、线性回归)的核心假设 —— 数据符合正态分布时,统计检验的 ...
2025-10-28箱线图(Box Plot)作为展示数据分布的核心统计图表,能直观呈现数据的中位数、四分位数、离散程度与异常值,是质量控制、实验分 ...
2025-10-28在 CDA(Certified Data Analyst)数据分析师的工作中,“分类变量关联分析” 是高频需求 —— 例如 “用户性别是否影响支付方式 ...
2025-10-28在数据可视化领域,单一图表往往难以承载多维度信息 —— 力导向图擅长展现节点间的关联结构与空间分布,却无法直观呈现 “流量 ...
2025-10-27这个问题问到了 Tableau 中两个核心行级函数的经典组合,理解它能帮你快速实现 “相对位置占比” 的分析需求。“index ()/size ( ...
2025-10-27对 CDA(Certified Data Analyst)数据分析师而言,“假设检验” 绝非 “套用统计公式的机械操作”,而是 “将模糊的业务猜想转 ...
2025-10-27在数字化运营中,“凭感觉做决策” 早已成为过去式 —— 运营指标作为业务增长的 “晴雨表” 与 “导航仪”,直接决定了运营动作 ...
2025-10-24在卷积神经网络(CNN)的训练中,“卷积层(Conv)后是否添加归一化(如 BN、LN)和激活函数(如 ReLU、GELU)” 是每个开发者都 ...
2025-10-24