准备数据集
采用的数据集是sklearn中的breast cancer数据集,30维特征,569个样本。训练前进行MinMax标准化缩放至[0,1]区间。按照75/25比例划分成训练集和验证集。
import numpy as np
import pandas as pd
from sklearn import datasets
from sklearn import preprocessing
from sklearn.model_selection import train_test_split
# 准备数据集
breast = datasets.load_breast_cancer()
scaler = preprocessing.MinMaxScaler()
data = scaler.fit_transform(breast['data'])
target = breast['target']
X_train,X_test,y_train,y_test = train_test_split(data,target)
二
模型结构图
三
正反传播公式
四
LR实现代码
import numpy as np
import pandas as pd
class LogisticRegression(object):
def __init__(self,alpha = 0.1,ITERNUM = 200000):
self.alpha,self.ITERNUM = alpha,ITERNUM
self.dfJ = pd.DataFrame(data = np.zeros((ITERNUM,1)),columns = ['J'])
self.w,self.b = np.nan,np.nan
def fit(self,X_train,y_train):
X,Y = X_train.T,y_train.reshape(1,-1)
n,m = X.shape
w,b = np.zeros((n,1)),0
for i in range(self.ITERNUM):
# 正向传播求函数值 X-->Z-->A-->J
Z = np.dot(w.T,X) + b
A = 1/(1 + np.exp(-Z))
J = (1/m) * np.sum(- Y*np.log(A) -(1-Y)*np.log(1-A))
self.dfJ.loc[i]['J']= J
# 反向传播求导数: J-->dA-->dZ(dw,db)
dA = 1/m*(-Y/A + (1-Y)/(1-A))
dZ = 1/m*(A-Y)
dw = np.dot(X,dZ.T)
db = np.sum(dZ)
# 梯度下降
w = w - self.alpha*dw
b = b - self.alpha*db
self.w,self.b = w,b
def predict_prob(self,X_test):
Z_test = np.dot(self.w.T,X_test.T) + self.b
Y_prob = 1/(1 + np.exp(-Z_test))
Y_prob = Y_prob.reshape(-1)
return(Y_prob)
def predict(self,X_test):
Y_prob = self.predict_prob(X_test)
Y_test = Y_prob.copy()
Y_test[Y_prob>=0.5] = 1
Y_test[Y_prob< 0.5] = 0
return(Y_test)
五
数据集测试
# 用数据喂养模型
clf = LogisticRegression(alpha = 0.1,ITERNUM = 200000)
clf.fit(X_train= X_train,y_train= y_train)
# 绘制目标函数的迭代曲线
%matplotlib inline
clf.dfJ.plot(y = 'J' ,kind = 'line',figsize = (10,7))
# 测试在验证集的auc得分
from sklearn.metrics import roc_auc_score
Y_prob = clf.predict_prob(X_test)
roc_auc_score(list(y_test),list(Y_prob))
# 和sklearn中的模型对比
from sklearn.linear_model import LogisticRegressionCV as LRCV
lr = LRCV()
lr.fit(X_train,y_train)
Y_proba = lr.predict_proba(X_test)
roc_auc_score(list(y_test),list(Y_proba[:,1]))