首页
学习
活动
专区
工具
TVP
发布
精选内容/技术社群/优惠产品,尽在小程序
立即前往

《统计学习方法》朴素贝叶斯实例

本文旨在Python案例实现方法,对于公式推导可参考《统计学习方法》第4章,

同时实例也来源于该书,希望Python实例能给正在看此书的小伙伴一些参考;

第1步:加载数据集

def createdata(self):

train_data=np.array([[1,'S',-1],[1,"M",-1],[1,"M",1,],[1,'S',1],[1,'S',-1],[2,'S',-1],[2,"M",-1],[2,"M",1],[2,"L",1],[2,"L",1],[3,"L",1],[3,"M",1],[3,"M",1],[3,"L",1],[3,"L",-1]])

test_data=np.array([3,"M"])

return train_data[:,:2],train_data[:,-1],test_data,train_data

第2步:计算label频次及概率

def label_p(self,label):

label_dict={}

label_pro={}

label_number=0

for labelcontent in label:

label_number+=1

if labelcontent in label_dict:

label_dict[labelcontent]+=1

else:

label_dict[labelcontent]=1

for k,v in label_dict.items():

label_pro[k]=v/label_number

return label_dict,label_pro

第3步:计算训练集条件概率

def train_data_con_pro(self,train_data,k,k1,label):

return train_data[np.where(train_data[np.where(train_data[:,-1]==label)][:,k]==k1)].shape[0]/train_data[np.where(train_data[:,-1]==label)].shape[0]

def con_pro(self,train_data):

pro_result={}

train_x1_count,train_x1_set=self.label_p(train_data[:,0])

train_x2_count,train_x2_set=self.label_p(train_data[:,1])

label_count,label_pro=self.label_p(self.label)

for k,v in label_count.items():

for k1,v1 in train_x1_count.items():

key="P(X1=%s|Y=%s)"%(k1,k)

value=self.train_data_con_pro(train_data,0,k1,k)

pro_result[key]=value

for k2,v2 in train_x2_count.items():

key="P(X2=%s|Y=%s)"%(k2,k)

value=self.train_data_con_pro(train_data,1,k2,k)

pro_result[key]=value

key="P(Y=%s)"%k

pro_result[key]=label_pro[k]

return pro_result

第4步:计算测试数据集概率及结果

def train_test(self,test_data):

test_pro={}

pro_result=self.con_pro(self.train)

label_count, label_pro = self.label_p(self.label)

for k,v in label_count.items():

value=pro_result["P(Y=%s)"%k]*pro_result["P(X1=%s|Y=%s)"%(test_data[0],k)]*pro_result["P(X2=%s|Y=%s)"%(test_data[1],k)]

test_pro[k]=value

return sorted(test_pro.items(),reverse=True,key=lambda x: x[1])[0][0]

第5步:封装

def main(self):

testlabel=(self.train_test(self.test))

return testlabel

全部代码如下所示:

# Author:随心

import numpy as np

class classifierNB:

def __init__(self):

self.train_data,self.label,self.test,self.train=self.createdata()

def createdata(self):

train_data=np.array([[1,'S',-1],[1,"M",-1],[1,"M",1,],[1,'S',1],[1,'S',-1],[2,'S',-1],[2,"M",-1],[2,"M",1],[2,"L",1],[2,"L",1],[3,"L",1],[3,"M",1],[3,"M",1],[3,"L",1],[3,"L",-1]])

test_data=np.array([3,"M"])

return train_data[:,:2],train_data[:,-1],test_data,train_data

def label_p(self,label):

label_dict={}

label_pro={}

label_number=0

for labelcontent in label:

label_number+=1

if labelcontent in label_dict:

label_dict[labelcontent]+=1

else:

label_dict[labelcontent]=1

for k,v in label_dict.items():

label_pro[k]=v/label_number

return label_dict,label_pro

def train_data_con_pro(self,train_data,k,k1,label):

return train_data[np.where(train_data[np.where(train_data[:,-1]==label)][:,k]==k1)].shape[0]/train_data[np.where(train_data[:,-1]==label)].shape[0]

def con_pro(self,train_data):

pro_result={}

train_x1_count,train_x1_set=self.label_p(train_data[:,0])

train_x2_count,train_x2_set=self.label_p(train_data[:,1])

label_count,label_pro=self.label_p(self.label)

for k,v in label_count.items():

for k1,v1 in train_x1_count.items():

key="P(X1=%s|Y=%s)"%(k1,k)

value=self.train_data_con_pro(train_data,0,k1,k)

pro_result[key]=value

for k2,v2 in train_x2_count.items():

key="P(X2=%s|Y=%s)"%(k2,k)

value=self.train_data_con_pro(train_data,1,k2,k)

pro_result[key]=value

key="P(Y=%s)"%k

pro_result[key]=label_pro[k]

return pro_result

def train_test(self,test_data):

test_pro={}

pro_result=self.con_pro(self.train)

label_count, label_pro = self.label_p(self.label)

for k,v in label_count.items():

value=pro_result["P(Y=%s)"%k]*pro_result["P(X1=%s|Y=%s)"%(test_data[0],k)]*pro_result["P(X2=%s|Y=%s)"%(test_data[1],k)]

test_pro[k]=value

return sorted(test_pro.items(),reverse=True,key=lambda x: x[1])[0][0]

def main(self):

testlabel=(self.train_test(self.test))

return testlabel

if __name__=="__main__":

test=classifierNB()

print(test.main())

如果大家对此有疑问,可以给小编留言或者加群咨询

欢迎加群一起讨论学习~

加入我们,一起学习,一起成长,一起吹牛逼~

论坛地址:

Excel 相关QQ群:

Python QQ群:

微信公众号:

  • 发表于:
  • 原文链接http://kuaibao.qq.com/s/20180331G15BI100?refer=cp_1026
  • 腾讯「腾讯云开发者社区」是腾讯内容开放平台帐号(企鹅号)传播渠道之一,根据《腾讯内容开放平台服务协议》转载发布内容。
  • 如有侵权,请联系 cloudcommunity@tencent.com 删除。

扫码

添加站长 进交流群

领取专属 10元无门槛券

私享最新 技术干货

扫码加入开发者社群
领券