查看: 13084|回复: 168|关注: 4

[我分享] 相关向量机 (Relevance Vector Machine, RVM) 训练和预测的实现

  [复制链接]

入门

166 麦片

财富积分


50500


9

主题

187

帖子

0

最佳答案
  • 关注者: 61
发表于 2018-11-12 08:19:42 | 显示全部楼层 |阅读模式
本帖最后由 iqiukp 于 2018-11-12 08:20 编辑

相关向量机 (Relevance Vector Machine, RVM) 在回归预测中的应用。

主要功能有:
(1)基于SB2_Release_200工具箱
(2)输出训练数据和测试数据的预测值
(3)输出相关向量的序号和对应的数值
(4)输出预测值的均值和方差(即分布)

核心函数:
(1)rvm_train
  1. function [model]= rvm_train(X,Y)
  2. % DESCRIPTION
  3. % Prediction based on Relevance Vector Machine (RVM)
  4. % Using SB2_Release_200 toolbox
  5. %
  6. % [model]= rvm_train(X,Y)
  7. %
  8. % INPUT
  9. % X Training samples (N*d)
  10. % N: number of samples
  11. % d: number of features
  12. % Y Target samples (N*1)
  13. %
  14. % OUTPUT
  15. % model RVM model
  16. %
  17. %


  18. % kernei width
  19. sigma = 5.5;


  20. L = size(X,1);
  21. %
  22. BASIS = [ones(L,1),computeKM(X,X,sigma)];
  23. %
  24. % SETTINGS = SB2_ParameterSettings('NoiseStd',0.1);
  25. OPTIONS = SB2_UserOptions('diagnosticLevel','medium','monitor',10, ...
  26. 'diagnosticFile', 'logfile.txt');
  27. [PARAMETER, HYPERPARAMETER, DIAGNOSTIC] = ...
  28. SparseBayes('Gaussian', BASIS, Y, OPTIONS);

  29. model.rv_index = PARAMETER.Relevant;
  30. model.rv_mu = PARAMETER.Value;
  31. model.width = sigma;
  32. model.X = X;
  33. model.beta = HYPERPARAMETER.beta;
  34. model.sigma = DIAGNOSTIC.Sigma;

  35. % mean of prediction (training samples)
  36. model.y_mu = BASIS(:,model.rv_index)*model.rv_mu;

  37. % variance of prediction (training samples)
  38. model.y_var = ones(L,1)*model.beta^-1+ ...
  39. diag(BASIS(:,model.rv_index)* ...
  40. model.sigma*BASIS(:,model.rv_index)');

  41. end
复制代码
(2)rvm_test
  1. function [y_mu,y_var] = rvm_test(model,X)
  2. % DESCRIPTION
  3. % Prediction based on Relevance Vector Machine (RVM)
  4. % Using SB2_Release_200 toolbox
  5. %
  6. % [y_mu,y_var] = rvm_test(model,X)
  7. %
  8. % INPUT
  9. % X Test samples (N*d)
  10. % N: number of samples
  11. % d: number of features
  12. % y_mu Mean of prediction
  13. % y_var Variance of prediction
  14. %
  15. % OUTPUT
  16. % model RVM model
  17. %
  18. %

  19. L = size(X,1);
  20. BASIS = [ones(L,1),computeKM(X,model.X,model.width)];

  21. % mean of prediction (test samples)
  22. y_mu = BASIS(:,model.rv_index)*model.rv_mu;

  23. % variance of prediction (test samples)
  24. y_var = ones(L,1)*model.beta^-1+ ...
  25. diag(BASIS(:,model.rv_index)* ...
  26. model.sigma*BASIS(:,model.rv_index)');

  27. end
复制代码
demo: sinc函数
  1. % demo
  2. clc
  3. clear all
  4. close all
  5. addpath(genpath(pwd))

  6. % sinc funciton
  7. fun = @(x) sin(abs(x))/abs(x);
  8. % training samples
  9. x = linspace(-10,10,100);
  10. y = arrayfun(fun,x);
  11. X = x';
  12. Y = y';
  13. % test samples
  14. xtest = linspace(-10,10,20);
  15. ytest = arrayfun(fun,xtest);
  16. Xtest = xtest';
  17. Ytest = ytest';

  18. % train RVM model
  19. [model]= rvm_train(X,Y);

  20. % test RVM model
  21. [y_mu,y_var] = rvm_test(model,Xtest);

  22. %
  23. ypre = y_mu;

  24. %
  25. figure
  26. plot(X,Y,':o','LineWidth',1.5,'MarkerSize',4)
  27. hold on
  28. plot(X,model.y_mu,':o','LineWidth',1.5,'MarkerSize',4)
  29. plot(X(model.rv_index),Y(model.rv_index),'go', ...
  30. 'LineWidth',1.5,'MarkerSize',8)

  31. legend('training samples','prediction','relevance Vectors')

  32. figure
  33. xs = (1:size(Xtest,1))';
  34. % 3σ
  35. f1 = [y_mu(:,1)+2*sqrt(y_var(:,1)); flip(y_mu(:,1)-2*sqrt(y_var(:,1)),1)];
  36. fill([xs; flip(xs,1)], f1, [7 7 7]/8)
  37. hold on
  38. plot(Ytest,'b:o','LineWidth',1.5,'MarkerSize',4)
  39. plot(y_mu,'r:o','LineWidth',1.5,'MarkerSize',4)

  40. xlabel('Samples')
  41. ylabel('Values')
  42. legend('3σ boundary','test samples','prediction')
复制代码
结果分别为:训练数据的预测结果和测试数据的预测结果
1.png 2.png



附件是源代码。

Relevance Vector Machine.zip (157.6 KB, 下载次数: 219)
回复主题 已获打赏: 0 积分

举报

入门

166 麦片

财富积分


50500


9

主题

187

帖子

0

最佳答案
  • 关注者: 61
 楼主| 发表于 2020-4-28 03:12:27 | 显示全部楼层
本帖最后由 iqiukp 于 2020-5-27 12:35 编辑

untitled.png



2020.4.28更新(V2.0)

源代码
RVM-MATLAB-V2.0.zip (176.72 KB, 下载次数: 66)
回复此楼 已获打赏: 20 积分

举报

新手

5 麦片

财富积分


050


1

主题

5

帖子

0

最佳答案
  • 关注者: 2
发表于 2019-1-16 18:35:44 | 显示全部楼层
楼主,为什么我用你的程序可以拟合数据,但是不能预测?
回复此楼 已获打赏: 0 积分

举报

新手

6 麦片

财富积分


050


0

主题

10

帖子

0

最佳答案
发表于 2019-2-28 16:42:08 | 显示全部楼层
非常感谢博主的分享,但是有个疑问麻烦帮忙解答:
我把sinc函数改为sin函数后,发现无法实现非线性拟合,拟合曲线偏差严重,另外相关向量曲线报错,这是什么原因?
希望能加博主的微信,可以发到我邮箱281188347@qq.com,谢谢!

sin(x)函数无法拟合

sin(x)函数无法拟合
回复此楼 已获打赏: 0 积分

举报

入门

166 麦片

财富积分


50500


9

主题

187

帖子

0

最佳答案
  • 关注者: 61
 楼主| 发表于 2019-2-28 17:57:22 | 显示全部楼层
本帖最后由 iqiukp 于 2019-2-28 18:00 编辑
xiaodong2019 发表于 2019-2-28 16:42
非常感谢博主的分享,但是有个疑问麻烦帮忙解答:
我把sinc函数改为sin函数后,发现无法实现非线性拟合,拟 ...

报错的原因:
1. 训练集中的相关向量数据点标记范围超过了样本个数,导致第一幅图的第三条曲线报错(即相关向量数据点),这个问题笔者以后再修复。
2. 可以直接绕过figure1,直接看figure2的拟合效果(测试集)

拟合性能差的原因:相关向量机的重要参数为核宽度,即rvm_train.m中的sigma:
  1. % kernei width
  2. sigma = 5.5;  
复制代码

提高拟合性能需要调整该参数,笔者试了试,在sigma等于3以下的时候,训练集和测试集的拟合效果为:
1.jpg 2.jpg


可以看出几乎完全重合了。




回复此楼 已获打赏: 0 积分

举报

新手

6 麦片

财富积分


050


0

主题

10

帖子

0

最佳答案
发表于 2019-2-28 23:30:23 | 显示全部楼层
iqiukp 发表于 2019-2-28 17:57
报错的原因:
1. 训练集中的相关向量数据点标记范围超过了样本个数,导致第一幅图的第三条曲线报错(即相 ...

非常感谢博主的解答,修改核宽度后,sin(x)拟合曲线确实与函数曲线一致。
    但是本人将训练函数替换为具体x,y数据后(如图1),在给定令一组x’数据后利用训练模型预测y'时得到的拟合曲线与训练曲线差别较大(如图2),且改变核宽度仍无效果,请问问题出在哪里?进一步,如果我的训练输入数据不止x一组,但只有一个y输出,请问如何开展训练和预测?
    如何可以的话,希望能加上博主的扣扣或微信,本人扣扣(281188347),谢谢!


12.JPG
11.JPG
回复此楼 已获打赏: 0 积分

举报

入门

166 麦片

财富积分


50500


9

主题

187

帖子

0

最佳答案
  • 关注者: 61
 楼主| 发表于 2019-3-4 20:29:51 | 显示全部楼层
xiaodong2019 发表于 2019-2-28 23:30
非常感谢博主的解答,修改核宽度后,sin(x)拟合曲线确实与函数曲线一致。
    但是本人将训练函数替换为 ...

1. 你把新的训练数据通过代码的方式贴上来,我这边试试。

2. 你的意思是多个自变量,一个因变量?这种情况的处理方法也是一样的,只要保证数据的行方向是样本,列方向是特征(属性)就行。部分情况可能还得先预处理数据。
回复此楼 已获打赏: 0 积分

举报

新手

6 麦片

财富积分


050


0

主题

10

帖子

0

最佳答案
发表于 2019-3-5 10:21:02 | 显示全部楼层
iqiukp 发表于 2019-3-4 20:29
1. 你把新的训练数据通过代码的方式贴上来,我这边试试。

2. 你的意思是多个自变量,一个因变量?这种情 ...

多谢博主解答,该问题原因已经找到,是输入数据未归一化,而且核宽度仍然需要小于3以下。现在存在两个 疑惑:
一是整个训练模型是不是只有核宽度一个参数需要适时调整大小,还有没有其他需要参数,如何影响拟合结果;二是该语句(plot(X(model.rv_index),Y(model.rv_index),'go',  'LineWidth',1.5,'MarkerSize',8))除了博主给出的sinc函数外一直报错(“索引超出矩阵维度”),例如换成sin函数,该语句是用来画出相关性好的训练数据的曲线么?为什么只有几个点?能否帮忙给出解决方案,谢谢!
回复此楼 已获打赏: 0 积分

举报

入门

166 麦片

财富积分


50500


9

主题

187

帖子

0

最佳答案
  • 关注者: 61
 楼主| 发表于 2019-3-5 11:33:49 | 显示全部楼层
本帖最后由 iqiukp 于 2019-3-5 11:35 编辑
xiaodong2019 发表于 2019-3-5 10:21
多谢博主解答,该问题原因已经找到,是输入数据未归一化,而且核宽度仍然需要小于3以下。现在存在两个 疑 ...

1. 在SB2_Release_200工具箱里面还有一个(迭代次数)可以设置。
rvm_train.m文件里面可以看到有个UserOptions, 工具箱默认是10000次(可以看工具箱的SB2_UserOptions.m文件)
  1. % SETTINGS = SB2_ParameterSettings('NoiseStd',0.1);
  2. OPTIONS = SB2_UserOptions('diagnosticLevel','medium','monitor',10, ...
  3.     'diagnosticFile', 'logfile.txt');
  4. [PARAMETER, HYPERPARAMETER, DIAGNOSTIC] = ...
  5.     SparseBayes('Gaussian', BASIS, Y, OPTIONS);
复制代码
可以尝试把上面的OPTIONS 改为:
  1. OPTIONS = SB2_UserOptions('Iterations',10000, 'diagnosticLevel','medium', 'monitor',10, ...
  2.     'diagnosticFile', 'logfile.txt');
复制代码
10000是迭代次数,可以根据你的需求改成其他数字,不过RVM的默认精度已经够高了,再修改这个参数意义不是很大,重点还是核宽度那个参数。

除此之外,最基本的RVM算法应该就没有可调的参数了。你可以关注一下近年来针对RVM原理层面的改进算法,或许有可以参考的地方。

2. 你可以先看看 ‘model’ 这个变量的意思。
  1. % SETTINGS = SB2_ParameterSettings('NoiseStd',0.1);
  2. OPTIONS = SB2_UserOptions('diagnosticLevel','medium','monitor',10, ...
  3.     'diagnosticFile', 'logfile.txt');
  4. [PARAMETER, HYPERPARAMETER, DIAGNOSTIC] = ...
  5.     SparseBayes('Gaussian', BASIS, Y, OPTIONS);

  6. % Locate = find(PARAMETER.Relevant == L+1);
  7. % PARAMETER.Relevant(Locate)=[];
  8. % PARAMETER.Value(Locate)=[];

  9. model.rv_index = PARAMETER.Relevant;
  10. model.rv_mu = PARAMETER.Value;
  11. model.width = sigma;
  12. model.X = X;
  13. model.beta = HYPERPARAMETER.beta;
  14. model.sigma = DIAGNOSTIC.Sigma;
复制代码
可以看到,model是一个结构体,整合了RVM模型的一些参数。其中 model.rv_index 是训练集里面相关向量的下标(实际上也是训练集里面的样本)。
  1. plot(X(model.rv_index),Y(model.rv_index),'go','LineWidth',1.5,'MarkerSize',8)
复制代码
实际上该语句的功能就是想把训练集中那些属于“相关向量”的样本给标记出来,因此只有少数几个点,稀疏性好。另外,报错的那个原因已经找到,但目前暂时没有找到解决的方法。
回复此楼 已获打赏: 0 积分

举报

新手

11 麦片

财富积分


050


0

主题

26

帖子

0

最佳答案
发表于 2019-3-8 14:43:12 | 显示全部楼层
请教楼主:
如果用DE来优化楼主的RVM算法,只能优化核参数sigma吗?一般的论文文献中优化的是至少两个参数:α、σ。如果要优化楼主的算法,适应度函数如何构造?
回复此楼 已获打赏: 0 积分

举报

入门

166 麦片

财富积分


50500


9

主题

187

帖子

0

最佳答案
  • 关注者: 61
 楼主| 发表于 2019-5-11 15:32:42 | 显示全部楼层
本帖最后由 iqiukp 于 2019-5-11 22:31 编辑

2019.5.11 更新

Relevance Vector Machine (RVM).zip (1.24 MB, 下载次数: 128)
回复此楼 已获打赏: 0 积分

举报

您需要登录后才可以回帖 登录 | 注册

本版积分规则

关闭

站长推荐上一条 /4 下一条

快速回复 返回顶部 返回列表