mexopencv  0.1
mex interface for opencv library
NormalBayesClassifier_.cpp
Go to the documentation of this file.
1 
8 #include "mexopencv.hpp"
9 #include "mexopencv_ml.hpp"
10 using namespace std;
11 using namespace cv;
12 using namespace cv::ml;
13 
14 // Persistent objects
15 namespace {
17 int last_id = 0;
19 map<int,Ptr<NormalBayesClassifier> > obj_;
20 }
21 
29 void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[])
30 {
31  // Check the number of arguments
32  nargchk(nrhs>=2 && nlhs<=3);
33 
34  // Argument vector
35  vector<MxArray> rhs(prhs, prhs+nrhs);
36  int id = rhs[0].toInt();
37  string method(rhs[1].toString());
38 
39  // Constructor is called. Create a new object from argument
40  if (method == "new") {
41  nargchk(nrhs==2 && nlhs<=1);
42  obj_[++last_id] = NormalBayesClassifier::create();
43  plhs[0] = MxArray(last_id);
44  return;
45  }
46 
47  // Big operation switch
48  Ptr<NormalBayesClassifier> obj = obj_[id];
49  if (method == "delete") {
50  nargchk(nrhs==2 && nlhs==0);
51  obj_.erase(id);
52  }
53  else if (method == "clear") {
54  nargchk(nrhs==2 && nlhs==0);
55  obj->clear();
56  }
57  else if (method == "load") {
58  nargchk(nrhs>=3 && (nrhs%2)==1 && nlhs==0);
59  string objname;
60  bool loadFromString = false;
61  for (int i=3; i<nrhs; i+=2) {
62  string key(rhs[i].toString());
63  if (key == "ObjName")
64  objname = rhs[i+1].toString();
65  else if (key == "FromString")
66  loadFromString = rhs[i+1].toBool();
67  else
68  mexErrMsgIdAndTxt("mexopencv:error",
69  "Unrecognized option %s", key.c_str());
70  }
71  obj_[id] = (loadFromString ?
72  Algorithm::loadFromString<NormalBayesClassifier>(rhs[2].toString(), objname) :
73  Algorithm::load<NormalBayesClassifier>(rhs[2].toString(), objname));
74  }
75  else if (method == "save") {
76  nargchk(nrhs==3 && nlhs<=1);
77  string fname(rhs[2].toString());
78  if (nlhs > 0) {
79  // write to memory, and return string
80  FileStorage fs(fname, FileStorage::WRITE + FileStorage::MEMORY);
81  fs << obj->getDefaultName() << "{";
82  fs << "format" << 3;
83  obj->write(fs);
84  fs << "}";
85  plhs[0] = MxArray(fs.releaseAndGetString());
86  }
87  else
88  // write to disk
89  obj->save(fname);
90  }
91  else if (method == "empty") {
92  nargchk(nrhs==2 && nlhs<=1);
93  plhs[0] = MxArray(obj->empty());
94  }
95  else if (method == "getDefaultName") {
96  nargchk(nrhs==2 && nlhs<=1);
97  plhs[0] = MxArray(obj->getDefaultName());
98  }
99  else if (method == "getVarCount") {
100  nargchk(nrhs==2 && nlhs<=1);
101  plhs[0] = MxArray(obj->getVarCount());
102  }
103  else if (method == "isClassifier") {
104  nargchk(nrhs==2 && nlhs<=1);
105  plhs[0] = MxArray(obj->isClassifier());
106  }
107  else if (method == "isTrained") {
108  nargchk(nrhs==2 && nlhs<=1);
109  plhs[0] = MxArray(obj->isTrained());
110  }
111  else if (method == "train") {
112  nargchk(nrhs>=4 && (nrhs%2)==0 && nlhs<=1);
113  vector<MxArray> dataOptions;
114  int flags = 0;
115  for (int i=4; i<nrhs; i+=2) {
116  string key(rhs[i].toString());
117  if (key == "Data")
118  dataOptions = rhs[i+1].toVector<MxArray>();
119  else if (key == "Flags")
120  flags = rhs[i+1].toInt();
121  else if (key == "UpdateModel")
122  UPDATE_FLAG(flags, rhs[i+1].toBool(), StatModel::UPDATE_MODEL);
123  else
124  mexErrMsgIdAndTxt("mexopencv:error",
125  "Unrecognized option %s", key.c_str());
126  }
127  Ptr<TrainData> data;
128  if (rhs[2].isChar())
129  data = loadTrainData(rhs[2].toString(),
130  dataOptions.begin(), dataOptions.end());
131  else
132  data = createTrainData(
133  rhs[2].toMat(CV_32F),
134  rhs[3].toMat(CV_32S),
135  dataOptions.begin(), dataOptions.end());
136  bool b = obj->train(data, flags);
137  plhs[0] = MxArray(b);
138  }
139  else if (method == "calcError") {
140  nargchk(nrhs>=4 && (nrhs%2)==0 && nlhs<=2);
141  vector<MxArray> dataOptions;
142  bool test = false;
143  for (int i=4; i<nrhs; i+=2) {
144  string key(rhs[i].toString());
145  if (key == "Data")
146  dataOptions = rhs[i+1].toVector<MxArray>();
147  else if (key == "TestError")
148  test = rhs[i+1].toBool();
149  else
150  mexErrMsgIdAndTxt("mexopencv:error",
151  "Unrecognized option %s", key.c_str());
152  }
153  Ptr<TrainData> data;
154  if (rhs[2].isChar())
155  data = loadTrainData(rhs[2].toString(),
156  dataOptions.begin(), dataOptions.end());
157  else
158  data = createTrainData(
159  rhs[2].toMat(CV_32F),
160  rhs[3].toMat(CV_32S),
161  dataOptions.begin(), dataOptions.end());
162  Mat resp;
163  float err = obj->calcError(data, test, (nlhs>1 ? resp : noArray()));
164  plhs[0] = MxArray(err);
165  if (nlhs>1)
166  plhs[1] = MxArray(resp);
167  }
168  else if (method == "predict") {
169  nargchk(nrhs>=3 && (nrhs%2)==1 && nlhs<=2);
170  int flags = 0;
171  for (int i=3; i<nrhs; i+=2) {
172  string key(rhs[i].toString());
173  if (key == "Flags")
174  flags = rhs[i+1].toInt();
175  else if (key == "RawOutput")
176  UPDATE_FLAG(flags, rhs[i+1].toBool(), StatModel::RAW_OUTPUT);
177  else
178  mexErrMsgIdAndTxt("mexopencv:error",
179  "Unrecognized option %s", key.c_str());
180  }
181  Mat samples(rhs[2].toMat(CV_32F)),
182  results;
183  float f = obj->predict(samples, results, flags);
184  plhs[0] = MxArray(results);
185  if (nlhs>1)
186  plhs[1] = MxArray(f);
187  }
188  else if (method == "predictProb") {
189  nargchk(nrhs>=3 && (nrhs%2)==1 && nlhs<=3);
190  int flags = 0;
191  for (int i=3; i<nrhs; i+=2) {
192  string key(rhs[i].toString());
193  if (key == "Flags")
194  flags = rhs[i+1].toInt();
195  else if (key == "RawOutput")
196  UPDATE_FLAG(flags, rhs[i+1].toBool(), StatModel::RAW_OUTPUT);
197  else
198  mexErrMsgIdAndTxt("mexopencv:error",
199  "Unrecognized option %s", key.c_str());
200  }
201  //HACK: we must do this one sample at a time to avoid incorrect outputProbs
202  //TODO: https://github.com/Itseez/opencv/issues/5911
203  Mat inputs(rhs[2].toMat(CV_32F));
204  Mat outputs(inputs.rows, 1, CV_32S),
205  outputProbs;
206  if (nlhs > 1) {
207  // we need to determine the number of classes
208  // to allocate the output probabilities matrix
209  int nclasses = 1;
210  if (!inputs.empty()) {
211  Mat tmp;
212  obj->predictProb(inputs.row(0), noArray(), tmp, flags);
213  nclasses = tmp.cols;
214  }
215  outputProbs.create(inputs.rows, nclasses, CV_32F);
216  }
217  float f = 0;
218  for (size_t i=0; i<inputs.rows; ++i)
219  f = obj->predictProb(inputs.row(i), outputs.row(i),
220  (nlhs>1 ? outputProbs.row(i) : noArray()), flags);
221  plhs[0] = MxArray(outputs);
222  if (nlhs>1)
223  plhs[1] = MxArray(outputProbs);
224  if (nlhs>2)
225  plhs[2] = MxArray(f);
226  }
227  else
228  mexErrMsgIdAndTxt("mexopencv:error","Unrecognized operation");
229 }
#define UPDATE_FLAG(NUM, TF, BIT)
set or clear a bit in flag depending on bool value
Definition: mexopencv.hpp:159
int toInt() const
Convert MxArray to int.
Definition: MxArray.cpp:489
cv::Ptr< cv::ml::TrainData > loadTrainData(const std::string &filename, std::vector< MxArray >::const_iterator first, std::vector< MxArray >::const_iterator last)
Read a dataset from a CSV file.
void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[])
Main entry called from Matlab.
mxArray object wrapper for data conversion and manipulation.
Definition: MxArray.hpp:123
void nargchk(bool cond)
Alias for input/ouput arguments number check.
Definition: mexopencv.hpp:166
bool toBool() const
Convert MxArray to bool.
Definition: MxArray.cpp:510
Global constant definitions.
cv::Ptr< cv::ml::TrainData > createTrainData(const cv::Mat &samples, const cv::Mat &responses, std::vector< MxArray >::const_iterator first, std::vector< MxArray >::const_iterator last)
Create an instance of TrainData using options in arguments.
Common definitions for the ml module.