RNN入门(4)利用LSTM实现整数加法运算

 本文将介绍LSTM模型在实现整数加法方面的应用。   我们以0-255之间的整数加法为例,生成的结果在0到510之间。为了能利用深度学习模型模拟整数的加法运算,我们需要将输入的两个加数和输出的结果用二进制表示,这样就能得到向量,如加数在0-255内,可以用8位0-1向量来表示,前面的空位用0填充;结果在0-510内,可以用9位0-1向量来表示,前面的空位用0填充。因为两个加数均在0-255内变化,所以共有256*256=65536个输入向量以及65536个输出向量,输入向量为两个加数的二进制向量的拼接结果,因而是个16为的输入向量。用以下的Python代码可以模拟以上过程: import numpy as np # 最多8位二进制 BINARY_DIM = 8 # 将整数表示成为binary_dim位的二进制数,高位用0补齐 def int_2_binary(number, binary_dim): binary_list = list(map(lambda x: int(x), bin(number)[2:])) number_dim = len(binary_list) result_list = [0]*(binary_dim-number_dim)+binary_list return result_list # 将一个二进制数组转为整数 def binary2int(binary_array): out = 0 for index, x in enumerate(reversed(binary_array)): out += x * pow(2, index) return out # 将[0,2**BINARY_DIM)所有数表示成二进制 binary = np.array([int_2_binary(x, BINARY_DIM) for x in range(2**BINARY_DIM)]) # print(binary) # 样本的输入向量和输出向量 dataX = [] dataY = [] for i in range(binary.shape[0]): for j in range(binary.shape[0]): dataX.append(np.append(binary[i], binary[j])) dataY.append(int_2_binary(i+j, BINARY_DIM+1)) # print(dataX) # print(dataY) # 重新特征X和目标变量Y数组,适应LSTM模型的输入和输出 X = np.reshape(dataX, (len(dataX), 2*BINARY_DIM, 1)) # print(X.shape) Y = np.array(dataY) # print(dataY.shape) 在以上代码中,得到的dataX和dataY以满足要求,但为了能让LSTM模型处理,需要改变这两个数据集的形状。   我们采用LSTM模型来训练上述数据,LSTM模型的结构很简单,就是简单的一层LSTM层,然后加上Dropout层,最后是全连接层,激活函数采用sigmoid函数,采用的损失函数为平均平方误差。整个结构的示意图如下: LSTM模型的结构示意图 模型训练的代码如下: from keras.models import Sequential from keras.layers import Dense from keras.layers import Dropout from keras.layers import LSTM from keras import losses from keras.utils import plot_model # 定义LSTM模型 model = Sequential() model.add(LSTM(256, input_shape=(X.shape[1], X.shape[2]))) model.add(Dropout(0.2)) model.add(Dense(Y.shape[1], activation='sigmoid')) model.compile(loss=losses.mean_squared_error, optimizer='adam') # print(model.summary()) # plot model plot_model(model, to_file=r'./model.png', show_shapes=True) # train model epochs = 100 model.fit(X, Y, epochs=epochs, batch_size=128) # save model mp = r'./LSTM_Operation.h5' model.save(mp) 该LSTM模型每批训练128个样本,共训练100次,采用Adam优化器减少损失值。   对这个模型进行训练,训练100次,损失值为0.0045。接下来我们就要用这个训练好的模型来预测。我们预测的方法为,虽然挑两个在0-255内的加数,转化为二进制向量作为输入向量,然后由LSTM模型输出结果,将该结果取整作为输出向量中的元素,最后将这个输出向量转化为整数,就是预测的两个加数的和。模型预测的代码如下: # use LSTM model to predict for _ in range(100): start = np.random.randint(0, len(dataX)-1) # print(dataX[start]) number1 = dataX[start][0:BINARY_DIM] number2 = dataX[start][BINARY_DIM:] print('='*30) print('%s: %s'%(number1, binary2int(number1))) print('%s: %s'%(number2, binary2int(number2))) sample = np.reshape(X[start], (1, 2*BINARY_DIM, 1)) predict = np.round(model.predict(sample), 0).astype(np.int32)[0] print('%s: %s'%(predict, binary2int(predict))) 预测的100组样本的输出结果如下: ============================== [1 0 0 1 1 1 0 1]: 157 [0 1 1 1 0 0 0 1]: 113 [1 0 0 0 0 1 1 1 0]: 270 ============================== [1 1 1 0 1 0 1 0]: 234 [0 1 0 0 1 1 0 0]: 76 [1 0 0 1 1 0 1 1 0]: 310 ============================== [1 1 0 0 0 1 0 0]: 196 [1 1 0 1 1 0 1 1]: 219 [1 1 0 0 1 1 1 1 1]: 415 ============================== [0 0 1 1 1 0 1 0]: 58 [0 0 1 0 0 0 1 1]: 35 [0 0 1 0 1 1 1 0 1]: 93 ============================== [1 0 0 0 0 0 0 0]: 128 [0 1 1 1 1 0 0 1]: 121 [0 1 1 1 1 1 0 0 1]: 249 ============================== [1 1 1 1 0 1 1 0]: 246 [1 1 0 1 0 1 0 1]: 213 [1 1 1 0 0 1 0 1 1]: 459 ============================== [1 1 1 0 0 1 1 0]: 230 [1 0 0 0 0 0 0 0]: 128 [1 0 1 1 0 0 1 1 0]: 358 ============================== [1 0 1 0 0 0 1 1]: 163 [0 1 1 0 0 1 0 1]: 101 [1 0 0 0 0 1 0 0 0]: 264 ============================== [1 0 1 0 0 1 1 0]: 166 [0 1 0 1 0 0 0 0]: 80 [0 1 1 1 1 0 1 1 0]: 246 ============================== [0 0 0 0 1 0 1 1]: 11 [0 1 0 0 0 1 0 1]: 69 [0 0 1 0 1 0 0 0 0]: 80 ============================== [1 1 1 1 0 1 1 1]: 247 [0 1 1 1 0 0 0 0]: 112 [1 0 1 1 0 0 1 1 1]: 359 ============================== [1 0 1 0 1 0 0 1]: 169 [1 1 0 0 0 0 0 0]: 192 [1 0 1 1 0 1 0 0 1]: 361 ============================== [1 0 1 1 0 0 0 1]: 177 [1 0 0 0 1 0 1 1]: 139 [1 0 0 1 1 1 1 0 0]: 316 ============================== [0 1 0 0 0 1 1 0]: 70 [0 0 1 0 1 1 1 0]: 46 [0 0 1 1 1 0 1 0 0]: 116 ============================== [1 0 0 1 1 0 1 1]: 155 [1 1 0 0 0 0 0 1]: 193 [1 0 1 0 1 1 1 0 0]: 348 ============================== [1 0 1 1 0 0 1 0]: 178 [1 0 0 0 1 1 1 1]: 143 [1 0 1 0 0 0 0 0 1]: 321 ============================== [0 1 0 1 1 1 1 1]: 95 [1 1 1 0 0 1 0 0]: 228 [1 0 1 0 0 0 0 1 1]: 323 ============================== [1 0 0 1 1 1 1 0]: 158 [0 0 0 1 1 0 0 1]: 25 [0 1 0 1 1 0 1 1 1]: 183 ============================== [1 1 1 0 1 0 1 1]: 235 [1 1 0 0 0 0 0 1]: 193 [1 1 0 1 0 1 1 0 0]: 428 ============================== [0 1 0 1 1 1 0 1]: 93 [0 1 1 1 0 1 1 0]: 118 [0 1 1 0 1 0 0 1 1]: 211 ============================== [1 1 1 1 1 1 1 1]: 255 [1 1 1 1 1 1 1 0]: 254 [1 1 1 1 1 1 1 0 1]: 509 ============================== [0 1 0 1 1 0 0 1]: 89 [0 1 0 1 1 1 1 0]: 94 [0 1 0 1 1 0 1 1 1]: 183 ============================== [0 1 1 1 0 0 0 0]: 112 [0 0 1 1 0 1 0 0]: 52 [0 1 0 1 0 0 1 0 0]: 164 ============================== [1 0 0 0 0 0 0 0]: 128 [1 1 0 1 1 0 1 0]: 218 [1 0 1 0 1 1 0 1 0]: 346 ============================== [0 0 1 1 0 1 0 1]: 53 [1 0 1 1 1 1 1 0]: 190 [0 1 1 1 1 0 0 1 1]: 243 ============================== [0 1 1 1 1 0 0 0]: 120 [1 1 0 1 0 1 0 1]: 213 [1 0 1 0 0 1 1 0 1]: 333 ============================== [0 1 1 1 1 0 1 1]: 123 [1 1 1 0 1 1 0 1]: 237 [1 0 1 1 0 1 0 0 0]: 360 ============================== [1 0 0 1 1 0 1 0]: 154 [0 1 1 0 1 0 0 1]: 105 [1 0 0 0 0 0 0 1 1]: 259 ============================== [0 0 0 1 1 0 0 1]: 25 [0 1 0 1 1 0 1 0]: 90 [0 0 1 1 1 0 0 1 1]: 115 ============================== [1 1 1 1 0 0 0 1]: 241 [0 0 0 1 1 1 1 1]: 31 [1 0 0 0 1 0 0 0 0]: 272 ============================== [0 1 0 0 0 1 1 0]: 70 [1 1 1 0 1 0 0 1]: 233 [1 0 0 1 0 1 1 1 1]: 303 ============================== [1 0 1 0 1 1 0 1]: 173 [0 1 1 1 0 1 0 0]: 116 [1 0 0 1 0 0 0 0 1]: 289 ============================== [0 1 0 0 1 0 0 0]: 72 [1 1 1 1 1 0 1 0]: 250 [1 0 1 0 0 0 0 1 0]: 322 ============================== [1 1 1 1 0 0 0 0]: 240 [0 1 0 0 0 0 1 0]: 66 [1 0 0 1 1 0 0 1 0]: 306 ============================== [0 1 0 0 0 1 1 1]: 71 [1 0 0 1 0 1 1 0]: 150 [0 1 1 0 1 1 1 0 1]: 221 ============================== [0 1 1 0 1 1 0 1]: 109 [0 0 1 0 0 1 0 1]: 37 [0 1 0 0 1 0 0 1 0]: 146 ============================== [1 1 0 0 0 0 0 0]: 192 [1 1 1 0 0 0 0 1]: 225 [1 1 0 1 0 0 0 0 1]: 417 ============================== [1 0 0 0 0 0 1 1]: 131 [1 1 0 1 1 1 1 0]: 222 [1 0 1 1 0 0 0 0 1]: 353 ============================== [0 0 0 0 0 1 0 0]: 4 [1 1 1 0 0 0 1 0]: 226 [0 1 1 1 0 0 1 1 0]: 230 ============================== [1 1 1 0 1 1 1 1]: 239 [1 1 0 1 1 0 1 1]: 219 [1 1 1 0 0 1 0 1 0]: 458 ============================== [0 0 1 1 0 1 0 1]: 53 [1 1 1 1 0 0 1 0]: 242 [1 0 0 1 0 0 1 1 1]: 295 ============================== [1 0 0 1 0 0 0 1]: 145 [0 1 0 0 0 1 0 0]: 68 [0 1 1 0 1 0 1 0 1]: 213 ============================== [0 0 1 1 0 0 0 0]: 48 [1 0 1 1 0 1 1 1]: 183 [0 1 1 1 0 0 1 1 1]: 231 ============================== [0 1 1 0 0 1 1 1]: 103 [0 0 0 1 1 1 1 0]: 30 [0 1 0 0 0 0 1 0 1]: 133 ============================== [0 1 0 1 1 1 0 1]: 93 [1 1 0 1 0 0 1 0]: 210 [1 0 0 1 0 1 1 1 1]: 303 ============================== [1 0 0 0 1 0 1 0]: 138 [0 1 1 1 1 0 0 1]: 121 [1 0 0 0 0 0 0 1 1]: 259 ============================== [0 0 0 0 0 0 1 1]: 3 [0 0 1 1 0 0 0 1]: 49 [0 0 0 1 1 0 1 0 0]: 52 ============================== [1 0 0 0 0 0 1 0]: 130 [0 0 0 1 0 0 0 0]: 16 [0 1 0 0 1 0 0 1 0]: 146 ============================== [0 0 0 1 0 0 0 0]: 16 [1 0 0 1 0 0 1 0]: 146 [0 1 0 1 0 0 0 1 0]: 162 ============================== [0 1 0 1 0 1 0 0]: 84 [0 0 0 0 1 1 0 0]: 12 [0 0 1 1 0 0 0 0 0]: 96 ============================== [1 0 1 0 1 0 1 1]: 171 [1 1 0 1 1 0 1 1]: 219 [1 1 0 0 0 0 1 1 0]: 390 ============================== [1 1 1 1 1 1 1 0]: 254 [0 1 1 0 1 0 1 0]: 106 [1 0 1 1 0 1 0 0 0]: 360 ============================== [1 0 0 0 0 0 1 0]: 130 [0 0 0 0 1 1 1 0]: 14 [0 1 0 0 1 0 0 0 0]: 144 ============================== [1 0 1 0 0 1 0 1]: 165 [0 0 1 1 1 0 1 1]: 59 [0 1 1 1 0 0 0 0 0]: 224 ============================== [0 0 1 1 1 0 1 0]: 58 [1 1 1 1 0 0 1 0]: 242 [1 0 0 1 0 1 1 0 0]: 300 ============================== [0 1 0 0 1 1 0 1]: 77 [0 0 0 1 1 1 1 1]: 31 [0 0 1 1 0 1 1 0 0]: 108 ============================== [1 0 0 1 1 0 1 0]: 154 [0 1 0 1 0 1 0 1]: 85 [0 1 1 1 0 1 1 1 1]: 239 ============================== [0 1 1 0 1 1 0 1]: 109 [0 1 1 0 1 0 0 1]: 105 [0 1 1 0 1 0 1 1 0]: 214 ============================== [0 1 1 1 1 1 1 1]: 127 [0 1 1 1 0 0 1 0]: 114 [0 1 1 1 1 0 0 0 1]: 241 ============================== [0 1 1 0 0 1 0 1]: 101 [0 1 0 1 0 0 0 0]: 80 [0 1 0 1 1 0 1 0 1]: 181 ============================== [0 1 1 0 1 1 1 0]: 110 [0 1 0 1 0 1 1 0]: 86 [0 1 1 0 0 0 1 0 0]: 196 ============================== [0 0 0 1 0 0 1 1]: 19 [1 0 0 1 0 0 0 0]: 144 [0 1 0 1 0 0 0 1 1]: 163 ============================== [1 1 1 1 0 1 0 0]: 244 [1 1 0 1 0 0 1 1]: 211 [1 1 1 0 0 0 1 1 1]: 455 ============================== [0 0 0 0 1 1 1 0]: 14 [1 0 1 1 0 0 1 0]: 178 [0 1 1 0 0 0 0 0 0]: 192 ============================== [0 1 1 0 0 0 0 0]: 96 [1 0 0 1 1 1 0 0]: 156 [0 1 1 1 1 1 1 0 0]: 252 ============================== [0 0 1 1 0 1 0 0]: 52 [0 1 1 1 1 1 0 1]: 125 [0 1 0 1 1 0 0 0 1]: 177 ============================== [0 0 0 0 1 1 0 0]: 12 [0 1 0 1 1 1 0 1]: 93 [0 0 1 1 0 1 0 0 1]: 105 ============================== [0 1 1 0 0 1 0 1]: 101 [1 1 0 1 0 1 0 0]: 212 [1 0 0 1 1 1 0 0 1]: 313 ============================== [1 1 0 0 0 0 0 1]: 193 [1 1 0 0 1 1 0 1]: 205 [1 1 0 0 0 1 1 1 0]: 398 ============================== [0 1 1 1 0 0 1 0]: 114 [0 0 0 0 0 0 0 0]: 0 [0 0 1 1 1 0 0 1 0]: 114 ============================== [1 0 0 0 1 1 1 0]: 142 [1 0 1 1 1 1 0 1]: 189 [1 0 1 0 0 1 0 1 1]: 331 ============================== [1 0 1 1 0 1 1 1]: 183 [0 1 0 1 0 1 1 0]: 86 [1 0 0 0 0 1 1 0 1]: 269 ============================== [1 0 1 0 0 0 1 1]: 163 [1 1 1 0 0 1 0 1]: 229 [1 1 0 0 0 1 0 0 0]: 392 ============================== [0 0 1 1 0 0 0 1]: 49 [1 1 1 0 0 1 1 1]: 231 [1 0 0 0 1 1 0 0 0]: 280 ============================== [1 0 0 0 1 1 1 1]: 143 [1 0 1 0 1 0 0 0]: 168 [1 0 0 1 1 0 1 1 1]: 311 ============================== [0 1 0 0 0 0 0 0]: 64 [0 0 0 0 0 1 0 1]: 5 [0 0 1 0 0 0 1 0 1]: 69 ============================== [1 1 1 1 1 0 1 1]: 251 [1 0 1 1 1 0 0 1]: 185 [1 1 0 1 1 0 1 0 0]: 436 ============================== [1 1 1 0 1 1 1 0]: 238 [1 1 0 0 0 0 1 0]: 194 [1 1 0 1 1 0 0 0 0]: 432 ============================== [0 0 1 1 1 1 0 0]: 60 [0 0 0 1 0 1 1 1]: 23 [0 0 1 0 1 0 0 1 1]: 83 ============================== [0 1 1 1 0 1 0 0]: 116 [1 1 1 1 1 1 0 0]: 252 [1 0 1 1 1 0 0 0 0]: 368 ============================== [1 1 0 1 0 1 1 0]: 214 [1 1 1 1 0 1 0 0]: 244 [1 1 1 0 0 1 0 1 0]: 458 ============================== [1 1 1 1 1 1 1 0]: 254 [1 1 0 1 0 0 0 1]: 209 [1 1 1 0 0 1 1 1 1]: 463 ============================== [0 0 0 0 0 0 1 0]: 2 [0 0 0 0 1 1 0 1]: 13 [0 0 0 0 0 1 1 1 1]: 15 ============================== [0 1 1 0 0 1 1 1]: 103 [1 0 1 1 1 1 1 0]: 190 [1 0 0 1 0 0 1 0 1]: 293 ============================== [1 1 1 1 0 1 1 0]: 246 [0 1 0 1 0 0 1 0]: 82 [1 0 1 0 0 1 0 0 0]: 328 ============================== [0 1 1 1 0 0 1 1]: 115 [0 0 1 1 1 0 1 1]: 59 [0 1 0 1 0 1 1 1 0]: 174 ============================== [0 1 0 1 1 0 0 1]: 89 [0 1 1 0 1 0 1 1]: 107 [0 1 1 0 0 0 1 0 0]: 196 ============================== [0 1 0 0 0 1 0 0]: 68 [0 0 1 1 1 0 0 0]: 56 [0 0 1 1 1 1 1 0 0]: 124 ============================== [1 1 0 0 1 0 0 0]: 200 [1 0 1 0 0 0 1 0]: 162 [1 0 1 1 0 1 0 1 0]: 362 ============================== [1 1 1 1 0 0 1 1]: 243 [0 1 1 0 0 0 1 1]: 99 [1 0 1 0 1 0 1 1 0]: 342 ============================== [0 0 1 0 1 0 0 1]: 41 [0 1 0 0 1 0 0 1]: 73 [0 0 1 1 1 0 0 1 0]: 114 ============================== [0 0 0 1 1 1 0 1]: 29 [1 0 1 0 1 1 1 0]: 174 [0 1 1 0 0 1 0 1 1]: 203 ============================== [0 0 0 0 1 1 1 1]: 15 [0 0 1 1 1 1 0 1]: 61 [0 0 1 0 0 1 1 0 0]: 76 ============================== [1 1 1 1 1 0 1 1]: 251 [1 1 0 1 0 0 0 0]: 208 [1 1 1 0 0 1 0 1 1]: 459 ============================== [1 1 1 0 1 0 0 0]: 232 [0 1 1 0 0 0 1 0]: 98 [1 0 1 0 0 1 0 1 0]: 330 ============================== [1 0 1 1 0 1 0 0]: 180 [0 1 0 1 0 1 1 1]: 87 [1 0 0 0 0 1 0 1 1]: 267 ============================== [1 0 0 0 0 1 1 0]: 134 [1 0 0 1 0 1 0 1]: 149 [1 0 0 0 1 1 0 1 1]: 283 ============================== [1 0 1 0 1 1 0 1]: 173 [0 1 1 1 1 1 0 0]: 124 [1 0 0 1 0 1 0 0 1]: 297 ============================== [0 1 0 0 1 0 0 0]: 72 [0 1 1 0 0 0 1 1]: 99 [0 1 0 1 0 1 0 1 1]: 171 ============================== [1 1 0 1 0 1 0 1]: 213 [0 0 0 1 1 1 1 0]: 30 [0 1 1 1 1 0 0 1 1]: 243   可以看到,这个简单的LSTM模型的预测的结果全部正确。因此,这就可以用来模拟0-255内的整数的加法运算,是不是很神奇呢?   如果需要想将加数的范围扩大,只需要改变代码中的BINARY_DIM变量即可。但是,加数的范围越大,样本就越大,如2**10=1024内的加法,就会有1024*1024=1048576个样本,这样大的样本量的无疑需要更多的训练时间。   本文到此结束,感谢阅读~如果不当之处,请速联系笔者,欢迎大家交流~祝您好运~ 注意:本人现已开通微信公众号: Python爬虫与算法(微信号为:easy_web_scrape), 欢迎大家关注哦~~ 完整的Python代码如下: import numpy as np from keras.models import Sequential from keras.layers import Dense from keras.layers import Dropout from keras.layers import LSTM from k
50000+
5万行代码练就真实本领
17年
创办于2008年老牌培训机构
1000+
合作企业
98%
就业率

联系我们

电话咨询

0532-85025005

扫码添加微信