mexopencv  0.1
mex interface for opencv library
RTrees_.cpp
Go to the documentation of this file.
1 
8 #include "mexopencv.hpp"
9 #include "mexopencv_ml.hpp"
10 using namespace std;
11 using namespace cv;
12 using namespace cv::ml;
13 
14 // Persistent objects
15 namespace {
17 int last_id = 0;
19 map<int,Ptr<RTrees> > obj_;
20 }
21 
29 void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[])
30 {
31  // Check the number of arguments
32  nargchk(nrhs>=2 && nlhs<=2);
33 
34  // Argument vector
35  vector<MxArray> rhs(prhs, prhs+nrhs);
36  int id = rhs[0].toInt();
37  string method(rhs[1].toString());
38 
39  // Constructor is called. Create a new object from argument
40  if (method == "new") {
41  nargchk(nrhs==2 && nlhs<=1);
42  obj_[++last_id] = RTrees::create();
43  plhs[0] = MxArray(last_id);
44  return;
45  }
46 
47  // Big operation switch
48  Ptr<RTrees> obj = obj_[id];
49  if (method == "delete") {
50  nargchk(nrhs==2 && nlhs==0);
51  obj_.erase(id);
52  }
53  else if (method == "clear") {
54  nargchk(nrhs==2 && nlhs==0);
55  obj->clear();
56  }
57  else if (method == "load") {
58  nargchk(nrhs>=3 && (nrhs%2)==1 && nlhs==0);
59  string objname;
60  bool loadFromString = false;
61  for (int i=3; i<nrhs; i+=2) {
62  string key(rhs[i].toString());
63  if (key == "ObjName")
64  objname = rhs[i+1].toString();
65  else if (key == "FromString")
66  loadFromString = rhs[i+1].toBool();
67  else
68  mexErrMsgIdAndTxt("mexopencv:error",
69  "Unrecognized option %s", key.c_str());
70  }
71  obj_[id] = (loadFromString ?
72  Algorithm::loadFromString<RTrees>(rhs[2].toString(), objname) :
73  Algorithm::load<RTrees>(rhs[2].toString(), objname));
74  }
75  else if (method == "save") {
76  nargchk(nrhs==3 && nlhs<=1);
77  string fname(rhs[2].toString());
78  if (nlhs > 0) {
79  // write to memory, and return string
80  FileStorage fs(fname, FileStorage::WRITE + FileStorage::MEMORY);
81  fs << obj->getDefaultName() << "{";
82  fs << "format" << 3;
83  obj->write(fs);
84  fs << "}";
85  plhs[0] = MxArray(fs.releaseAndGetString());
86  }
87  else
88  // write to disk
89  obj->save(fname);
90  }
91  else if (method == "empty") {
92  nargchk(nrhs==2 && nlhs<=1);
93  plhs[0] = MxArray(obj->empty());
94  }
95  else if (method == "getDefaultName") {
96  nargchk(nrhs==2 && nlhs<=1);
97  plhs[0] = MxArray(obj->getDefaultName());
98  }
99  else if (method == "getVarCount") {
100  nargchk(nrhs==2 && nlhs<=1);
101  plhs[0] = MxArray(obj->getVarCount());
102  }
103  else if (method == "isClassifier") {
104  nargchk(nrhs==2 && nlhs<=1);
105  plhs[0] = MxArray(obj->isClassifier());
106  }
107  else if (method == "isTrained") {
108  nargchk(nrhs==2 && nlhs<=1);
109  plhs[0] = MxArray(obj->isTrained());
110  }
111  else if (method == "train") {
112  nargchk(nrhs>=4 && (nrhs%2)==0 && nlhs<=1);
113  vector<MxArray> dataOptions;
114  int flags = 0;
115  for (int i=4; i<nrhs; i+=2) {
116  string key(rhs[i].toString());
117  if (key == "Data")
118  dataOptions = rhs[i+1].toVector<MxArray>();
119  else if (key == "Flags")
120  flags = rhs[i+1].toInt();
121  else if (key == "RawOutput")
122  UPDATE_FLAG(flags, rhs[i+1].toBool(), StatModel::RAW_OUTPUT);
123  else if (key == "PredictSum")
124  UPDATE_FLAG(flags, rhs[i+1].toBool(), DTrees::PREDICT_SUM);
125  else if (key == "PredictMaxVote")
126  UPDATE_FLAG(flags, rhs[i+1].toBool(), DTrees::PREDICT_MAX_VOTE);
127  else
128  mexErrMsgIdAndTxt("mexopencv:error",
129  "Unrecognized option %s", key.c_str());
130  }
131  Ptr<TrainData> data;
132  if (rhs[2].isChar())
133  data = loadTrainData(rhs[2].toString(),
134  dataOptions.begin(), dataOptions.end());
135  else
136  data = createTrainData(
137  rhs[2].toMat(CV_32F),
138  rhs[3].toMat(rhs[3].isInt32() ? CV_32S : CV_32F),
139  dataOptions.begin(), dataOptions.end());
140  bool b = obj->train(data, flags);
141  plhs[0] = MxArray(b);
142  }
143  else if (method == "calcError") {
144  nargchk(nrhs>=4 && (nrhs%2)==0 && nlhs<=2);
145  vector<MxArray> dataOptions;
146  bool test = false;
147  for (int i=4; i<nrhs; i+=2) {
148  string key(rhs[i].toString());
149  if (key == "Data")
150  dataOptions = rhs[i+1].toVector<MxArray>();
151  else if (key == "TestError")
152  test = rhs[i+1].toBool();
153  else
154  mexErrMsgIdAndTxt("mexopencv:error",
155  "Unrecognized option %s", key.c_str());
156  }
157  Ptr<TrainData> data;
158  if (rhs[2].isChar())
159  data = loadTrainData(rhs[2].toString(),
160  dataOptions.begin(), dataOptions.end());
161  else
162  data = createTrainData(
163  rhs[2].toMat(CV_32F),
164  rhs[3].toMat(rhs[3].isInt32() ? CV_32S : CV_32F),
165  dataOptions.begin(), dataOptions.end());
166  Mat resp;
167  float err = obj->calcError(data, test, (nlhs>1 ? resp : noArray()));
168  plhs[0] = MxArray(err);
169  if (nlhs>1)
170  plhs[1] = MxArray(resp);
171  }
172  else if (method == "predict") {
173  nargchk(nrhs>=3 && (nrhs%2)==1 && nlhs<=2);
174  int flags = 0;
175  for (int i=3; i<nrhs; i+=2) {
176  string key(rhs[i].toString());
177  if (key == "Flags")
178  flags = rhs[i+1].toInt();
179  else if (key == "RawOutput")
180  UPDATE_FLAG(flags, rhs[i+1].toBool(), StatModel::RAW_OUTPUT);
181  else if (key == "CompressedInput")
182  UPDATE_FLAG(flags, rhs[i+1].toBool(), StatModel::COMPRESSED_INPUT);
183  else if (key == "PreprocessedInput")
184  UPDATE_FLAG(flags, rhs[i+1].toBool(), StatModel::PREPROCESSED_INPUT);
185  else if (key == "PredictAuto") {
186  //UPDATE_FLAG(flags, rhs[i+1].toBool(), DTrees::PREDICT_AUTO);
187  UPDATE_FLAG(flags, !rhs[i+1].toBool(), DTrees::PREDICT_SUM);
188  UPDATE_FLAG(flags, !rhs[i+1].toBool(), DTrees::PREDICT_MAX_VOTE);
189  }
190  else if (key == "PredictSum")
191  UPDATE_FLAG(flags, rhs[i+1].toBool(), DTrees::PREDICT_SUM);
192  else if (key == "PredictMaxVote")
193  UPDATE_FLAG(flags, rhs[i+1].toBool(), DTrees::PREDICT_MAX_VOTE);
194  else
195  mexErrMsgIdAndTxt("mexopencv:error",
196  "Unrecognized option %s", key.c_str());
197  }
198  Mat samples(rhs[2].toMat(CV_32F)),
199  results;
200  float f = obj->predict(samples, results, flags);
201  plhs[0] = MxArray(results);
202  if (nlhs>1)
203  plhs[1] = MxArray(f);
204  }
205  else if (method == "getNodes") {
206  nargchk(nrhs==2 && nlhs<=1);
207  plhs[0] = toStruct(obj->getNodes());
208  }
209  else if (method == "getRoots") {
210  nargchk(nrhs==2 && nlhs<=1);
211  plhs[0] = MxArray(obj->getRoots());
212  }
213  else if (method == "getSplits") {
214  nargchk(nrhs==2 && nlhs<=1);
215  plhs[0] = toStruct(obj->getSplits());
216  }
217  else if (method == "getSubsets") {
218  nargchk(nrhs==2 && nlhs<=1);
219  plhs[0] = MxArray(obj->getSubsets());
220  }
221  else if (method == "getVarImportance") {
222  nargchk(nrhs==2 && nlhs<=1);
223  plhs[0] = MxArray(obj->getVarImportance());
224  }
225  else if (method == "get") {
226  nargchk(nrhs==3 && nlhs<=1);
227  string prop(rhs[2].toString());
228  if (prop == "CVFolds")
229  plhs[0] = MxArray(obj->getCVFolds());
230  else if (prop == "MaxCategories")
231  plhs[0] = MxArray(obj->getMaxCategories());
232  else if (prop == "MaxDepth")
233  plhs[0] = MxArray(obj->getMaxDepth());
234  else if (prop == "MinSampleCount")
235  plhs[0] = MxArray(obj->getMinSampleCount());
236  else if (prop == "Priors")
237  plhs[0] = MxArray(obj->getPriors());
238  else if (prop == "RegressionAccuracy")
239  plhs[0] = MxArray(obj->getRegressionAccuracy());
240  else if (prop == "TruncatePrunedTree")
241  plhs[0] = MxArray(obj->getTruncatePrunedTree());
242  else if (prop == "Use1SERule")
243  plhs[0] = MxArray(obj->getUse1SERule());
244  else if (prop == "UseSurrogates")
245  plhs[0] = MxArray(obj->getUseSurrogates());
246  else if (prop == "ActiveVarCount")
247  plhs[0] = MxArray(obj->getActiveVarCount());
248  else if (prop == "CalculateVarImportance")
249  plhs[0] = MxArray(obj->getCalculateVarImportance());
250  else if (prop == "TermCriteria")
251  plhs[0] = MxArray(obj->getTermCriteria());
252  else
253  mexErrMsgIdAndTxt("mexopencv:error",
254  "Unrecognized property %s", prop.c_str());
255  }
256  else if (method == "set") {
257  nargchk(nrhs==4 && nlhs==0);
258  string prop(rhs[2].toString());
259  if (prop == "CVFolds")
260  obj->setCVFolds(rhs[3].toInt());
261  else if (prop == "MaxCategories")
262  obj->setMaxCategories(rhs[3].toInt());
263  else if (prop == "MaxDepth")
264  obj->setMaxDepth(rhs[3].toInt());
265  else if (prop == "MinSampleCount")
266  obj->setMinSampleCount(rhs[3].toInt());
267  else if (prop == "Priors")
268  obj->setPriors(rhs[3].toMat());
269  else if (prop == "RegressionAccuracy")
270  obj->setRegressionAccuracy(rhs[3].toFloat());
271  else if (prop == "TruncatePrunedTree")
272  obj->setTruncatePrunedTree(rhs[3].toBool());
273  else if (prop == "Use1SERule")
274  obj->setUse1SERule(rhs[3].toBool());
275  else if (prop == "UseSurrogates")
276  obj->setUseSurrogates(rhs[3].toBool());
277  else if (prop == "ActiveVarCount")
278  obj->setActiveVarCount(rhs[3].toInt());
279  else if (prop == "CalculateVarImportance")
280  obj->setCalculateVarImportance(rhs[3].toBool());
281  else if (prop == "TermCriteria")
282  obj->setTermCriteria(rhs[3].toTermCriteria());
283  else
284  mexErrMsgIdAndTxt("mexopencv:error",
285  "Unrecognized property %s", prop.c_str());
286  }
287  else
288  mexErrMsgIdAndTxt("mexopencv:error","Unrecognized operation");
289 }
#define UPDATE_FLAG(NUM, TF, BIT)
set or clear a bit in flag depending on bool value
Definition: mexopencv.hpp:159
MxArray toStruct(const std::vector< cv::ml::DTrees::Node > &nodes)
Convert tree nodes to struct array.
int toInt() const
Convert MxArray to int.
Definition: MxArray.cpp:489
cv::Ptr< cv::ml::TrainData > loadTrainData(const std::string &filename, std::vector< MxArray >::const_iterator first, std::vector< MxArray >::const_iterator last)
Read a dataset from a CSV file.
mxArray object wrapper for data conversion and manipulation.
Definition: MxArray.hpp:123
void nargchk(bool cond)
Alias for input/ouput arguments number check.
Definition: mexopencv.hpp:166
bool toBool() const
Convert MxArray to bool.
Definition: MxArray.cpp:510
Global constant definitions.
void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[])
Main entry called from Matlab.
Definition: RTrees_.cpp:29
cv::Ptr< cv::ml::TrainData > createTrainData(const cv::Mat &samples, const cv::Mat &responses, std::vector< MxArray >::const_iterator first, std::vector< MxArray >::const_iterator last)
Create an instance of TrainData using options in arguments.
Common definitions for the ml module.