基于分位数长短期记忆网络(QRLSTM)的数据回归预测 matlab代码,要求2018及以上版本

分位数长短期记忆网络(QRLSTM)在数据预测任务中越来越受欢迎,特别是需要评估预测不确定性的场景。和普通LSTM只会预测一个值不同,这哥们能同时输出多个分位数的预测结果,相当于自带置信区间分析功能。咱们用MATLAB实操一把,先准备个正弦波加噪声的示例数据:

% 数据生成
x = linspace(0, 10, 500)';
y = sin(x*2) + 0.2*randn(size(x));
t = 0.1:0.1:10; % 时间序列

% 数据标准化
y_normalized = (y - mean(y))/std(y);

这里故意加了点高斯噪声,模拟真实场景的数据波动。标准化操作是常规操作了,防止梯度爆炸。接下来构建网络结构是关键,注意这里用了双LSTM层设计:

numFeatures = 1;
numHiddenUnits = 32;
quantiles = [0.1, 0.5, 0.9];  % 10%、50%、90%分位数

layers = [...
    sequenceInputLayer(numFeatures)
    lstmLayer(numHiddenUnits,'OutputMode','sequence')
    dropoutLayer(0.2)
    lstmLayer(numHiddenUnits/2)
    fullyConnectedLayer(length(quantiles))
    quantileLossLayer(quantiles)];

第二层LSTM神经元数减半是个经验值,防止过拟合。重点是这个quantileLossLayer需要自定义实现:

classdef quantileLossLayer < nnet.layer.Layer
    properties
        quantiles
    end
    
    methods
        function layer = quantileLossLayer(quantiles)
            layer.quantiles = quantiles;
            layer.Name = 'QuantileLoss';
        end
        
        function loss = forwardLoss(layer, Y, T)
            residuals = T - Y;
            loss = mean(max(layer.quantiles'.*residuals, (layer.quantiles'-1).*residuals), [1 3]);
        end
    end
end

这个损失函数实现得挺有意思——当预测值高于真实值时,用quantile系数惩罚;低于时用(1-quantile)惩罚。比如对于0.9分位数,更倾向于让预测值偏高而不是偏低。

基于分位数长短期记忆网络(QRLSTM)的数据回归预测 matlab代码,要求2018及以上版本

训练参数设置有个小技巧,初始学习率设高点然后逐步衰减:

options = trainingOptions('adam', ...
    'MaxEpochs', 150, ...
    'InitialLearnRate',0.1, ...
    'LearnRateSchedule','piecewise', ...
    'LearnRateDropPeriod',50, ...
    'Verbose',0);

实测发现这样收敛更快。最后预测阶段要注意反标准化:

YPred = predict(net, XTest);
YPred = YPred * std(y) + mean(y);

跑出来的效果如图所示,三条分位数曲线像三明治一样包着真实数据波动。中间那条50%分位线基本贴合正弦波主线,上下两条形成预测区间。有意思的是在波峰波谷处,预测区间会自动变宽——这说明模型确实学到了序列变化的规律性。

有个坑得提醒:处理长序列时最好用GPU加速,不然等个把小时是常事。另外分位数设置不宜太多,否则网络容易精神分裂。实际项目中可以先用0.5分位做基准,再根据需求添加两侧分位点。

Logo

腾讯云面向开发者汇聚海量精品云计算使用和开发经验,营造开放的云计算技术生态圈。

更多推荐