本帖最后由 pengyao1207 于 2017-12-7 20:09 编辑
偶然看到这么一个帖子,有些小激动,传说中的机器学习。。。。
朴素贝叶斯文本分类算法学习:www.chepoo.com/naive-bayesian-text-classification-algorithm-to-learn.html
尝试着把帖子里面的伯努利模型用c++实现了下,发帖以共同学习
代码如下
[C++] 纯文本查看 复制代码 // Bayes.cpp: 定义控制台应用程序的入口点。
//
#include<iostream>
#include<vector>
#include<string>
#include<set>
#include<map>
using namespace std;
class Bayes
{
public:
Bayes()
{
yes = 0; no = 0;
}
int input(vector<string> data,bool bj)
{
if (bj == 0)no++; else yes++;
//数组转集合
set<string> temp;
for (size_t i = 0; i < data.size(); i++)temp.insert(data[i]);
//集合的遍历与数据储存
for (set<string>::iterator it = temp.begin(); it != temp.end(); it++)if(bj==0)sjno[*it]++;else sjyes[*it]++;
return 0;
}
int output(vector<string> data)
{
float py=pyes();
float pn=pno();
//数组转集合
set<string> temp;
for (size_t i = 0; i < data.size(); i++)temp.insert(data[i]);
//数组转集合2
set<string> temp2;
for (map<string, int>::iterator it = sjyes.begin(); it != sjyes.end(); it++)
{
temp2.insert((*it).first);
}
for (map<string, int>::iterator it = sjno.begin(); it != sjno.end(); it++)
{
temp2.insert((*it).first);
}
//遍历集合2
for (set<string>::iterator it = temp2.begin(); it != temp2.end(); it++)
{
if (temp.find(*it) != temp.end())//找到了
{
py = py*getyes(*it);
pn = pn*getno(*it);
}
else
{
py = py*(1-getyes(*it));
pn = pn*(1-getno(*it));
}
}
cout << py << ' ' << pn << endl;//调试语句
if (py > pn)return 1; else return 0;
}
private:
float getyes(string str)
{
float ff;
if (sjyes.find(str) != sjyes.end())ff = sjyes[str];else ff = 0;
return float(ff + 1) / (yes + 2);
}
float getno(string str)
{
float ff;
if (sjno.find(str) != sjno.end())ff = sjno[str]; else ff = 0;
return float(ff + 1) / (no + 2);
}
float pyes()
{
return float(yes) / (yes + no);
}
float pno()
{
return float(no) / (yes + no);
}
int yes; int no;
map<string,int> sjyes;
map<string, int> sjno;
};
int main()
{
Bayes a;
//训练开始
a.input(vector<string>{"Chinese","Beijing","Chinese"}, 1);
a.input(vector<string>{"Chinese", "Chinese", "Shanghai"}, 1);
a.input(vector<string>{"Chinese", "Macao"}, 1);
a.input(vector<string>{"Tokyo", "Japan", "Chinese"}, 0);
//结果测试
cout <<"属于分类:"<<a.output(vector<string>{"Chinese", "Chinese", "Chinese", "Tokyo", "Japan"}) << endl;
return 0;
}
编译器用的是vs2017,编译环境是win10,代码有什么不足之处,还望指正
放上一张运行图~
|