19 map<int,Ptr<NormalBayesClassifier> > obj_;
29 void mexFunction(
int nlhs, mxArray *plhs[],
int nrhs,
const mxArray *prhs[])
35 vector<MxArray> rhs(prhs, prhs+nrhs);
36 int id = rhs[0].toInt();
37 string method(rhs[1].toString());
40 if (method ==
"new") {
42 obj_[++last_id] = NormalBayesClassifier::create();
48 Ptr<NormalBayesClassifier> obj = obj_[id];
49 if (method ==
"delete") {
53 else if (method ==
"clear") {
57 else if (method ==
"load") {
58 nargchk(nrhs>=3 && (nrhs%2)==1 && nlhs==0);
60 bool loadFromString =
false;
61 for (
int i=3; i<nrhs; i+=2) {
62 string key(rhs[i].toString());
64 objname = rhs[i+1].toString();
65 else if (key ==
"FromString")
66 loadFromString = rhs[i+1].toBool();
68 mexErrMsgIdAndTxt(
"mexopencv:error",
69 "Unrecognized option %s", key.c_str());
71 obj_[id] = (loadFromString ?
72 Algorithm::loadFromString<NormalBayesClassifier>(rhs[2].toString(), objname) :
73 Algorithm::load<NormalBayesClassifier>(rhs[2].toString(), objname));
75 else if (method ==
"save") {
77 string fname(rhs[2].toString());
80 FileStorage fs(fname, FileStorage::WRITE + FileStorage::MEMORY);
81 fs << obj->getDefaultName() <<
"{";
85 plhs[0] =
MxArray(fs.releaseAndGetString());
91 else if (method ==
"empty") {
93 plhs[0] =
MxArray(obj->empty());
95 else if (method ==
"getDefaultName") {
97 plhs[0] =
MxArray(obj->getDefaultName());
99 else if (method ==
"getVarCount") {
101 plhs[0] =
MxArray(obj->getVarCount());
103 else if (method ==
"isClassifier") {
105 plhs[0] =
MxArray(obj->isClassifier());
107 else if (method ==
"isTrained") {
109 plhs[0] =
MxArray(obj->isTrained());
111 else if (method ==
"train") {
112 nargchk(nrhs>=4 && (nrhs%2)==0 && nlhs<=1);
113 vector<MxArray> dataOptions;
115 for (
int i=4; i<nrhs; i+=2) {
116 string key(rhs[i].toString());
118 dataOptions = rhs[i+1].toVector<
MxArray>();
119 else if (key ==
"Flags")
120 flags = rhs[i+1].
toInt();
121 else if (key ==
"UpdateModel")
122 UPDATE_FLAG(flags, rhs[i+1].toBool(), StatModel::UPDATE_MODEL);
124 mexErrMsgIdAndTxt(
"mexopencv:error",
125 "Unrecognized option %s", key.c_str());
130 dataOptions.begin(), dataOptions.end());
133 rhs[2].toMat(CV_32F),
134 rhs[3].toMat(CV_32S),
135 dataOptions.begin(), dataOptions.end());
136 bool b = obj->train(data, flags);
139 else if (method ==
"calcError") {
140 nargchk(nrhs>=4 && (nrhs%2)==0 && nlhs<=2);
141 vector<MxArray> dataOptions;
143 for (
int i=4; i<nrhs; i+=2) {
144 string key(rhs[i].toString());
146 dataOptions = rhs[i+1].toVector<
MxArray>();
147 else if (key ==
"TestError")
150 mexErrMsgIdAndTxt(
"mexopencv:error",
151 "Unrecognized option %s", key.c_str());
156 dataOptions.begin(), dataOptions.end());
159 rhs[2].toMat(CV_32F),
160 rhs[3].toMat(CV_32S),
161 dataOptions.begin(), dataOptions.end());
163 float err = obj->calcError(data, test, (nlhs>1 ? resp : noArray()));
168 else if (method ==
"predict") {
169 nargchk(nrhs>=3 && (nrhs%2)==1 && nlhs<=2);
171 for (
int i=3; i<nrhs; i+=2) {
172 string key(rhs[i].toString());
174 flags = rhs[i+1].toInt();
175 else if (key ==
"RawOutput")
176 UPDATE_FLAG(flags, rhs[i+1].toBool(), StatModel::RAW_OUTPUT);
178 mexErrMsgIdAndTxt(
"mexopencv:error",
179 "Unrecognized option %s", key.c_str());
181 Mat samples(rhs[2].toMat(CV_32F)),
183 float f = obj->predict(samples, results, flags);
188 else if (method ==
"predictProb") {
189 nargchk(nrhs>=3 && (nrhs%2)==1 && nlhs<=3);
191 for (
int i=3; i<nrhs; i+=2) {
192 string key(rhs[i].toString());
194 flags = rhs[i+1].toInt();
195 else if (key ==
"RawOutput")
196 UPDATE_FLAG(flags, rhs[i+1].toBool(), StatModel::RAW_OUTPUT);
198 mexErrMsgIdAndTxt(
"mexopencv:error",
199 "Unrecognized option %s", key.c_str());
203 Mat inputs(rhs[2].toMat(CV_32F));
204 Mat outputs(inputs.rows, 1, CV_32S),
210 if (!inputs.empty()) {
212 obj->predictProb(inputs.row(0), noArray(), tmp, flags);
215 outputProbs.create(inputs.rows, nclasses, CV_32F);
218 for (
size_t i=0; i<inputs.rows; ++i)
219 f = obj->predictProb(inputs.row(i), outputs.row(i),
220 (nlhs>1 ? outputProbs.row(i) : noArray()), flags);
223 plhs[1] =
MxArray(outputProbs);
228 mexErrMsgIdAndTxt(
"mexopencv:error",
"Unrecognized operation");
#define UPDATE_FLAG(NUM, TF, BIT)
set or clear a bit in flag depending on bool value
int toInt() const
Convert MxArray to int.
cv::Ptr< cv::ml::TrainData > loadTrainData(const std::string &filename, std::vector< MxArray >::const_iterator first, std::vector< MxArray >::const_iterator last)
Read a dataset from a CSV file.
void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[])
Main entry called from Matlab.
mxArray object wrapper for data conversion and manipulation.
void nargchk(bool cond)
Alias for input/ouput arguments number check.
bool toBool() const
Convert MxArray to bool.
Global constant definitions.
cv::Ptr< cv::ml::TrainData > createTrainData(const cv::Mat &samples, const cv::Mat &responses, std::vector< MxArray >::const_iterator first, std::vector< MxArray >::const_iterator last)
Create an instance of TrainData using options in arguments.
Common definitions for the ml module.