Keras之注意力模型实现
习的一个github上的代码,分析了一下实现过程。代码下载链接:https://github.com/Choco31415/Attention_Network_With_Keras
代码的主要目标是通过一个描述时间的字符串,预测为数字形式的字符串。如“ten before ten o'clock a.m”预测为09:50
在jupyter上运行,代码如下:
1,导入模块,好像并没有全部使用到,如Permute,Multiply,Reshape,LearningRateScheduler等,这些应该是优化的时候使用的
1 from keras.layers import Bidirectional, Concatenate, Permute, Dot, Input, LSTM, Multiply, Reshape 2 from keras.layers import RepeatVector, Dense, Activation, Lambda 3 from keras.optimizers import Adam 4 #from keras.utils import to_categorical 5 from keras.models import load_model, Model 6 #from keras.callbacks import LearningRateScheduler 7 import keras.backend as K 8 9 import matplotlib.pyplot as plt 10 %matplotlib inline 11 12 import random 13 #import math14 15 import json 16 import numpy as np
2,加载数据集,以及翻译前和翻译后的词典
1 with open('data/Time Dataset.json','r') as f: 2 dataset = json.loads(f.read()) 3 with open('data/Time Vocabs.json','r') as f: 4 human_vocab, machine_vocab = json.loads(f.read()) 5 6 human_vocab_size = len(human_vocab) 7 machine_vocab_size = len(machine_vocab)
这里human_vocab词典是将每个字符映射到索引,machine_vocab也是将翻译后的字符映射到索引,因为翻译后的时间只包含0-9以及:
3,定义数据处理方法
tokenize为将字符映射到索引,one-hot为对每个映射后的索引做了个one-hot编码处理
1 def preprocess_data(dataset, human_vocab, machine_vocab, Tx, Ty): 2 """ 3 A method for tokenizing data. 4 5 Inputs: 6