最小二乘法
在研究两个变量(x, y)之间的相互关系时
通常可以得到一系列成对的数据(x1, y1),(x2, y2)… (xm , ym)
将这些数据描绘在x-y直角坐标系中
若发现这些点在一条直线附近
可以令这条直线方程y= e + wx
其中:w\e是任意实数
为建立这直线方程就要确定e和w
应用《最小二乘法原理》
将实测值Yi与利用计算y= e + wx值的离差(yi-y)的平方和
即〔∑(yi - y)²〕最小
简单来说就是以下公式
y = a x + b
b = sum( y ) / n - a * sum( x ) / n
a = ( n * sum( xy ) - sum( x* ) * sum( y ) ) / ( n * sum( x^2 ) - sum(x) ^ 2 )
一个预测问题在回归模型下的解决步骤为:
1.构造训练集;
2.学习,得到输入输出间的关系;
3.预测,通过学习得到的关系预测输出
代码实现
你看,代码风格依旧良好
中间用到了Double类型的数据运算
而Double类型的数据直接加减乘除是有可能有问题的
所以附上了Double数据运算的常用方法
/**
* 使用最小二乘法实现线性回归预测
*
* @author daijiyong
*/
public class LinearRegression {
/**
* 训练集数据
*/
private Map<Double, Double> initData = new HashMap<>();
/**
* 截距
*/
private double intercept = 0.0;
//斜率
private double slope = 0.0;
/**
* x、y平均值
*/
private double averageX, averageY;
/**
* 求斜率的上下两个分式的值
*/
private double slopeUp, slopeDown;
public LinearRegression(Map<Double, Double> initData) {
this.initData = initData;
initData();
}
public LinearRegression() {
}
/**
* 根据训练集数据进行训练预测
* 并计算斜率和截距
*/
public void initData() {
if (initData.size() > 0) {
//数据个数
int number = 0;
//x值、y值总和
double sumX = 0;
double sumY = 0;
averageX = 0;
averageY = 0;
slopeUp = 0;
slopeDown = 0;
for (Double x : initData.keySet()) {
if (x == null || initData.get(x) == null) {
continue;
}
number++;
sumX = add(sumX, x);
sumY = add(sumY, initData.get(x));
}
//求x,y平均值
averageX = DoubleUtils.div(sumX, (double) number);
averageY = DoubleUtils.div(sumY, (double) number);
for (Double x : initData.keySet()) {
if (x == null || initData.get(x) == null) {
continue;
}
slopeUp = add(slopeUp, mul(sub(x, averageX), sub(initData.get(x), averageY)));
slopeDown = add(slopeDown, mul(sub(x, averageX), sub(x, averageX)));
}
initSlopeIntercept();
}
}
/**
* 计算斜率和截距
*/
private void initSlopeIntercept() {
if (slopeUp != 0 && slopeDown != 0) {
slope = slopeUp / slopeDown;
}
intercept = averageY - averageX * slope;
}
/**
* 根据x值预测y值
*
* @param x x值
* @return y值
*/
public Double getY(Double x) {
return add(intercept, mul(slope, x));
}
/**
* 根据y值预测x值
*
* @param y y值
* @return x值
*/
public Double getX(Double y) {
return div(sub(y, intercept), slope);
}
public Map<Double, Double> getInitData() {
return initData;
}
public void setInitData(Map<Double, Double> initData) {
this.initData = initData;
}
public static void main(String[] args) {
LinearRegression linearRegression = new LinearRegression();
//训练集数据
linearRegression.getInitData().put(1D, 8D);
linearRegression.getInitData().put(1.5D, 9.5D);
linearRegression.getInitData().put(2D, 11D);
linearRegression.getInitData().put(2.5D, 10D);
linearRegression.getInitData().put(3D, 14D);
//根据训练集数据进行线性函数预测
linearRegression.initData();
/*
* 给定x值,预测y值
*/
System.out.println(linearRegression.getY(8D));
/*
* 给定y值,预测x值
*/
System.out.println(linearRegression.getX(9.5D));
}
}
/**
* Created by daijiyong on 2017/4/6.
*/
public class DoubleUtils {
private static final int DEF_DIV_SCALE = 10;
/**
* * 两个Double数相加 *
*
* @param v1 *
* @param v2 *
* @return Double
*/
public static Double add(Double v1, Double v2) {
BigDecimal b1 = new BigDecimal(v1.toString());
BigDecimal b2 = new BigDecimal(v2.toString());
return b1.add(b2).doubleValue();
}
/**
* * 两个Double数相减 *
*
* @param v1 *
* @param v2 *
* @return Double
*/
public static Double sub(Double v1, Double v2) {
BigDecimal b1 = new BigDecimal(v1.toString());
BigDecimal b2 = new BigDecimal(v2.toString());
return b1.subtract(b2).doubleValue();
}
/**
* * 两个Double数相乘 *
*
* @param v1 *
* @param v2 *
* @return Double
*/
public static Double mul(Double v1, Double v2) {
BigDecimal b1 = new BigDecimal(v1.toString());
BigDecimal b2 = new BigDecimal(v2.toString());
return b1.multiply(b2).doubleValue();
}
/**
* * 两个Double数相除 *
*
* @param v1 *
* @param v2 *
* @return Double
*/
public static Double div(Double v1, Double v2) {
BigDecimal b1 = new BigDecimal(v1.toString());
BigDecimal b2 = new BigDecimal(v2.toString());
return b1.divide(b2, DEF_DIV_SCALE, BigDecimal.ROUND_HALF_UP).doubleValue();
}
/**
* * 两个Double数相除,并保留scale位小数 *
*
* @param v1 *
* @param v2 *
* @param scale *
* @return Double
*/
public static Double div(Double v1, Double v2, int scale) {
if (scale < 0) {
throw new IllegalArgumentException(
"The scale must be a positive integer or zero");
}
BigDecimal b1 = new BigDecimal(v1.toString());
BigDecimal b2 = new BigDecimal(v2.toString());
return b1.divide(b2, scale, BigDecimal.ROUND_HALF_UP).doubleValue();
}
public static int max(int a, int b) {
return Math.max(a, b);
}
public static int min(int a, int b) {
return Math.min(a, b);
}
运行测试
给个例子,测试一下
文/戴先生@2020年6月8日---end---
更多精彩推荐