在STM32上实现Keras中的LSTM网络

首先在keras环境中训练参数,得到lstm的权重和偏置(此处推荐不选择使用LSTM的bias选项,这样可以在剪枝等操作时更好地对比变化后的影响)

①查看keras的LSTM计算过程

self.kernel_i = self.kernel[:, :self.units]
        self.kernel_f = self.kernel[:, self.units: self.units * 2]
        self.kernel_c = self.kernel[:, self.units * 2: self.units * 3]
        self.kernel_o = self.kernel[:, self.units * 3:]

        self.recurrent_kernel_i = self.recurrent_kernel[:, :self.units]
        self.recurrent_kernel_f = self.recurrent_kernel[:, self.units: self.units * 2]
        self.recurrent_kernel_c = self.recurrent_kernel[:, self.units * 2: self.units * 3]
        self.recurrent_kernel_o = self.recurrent_kernel[:, self.units * 3:]

        if self.use_bias:
            self.bias_i = self.bias[:self.units]
            self.bias_f = self.bias[self.units: self.units * 2]
            self.bias_c = self.bias[self.units * 2: self.units * 3]
            self.bias_o = self.bias[self.units * 3:]
        else:
            self.bias_i = None
            self.bias_f = None
            self.bias_c = None
            self.bias_o = None
        self.built = True

    def call(self, inputs, states, training=None):
        if 0 < self.dropout < 1 and self._dropout_mask is None:
            self._dropout_mask = _generate_dropout_mask(
                K.ones_like(inputs),
                self.dropout,
                training=training,
                count=4)
        if (0 < self.recurrent_dropout < 1 and
                self._recurrent_dropout_mask is None):
            self._recurrent_dropout_mask = _generate_dropout_mask(
                K.ones_like(states[0]),
                self.recurrent_dropout,
                training=training,
                count=4)

        # dropout matrices for input units
        dp_mask = self._dropout_mask
        # dropout matrices for recurrent units
        rec_dp_mask = self._recurrent_dropout_mask

        h_tm1 = states[0]  # previous memory state
        c_tm1 = states[1]  # previous carry state

        if self.implementation == 1:
            if 0 < self.dropout < 1.:
                inputs_i = inputs * dp_mask[0]
                inputs_f = inputs * dp_mask[1]
                inputs_c = inputs * dp_mask[2]
                inputs_o = inputs * dp_mask[3]
            else:
                inputs_i = inputs
                inputs_f = inputs
                inputs_c = inputs
                inputs_o = inputs
            x_i = K.dot(inputs_i, self.kernel_i)
            x_f = K.dot(inputs_f, self.kernel_f)
            x_c = K.dot(inputs_c, self.kernel_c)
            x_o = K.dot(inputs_o, self.kernel_o)
            if self.use_bias:
                x_i = K.bias_add(x_i, self.bias_i)
                x_f = K.bias_add(x_f, self.bias_f)
                x_c = K.bias_add(x_c, self.bias_c)
                x_o = K.bias_add(x_o, self.bias_o)

            if 0 < self.recurrent_dropout < 1.:
                h_tm1_i = h_tm1 * rec_dp_mask[0]
                h_tm1_f = h_tm1 * rec_dp_mask[1]
                h_tm1_c = h_tm1 * rec_dp_mask[2]
                h_tm1_o = h_tm1 * rec_dp_mask[3]
            else:
                h_tm1_i = h_tm1
                h_tm1_f = h_tm1
                h_tm1_c = h_tm1
                h_tm1_o = h_tm1
            i = self.recurrent_activation(x_i + K.dot(h_tm1_i,
                                                      self.recurrent_kernel_i))
            f = self.recurrent_activation(x_f + K.dot(h_tm1_f,
                                                      self.recurrent_kernel_f))
            c = f * c_tm1 + i * self.activation(x_c + K.dot(h_tm1_c,
                                                            self.recurrent_kernel_c))
            o = self.recurrent_activation(x_o + K.dot(h_tm1_o,
                                                      self.recurrent_kernel_o))
        else:
            if 0. < self.dropout < 1.:
                inputs *= dp_mask[0]
            z = K.dot(inputs, self.kernel)
            if 0. < self.recurrent_dropout < 1.:
                h_tm1 *= rec_dp_mask[0]
            z += K.dot(h_tm1, self.recurrent_kernel)
            if self.use_bias:
                z = K.bias_add(z, self.bias)

            z0 = z[:, :self.units]
            z1 = z[:, self.units: 2 * self.units]
            z2 = z[:, 2 * self.units: 3 * self.units]
            z3 = z[:, 3 * self.units:]

            i = self.recurrent_activation(z0)
            f = self.recurrent_activation(z1)
            c = f * c_tm1 + i * self.activation(z2)
            o = self.recurrent_activation(z3)

        h = o * self.activation(c)

keras中将输入和上一级的输出进行了分开的矩阵乘法,kernal是输入的权重,recurrent_kernal是上级输出的权重;同时我们可以确定LSTM中的四个门的权重排序方式为 i  f   c   o

②将权重导出

# LSTM参数提取
weight_Dense_1, weight_lstm2, bias_Dense_1 = model.get_layer('lstm_1').get_weights()
w1 = weight_Dense_1[:, :32]
w2 = weight_Dense_1[:, 32:64]
w3 = weight_Dense_1[:, 64:96]
w4 = weight_Dense_1[:, 96:]
weight_Dense_1 = np.concatenate((w1, w2, w3, w4), axis=0)
weight_Dense_1 = weight_Dense_1.reshape(2048, )
w1 = weight_lstm2[:, :32]
w2 = weight_lstm2[:, 32:64]
w3 = weight_lstm2[:, 64:96]
w4 = weight_lstm2[:, 96:]
weight_lstm2 = np.concatenate((w1, w2, w3, w4), axis=0)
weight_lstm2 = weight_lstm2.reshape(4096, )
win_para = np.around(weight_Dense_1, decimals=4)
win_para2 = np.around(weight_lstm2, decimals=4)

③在txt中复制权重初始化的数据,初始化stm32中的权重,利用arm_math中的矩阵乘法进行运算,得到正确结果

④补充:

注意leras和stm32的运算顺序一致

版权声明:本文为CSDN博主「linxuplus」的原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接及本声明。
原文链接:https://blog.csdn.net/ZXCVBNNHU/article/details/121214495

生成海报
点赞 0

linxuplus

我还没有学会写个人说明!

暂无评论

发表评论

相关推荐

RT-Thread Studio移植LAN8720A驱动

RTT网络协议栈驱动移植(霸天虎) 1、新建工程 ​ 工程路径不含中文路径名,工程名用纯英文不含任何符号。 2、用CubeMx配置板子外设 2.1、配置时钟 ​ 按照自己板子配置相应时钟。

Lin总线通信在STM32作为主机代码以及从机程序

距离上次做资料准备已经过去六天了。最近在学车,上周末就没有开电脑。这周开始进行了Lin通信的代码整理,目前是可以正常通信的了,采用的是增强型校验方式。后期再进一步跟进研究。。。更新一博,留

4路红外循迹模块使用教程

4路红外循迹模块使用教程 个人原创博客:点击浏览模块详细信息: 工作电压:DC 3.3V~5V 工作电流:尽量选择1A以上电源供电 工作温度:-10℃~50℃ 安装孔

HAL库串口中断

一,配置串口初始化 void MX_USART1_UART_Init(void) {huart1.Instance USART1;huart1.Init.BaudRate 115200;huart1.Init.WordLen