首先在keras环境中训练参数,得到lstm的权重和偏置(此处推荐不选择使用LSTM的bias选项,这样可以在剪枝等操作时更好地对比变化后的影响)
①查看keras的LSTM计算过程
self.kernel_i = self.kernel[:, :self.units]
self.kernel_f = self.kernel[:, self.units: self.units * 2]
self.kernel_c = self.kernel[:, self.units * 2: self.units * 3]
self.kernel_o = self.kernel[:, self.units * 3:]
self.recurrent_kernel_i = self.recurrent_kernel[:, :self.units]
self.recurrent_kernel_f = self.recurrent_kernel[:, self.units: self.units * 2]
self.recurrent_kernel_c = self.recurrent_kernel[:, self.units * 2: self.units * 3]
self.recurrent_kernel_o = self.recurrent_kernel[:, self.units * 3:]
if self.use_bias:
self.bias_i = self.bias[:self.units]
self.bias_f = self.bias[self.units: self.units * 2]
self.bias_c = self.bias[self.units * 2: self.units * 3]
self.bias_o = self.bias[self.units * 3:]
else:
self.bias_i = None
self.bias_f = None
self.bias_c = None
self.bias_o = None
self.built = True
def call(self, inputs, states, training=None):
if 0 < self.dropout < 1 and self._dropout_mask is None:
self._dropout_mask = _generate_dropout_mask(
K.ones_like(inputs),
self.dropout,
training=training,
count=4)
if (0 < self.recurrent_dropout < 1 and
self._recurrent_dropout_mask is None):
self._recurrent_dropout_mask = _generate_dropout_mask(
K.ones_like(states[0]),
self.recurrent_dropout,
training=training,
count=4)
# dropout matrices for input units
dp_mask = self._dropout_mask
# dropout matrices for recurrent units
rec_dp_mask = self._recurrent_dropout_mask
h_tm1 = states[0] # previous memory state
c_tm1 = states[1] # previous carry state
if self.implementation == 1:
if 0 < self.dropout < 1.:
inputs_i = inputs * dp_mask[0]
inputs_f = inputs * dp_mask[1]
inputs_c = inputs * dp_mask[2]
inputs_o = inputs * dp_mask[3]
else:
inputs_i = inputs
inputs_f = inputs
inputs_c = inputs
inputs_o = inputs
x_i = K.dot(inputs_i, self.kernel_i)
x_f = K.dot(inputs_f, self.kernel_f)
x_c = K.dot(inputs_c, self.kernel_c)
x_o = K.dot(inputs_o, self.kernel_o)
if self.use_bias:
x_i = K.bias_add(x_i, self.bias_i)
x_f = K.bias_add(x_f, self.bias_f)
x_c = K.bias_add(x_c, self.bias_c)
x_o = K.bias_add(x_o, self.bias_o)
if 0 < self.recurrent_dropout < 1.:
h_tm1_i = h_tm1 * rec_dp_mask[0]
h_tm1_f = h_tm1 * rec_dp_mask[1]
h_tm1_c = h_tm1 * rec_dp_mask[2]
h_tm1_o = h_tm1 * rec_dp_mask[3]
else:
h_tm1_i = h_tm1
h_tm1_f = h_tm1
h_tm1_c = h_tm1
h_tm1_o = h_tm1
i = self.recurrent_activation(x_i + K.dot(h_tm1_i,
self.recurrent_kernel_i))
f = self.recurrent_activation(x_f + K.dot(h_tm1_f,
self.recurrent_kernel_f))
c = f * c_tm1 + i * self.activation(x_c + K.dot(h_tm1_c,
self.recurrent_kernel_c))
o = self.recurrent_activation(x_o + K.dot(h_tm1_o,
self.recurrent_kernel_o))
else:
if 0. < self.dropout < 1.:
inputs *= dp_mask[0]
z = K.dot(inputs, self.kernel)
if 0. < self.recurrent_dropout < 1.:
h_tm1 *= rec_dp_mask[0]
z += K.dot(h_tm1, self.recurrent_kernel)
if self.use_bias:
z = K.bias_add(z, self.bias)
z0 = z[:, :self.units]
z1 = z[:, self.units: 2 * self.units]
z2 = z[:, 2 * self.units: 3 * self.units]
z3 = z[:, 3 * self.units:]
i = self.recurrent_activation(z0)
f = self.recurrent_activation(z1)
c = f * c_tm1 + i * self.activation(z2)
o = self.recurrent_activation(z3)
h = o * self.activation(c)
keras中将输入和上一级的输出进行了分开的矩阵乘法,kernal是输入的权重,recurrent_kernal是上级输出的权重;同时我们可以确定LSTM中的四个门的权重排序方式为 i f c o
②将权重导出
# LSTM参数提取
weight_Dense_1, weight_lstm2, bias_Dense_1 = model.get_layer('lstm_1').get_weights()
w1 = weight_Dense_1[:, :32]
w2 = weight_Dense_1[:, 32:64]
w3 = weight_Dense_1[:, 64:96]
w4 = weight_Dense_1[:, 96:]
weight_Dense_1 = np.concatenate((w1, w2, w3, w4), axis=0)
weight_Dense_1 = weight_Dense_1.reshape(2048, )
w1 = weight_lstm2[:, :32]
w2 = weight_lstm2[:, 32:64]
w3 = weight_lstm2[:, 64:96]
w4 = weight_lstm2[:, 96:]
weight_lstm2 = np.concatenate((w1, w2, w3, w4), axis=0)
weight_lstm2 = weight_lstm2.reshape(4096, )
win_para = np.around(weight_Dense_1, decimals=4)
win_para2 = np.around(weight_lstm2, decimals=4)
③在txt中复制权重初始化的数据,初始化stm32中的权重,利用arm_math中的矩阵乘法进行运算,得到正确结果
④补充:
注意leras和stm32的运算顺序一致
版权声明:本文为CSDN博主「linxuplus」的原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接及本声明。
原文链接:https://blog.csdn.net/ZXCVBNNHU/article/details/121214495
暂无评论