mexopencv  0.1
mex interface for opencv library
DTrees_.cpp
Go to the documentation of this file.
1 
8 #include "mexopencv.hpp"
9 #include "mexopencv_ml.hpp"
10 using namespace std;
11 using namespace cv;
12 using namespace cv::ml;
13 
14 // Persistent objects
15 namespace {
17 int last_id = 0;
19 map<int,Ptr<DTrees> > obj_;
20 }
21 
29 void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[])
30 {
31  // Check the number of arguments
32  nargchk(nrhs>=2 && nlhs<=2);
33 
34  // Argument vector
35  vector<MxArray> rhs(prhs, prhs+nrhs);
36  int id = rhs[0].toInt();
37  string method(rhs[1].toString());
38 
39  // Constructor is called. Create a new object from argument
40  if (method == "new") {
41  nargchk(nrhs==2 && nlhs<=1);
42  obj_[++last_id] = DTrees::create();
43  plhs[0] = MxArray(last_id);
44  return;
45  }
46 
47  // Big operation switch
48  Ptr<DTrees> obj = obj_[id];
49  if (method == "delete") {
50  nargchk(nrhs==2 && nlhs==0);
51  obj_.erase(id);
52  }
53  else if (method == "clear") {
54  nargchk(nrhs==2 && nlhs==0);
55  obj->clear();
56  }
57  else if (method == "load") {
58  nargchk(nrhs>=3 && (nrhs%2)==1 && nlhs==0);
59  string objname;
60  bool loadFromString = false;
61  for (int i=3; i<nrhs; i+=2) {
62  string key(rhs[i].toString());
63  if (key == "ObjName")
64  objname = rhs[i+1].toString();
65  else if (key == "FromString")
66  loadFromString = rhs[i+1].toBool();
67  else
68  mexErrMsgIdAndTxt("mexopencv:error",
69  "Unrecognized option %s", key.c_str());
70  }
71  obj_[id] = (loadFromString ?
72  Algorithm::loadFromString<DTrees>(rhs[2].toString(), objname) :
73  Algorithm::load<DTrees>(rhs[2].toString(), objname));
74  }
75  else if (method == "save") {
76  nargchk(nrhs==3 && nlhs<=1);
77  string fname(rhs[2].toString());
78  if (nlhs > 0) {
79  // write to memory, and return string
80  FileStorage fs(fname, FileStorage::WRITE + FileStorage::MEMORY);
81  fs << obj->getDefaultName() << "{";
82  fs << "format" << 3;
83  obj->write(fs);
84  fs << "}";
85  plhs[0] = MxArray(fs.releaseAndGetString());
86  }
87  else
88  // write to disk
89  obj->save(fname);
90  }
91  else if (method == "empty") {
92  nargchk(nrhs==2 && nlhs<=1);
93  plhs[0] = MxArray(obj->empty());
94  }
95  else if (method == "getDefaultName") {
96  nargchk(nrhs==2 && nlhs<=1);
97  plhs[0] = MxArray(obj->getDefaultName());
98  }
99  else if (method == "getVarCount") {
100  nargchk(nrhs==2 && nlhs<=1);
101  plhs[0] = MxArray(obj->getVarCount());
102  }
103  else if (method == "isClassifier") {
104  nargchk(nrhs==2 && nlhs<=1);
105  plhs[0] = MxArray(obj->isClassifier());
106  }
107  else if (method == "isTrained") {
108  nargchk(nrhs==2 && nlhs<=1);
109  plhs[0] = MxArray(obj->isTrained());
110  }
111  else if (method == "train") {
112  nargchk(nrhs>=4 && (nrhs%2)==0 && nlhs<=1);
113  vector<MxArray> dataOptions;
114  int flags = 0;
115  for (int i=4; i<nrhs; i+=2) {
116  string key(rhs[i].toString());
117  if (key == "Data")
118  dataOptions = rhs[i+1].toVector<MxArray>();
119  else if (key == "Flags")
120  flags = rhs[i+1].toInt();
121  else
122  mexErrMsgIdAndTxt("mexopencv:error",
123  "Unrecognized option %s", key.c_str());
124  }
125  Ptr<TrainData> data;
126  if (rhs[2].isChar())
127  data = loadTrainData(rhs[2].toString(),
128  dataOptions.begin(), dataOptions.end());
129  else
130  data = createTrainData(
131  rhs[2].toMat(CV_32F),
132  rhs[3].toMat(rhs[3].isInt32() ? CV_32S : CV_32F),
133  dataOptions.begin(), dataOptions.end());
134  bool b = obj->train(data, flags);
135  plhs[0] = MxArray(b);
136  }
137  else if (method == "calcError") {
138  nargchk(nrhs>=4 && (nrhs%2)==0 && nlhs<=2);
139  vector<MxArray> dataOptions;
140  bool test = false;
141  for (int i=4; i<nrhs; i+=2) {
142  string key(rhs[i].toString());
143  if (key == "Data")
144  dataOptions = rhs[i+1].toVector<MxArray>();
145  else if (key == "TestError")
146  test = rhs[i+1].toBool();
147  else
148  mexErrMsgIdAndTxt("mexopencv:error",
149  "Unrecognized option %s", key.c_str());
150  }
151  Ptr<TrainData> data;
152  if (rhs[2].isChar())
153  data = loadTrainData(rhs[2].toString(),
154  dataOptions.begin(), dataOptions.end());
155  else
156  data = createTrainData(
157  rhs[2].toMat(CV_32F),
158  rhs[3].toMat(rhs[3].isInt32() ? CV_32S : CV_32F),
159  dataOptions.begin(), dataOptions.end());
160  Mat resp;
161  float err = obj->calcError(data, test, (nlhs>1 ? resp : noArray()));
162  plhs[0] = MxArray(err);
163  if (nlhs>1)
164  plhs[1] = MxArray(resp);
165  }
166  else if (method == "predict") {
167  nargchk(nrhs>=3 && (nrhs%2)==1 && nlhs<=2);
168  int flags = 0;
169  for (int i=3; i<nrhs; i+=2) {
170  string key(rhs[i].toString());
171  if (key == "Flags")
172  flags = rhs[i+1].toInt();
173  else if (key == "RawOutput")
174  UPDATE_FLAG(flags, rhs[i+1].toBool(), StatModel::RAW_OUTPUT);
175  else if (key == "CompressedInput")
176  UPDATE_FLAG(flags, rhs[i+1].toBool(), StatModel::COMPRESSED_INPUT);
177  else if (key == "PreprocessedInput")
178  UPDATE_FLAG(flags, rhs[i+1].toBool(), StatModel::PREPROCESSED_INPUT);
179  else if (key == "PredictAuto") {
180  //UPDATE_FLAG(flags, rhs[i+1].toBool(), DTrees::PREDICT_AUTO);
181  UPDATE_FLAG(flags, !rhs[i+1].toBool(), DTrees::PREDICT_SUM);
182  UPDATE_FLAG(flags, !rhs[i+1].toBool(), DTrees::PREDICT_MAX_VOTE);
183  }
184  else if (key == "PredictSum")
185  UPDATE_FLAG(flags, rhs[i+1].toBool(), DTrees::PREDICT_SUM);
186  else if (key == "PredictMaxVote")
187  UPDATE_FLAG(flags, rhs[i+1].toBool(), DTrees::PREDICT_MAX_VOTE);
188  else
189  mexErrMsgIdAndTxt("mexopencv:error",
190  "Unrecognized option %s", key.c_str());
191  }
192  Mat samples(rhs[2].toMat(CV_32F)),
193  results;
194  float f = obj->predict(samples, results, flags);
195  plhs[0] = MxArray(results);
196  if (nlhs>1)
197  plhs[1] = MxArray(f);
198  }
199  else if (method == "getNodes") {
200  nargchk(nrhs==2 && nlhs<=1);
201  plhs[0] = toStruct(obj->getNodes());
202  }
203  else if (method == "getRoots") {
204  nargchk(nrhs==2 && nlhs<=1);
205  plhs[0] = MxArray(obj->getRoots());
206  }
207  else if (method == "getSplits") {
208  nargchk(nrhs==2 && nlhs<=1);
209  plhs[0] = toStruct(obj->getSplits());
210  }
211  else if (method == "getSubsets") {
212  nargchk(nrhs==2 && nlhs<=1);
213  plhs[0] = MxArray(obj->getSubsets());
214  }
215  else if (method == "get") {
216  nargchk(nrhs==3 && nlhs<=1);
217  string prop(rhs[2].toString());
218  if (prop == "CVFolds")
219  plhs[0] = MxArray(obj->getCVFolds());
220  else if (prop == "MaxCategories")
221  plhs[0] = MxArray(obj->getMaxCategories());
222  else if (prop == "MaxDepth")
223  plhs[0] = MxArray(obj->getMaxDepth());
224  else if (prop == "MinSampleCount")
225  plhs[0] = MxArray(obj->getMinSampleCount());
226  else if (prop == "Priors")
227  plhs[0] = MxArray(obj->getPriors());
228  else if (prop == "RegressionAccuracy")
229  plhs[0] = MxArray(obj->getRegressionAccuracy());
230  else if (prop == "TruncatePrunedTree")
231  plhs[0] = MxArray(obj->getTruncatePrunedTree());
232  else if (prop == "Use1SERule")
233  plhs[0] = MxArray(obj->getUse1SERule());
234  else if (prop == "UseSurrogates")
235  plhs[0] = MxArray(obj->getUseSurrogates());
236  else
237  mexErrMsgIdAndTxt("mexopencv:error",
238  "Unrecognized property %s", prop.c_str());
239  }
240  else if (method == "set") {
241  nargchk(nrhs==4 && nlhs==0);
242  string prop(rhs[2].toString());
243  if (prop == "CVFolds")
244  obj->setCVFolds(rhs[3].toInt());
245  else if (prop == "MaxCategories")
246  obj->setMaxCategories(rhs[3].toInt());
247  else if (prop == "MaxDepth")
248  obj->setMaxDepth(rhs[3].toInt());
249  else if (prop == "MinSampleCount")
250  obj->setMinSampleCount(rhs[3].toInt());
251  else if (prop == "Priors")
252  obj->setPriors(rhs[3].toMat());
253  else if (prop == "RegressionAccuracy")
254  obj->setRegressionAccuracy(rhs[3].toFloat());
255  else if (prop == "TruncatePrunedTree")
256  obj->setTruncatePrunedTree(rhs[3].toBool());
257  else if (prop == "Use1SERule")
258  obj->setUse1SERule(rhs[3].toBool());
259  else if (prop == "UseSurrogates")
260  obj->setUseSurrogates(rhs[3].toBool());
261  else
262  mexErrMsgIdAndTxt("mexopencv:error",
263  "Unrecognized property %s", prop.c_str());
264  }
265  else
266  mexErrMsgIdAndTxt("mexopencv:error","Unrecognized operation");
267 }
#define UPDATE_FLAG(NUM, TF, BIT)
set or clear a bit in flag depending on bool value
Definition: mexopencv.hpp:159
MxArray toStruct(const std::vector< cv::ml::DTrees::Node > &nodes)
Convert tree nodes to struct array.
void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[])
Main entry called from Matlab.
Definition: DTrees_.cpp:29
int toInt() const
Convert MxArray to int.
Definition: MxArray.cpp:489
cv::Ptr< cv::ml::TrainData > loadTrainData(const std::string &filename, std::vector< MxArray >::const_iterator first, std::vector< MxArray >::const_iterator last)
Read a dataset from a CSV file.
mxArray object wrapper for data conversion and manipulation.
Definition: MxArray.hpp:123
void nargchk(bool cond)
Alias for input/ouput arguments number check.
Definition: mexopencv.hpp:166
bool toBool() const
Convert MxArray to bool.
Definition: MxArray.cpp:510
Global constant definitions.
cv::Ptr< cv::ml::TrainData > createTrainData(const cv::Mat &samples, const cv::Mat &responses, std::vector< MxArray >::const_iterator first, std::vector< MxArray >::const_iterator last)
Create an instance of TrainData using options in arguments.
Common definitions for the ml module.