mexopencv  0.1
mex interface for opencv library
Boost_.cpp
Go to the documentation of this file.
1 
8 #include "mexopencv.hpp"
9 #include "mexopencv_ml.hpp"
10 using namespace std;
11 using namespace cv;
12 using namespace cv::ml;
13 
14 // Persistent objects
15 namespace {
17 int last_id = 0;
19 map<int,Ptr<Boost> > obj_;
20 
23  ("Discrete", cv::ml::Boost::DISCRETE)
24  ("Real", cv::ml::Boost::REAL)
25  ("Logit", cv::ml::Boost::LOGIT)
26  ("Gentle", cv::ml::Boost::GENTLE);
27 
29 const ConstMap<int,string> InvBoostType = ConstMap<int,string>
30  (cv::ml::Boost::DISCRETE, "Discrete")
31  (cv::ml::Boost::REAL, "Real")
32  (cv::ml::Boost::LOGIT, "Logit")
33  (cv::ml::Boost::GENTLE, "Gentle");
34 }
35 
43 void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[])
44 {
45  // Check the number of arguments
46  nargchk(nrhs>=2 && nlhs<=2);
47 
48  // Argument vector
49  vector<MxArray> rhs(prhs, prhs+nrhs);
50  int id = rhs[0].toInt();
51  string method(rhs[1].toString());
52 
53  // Constructor is called. Create a new object from argument
54  if (method == "new") {
55  nargchk(nrhs==2 && nlhs<=1);
56  obj_[++last_id] = Boost::create();
57  plhs[0] = MxArray(last_id);
58  return;
59  }
60 
61  // Big operation switch
62  Ptr<Boost> obj = obj_[id];
63  if (method == "delete") {
64  nargchk(nrhs==2 && nlhs==0);
65  obj_.erase(id);
66  }
67  else if (method == "clear") {
68  nargchk(nrhs==2 && nlhs==0);
69  obj->clear();
70  }
71  else if (method == "load") {
72  nargchk(nrhs>=3 && (nrhs%2)==1 && nlhs==0);
73  string objname;
74  bool loadFromString = false;
75  for (int i=3; i<nrhs; i+=2) {
76  string key(rhs[i].toString());
77  if (key == "ObjName")
78  objname = rhs[i+1].toString();
79  else if (key == "FromString")
80  loadFromString = rhs[i+1].toBool();
81  else
82  mexErrMsgIdAndTxt("mexopencv:error",
83  "Unrecognized option %s", key.c_str());
84  }
85  obj_[id] = (loadFromString ?
86  Algorithm::loadFromString<Boost>(rhs[2].toString(), objname) :
87  Algorithm::load<Boost>(rhs[2].toString(), objname));
88  }
89  else if (method == "save") {
90  nargchk(nrhs==3 && nlhs<=1);
91  string fname(rhs[2].toString());
92  if (nlhs > 0) {
93  // write to memory, and return string
94  FileStorage fs(fname, FileStorage::WRITE + FileStorage::MEMORY);
95  fs << obj->getDefaultName() << "{";
96  fs << "format" << 3;
97  obj->write(fs);
98  fs << "}";
99  plhs[0] = MxArray(fs.releaseAndGetString());
100  }
101  else
102  // write to disk
103  obj->save(fname);
104  }
105  else if (method == "empty") {
106  nargchk(nrhs==2 && nlhs<=1);
107  plhs[0] = MxArray(obj->empty());
108  }
109  else if (method == "getDefaultName") {
110  nargchk(nrhs==2 && nlhs<=1);
111  plhs[0] = MxArray(obj->getDefaultName());
112  }
113  else if (method == "getVarCount") {
114  nargchk(nrhs==2 && nlhs<=1);
115  plhs[0] = MxArray(obj->getVarCount());
116  }
117  else if (method == "isClassifier") {
118  nargchk(nrhs==2 && nlhs<=1);
119  plhs[0] = MxArray(obj->isClassifier());
120  }
121  else if (method == "isTrained") {
122  nargchk(nrhs==2 && nlhs<=1);
123  plhs[0] = MxArray(obj->isTrained());
124  }
125  else if (method == "train") {
126  nargchk(nrhs>=4 && (nrhs%2)==0 && nlhs<=1);
127  vector<MxArray> dataOptions;
128  int flags = 0;
129  for (int i=4; i<nrhs; i+=2) {
130  string key(rhs[i].toString());
131  if (key == "Data")
132  dataOptions = rhs[i+1].toVector<MxArray>();
133  else if (key == "Flags")
134  flags = rhs[i+1].toInt();
135  else if (key == "RawOutput")
136  UPDATE_FLAG(flags, rhs[i+1].toBool(), StatModel::RAW_OUTPUT);
137  else if (key == "CompressedInput")
138  UPDATE_FLAG(flags, rhs[i+1].toBool(), StatModel::COMPRESSED_INPUT);
139  else if (key == "PredictSum")
140  UPDATE_FLAG(flags, rhs[i+1].toBool(), DTrees::PREDICT_SUM);
141  else if (key == "PredictMaxVote")
142  UPDATE_FLAG(flags, rhs[i+1].toBool(), DTrees::PREDICT_MAX_VOTE);
143  else
144  mexErrMsgIdAndTxt("mexopencv:error",
145  "Unrecognized option %s", key.c_str());
146  }
147  Ptr<TrainData> data;
148  if (rhs[2].isChar())
149  data = loadTrainData(rhs[2].toString(),
150  dataOptions.begin(), dataOptions.end());
151  else
152  data = createTrainData(
153  rhs[2].toMat(CV_32F),
154  rhs[3].toMat(rhs[3].isInt32() ? CV_32S : CV_32F),
155  dataOptions.begin(), dataOptions.end());
156  bool b = obj->train(data, flags);
157  plhs[0] = MxArray(b);
158  }
159  else if (method == "calcError") {
160  nargchk(nrhs>=4 && (nrhs%2)==0 && nlhs<=2);
161  vector<MxArray> dataOptions;
162  bool test = false;
163  for (int i=4; i<nrhs; i+=2) {
164  string key(rhs[i].toString());
165  if (key == "Data")
166  dataOptions = rhs[i+1].toVector<MxArray>();
167  else if (key == "TestError")
168  test = rhs[i+1].toBool();
169  else
170  mexErrMsgIdAndTxt("mexopencv:error",
171  "Unrecognized option %s", key.c_str());
172  }
173  Ptr<TrainData> data;
174  if (rhs[2].isChar())
175  data = loadTrainData(rhs[2].toString(),
176  dataOptions.begin(), dataOptions.end());
177  else
178  data = createTrainData(
179  rhs[2].toMat(CV_32F),
180  rhs[3].toMat(rhs[3].isInt32() ? CV_32S : CV_32F),
181  dataOptions.begin(), dataOptions.end());
182  Mat resp;
183  float err = obj->calcError(data, test, (nlhs>1 ? resp : noArray()));
184  plhs[0] = MxArray(err);
185  if (nlhs>1)
186  plhs[1] = MxArray(resp);
187  }
188  else if (method == "predict") {
189  nargchk(nrhs>=3 && (nrhs%2)==1 && nlhs<=2);
190  int flags = 0;
191  for (int i=3; i<nrhs; i+=2) {
192  string key(rhs[i].toString());
193  if (key == "Flags")
194  flags = rhs[i+1].toInt();
195  else if (key == "RawOutput")
196  UPDATE_FLAG(flags, rhs[i+1].toBool(), StatModel::RAW_OUTPUT);
197  else if (key == "CompressedInput")
198  UPDATE_FLAG(flags, rhs[i+1].toBool(), StatModel::COMPRESSED_INPUT);
199  else if (key == "PreprocessedInput")
200  UPDATE_FLAG(flags, rhs[i+1].toBool(), StatModel::PREPROCESSED_INPUT);
201  else if (key == "PredictAuto") {
202  //UPDATE_FLAG(flags, rhs[i+1].toBool(), DTrees::PREDICT_AUTO);
203  UPDATE_FLAG(flags, !rhs[i+1].toBool(), DTrees::PREDICT_SUM);
204  UPDATE_FLAG(flags, !rhs[i+1].toBool(), DTrees::PREDICT_MAX_VOTE);
205  }
206  else if (key == "PredictSum")
207  UPDATE_FLAG(flags, rhs[i+1].toBool(), DTrees::PREDICT_SUM);
208  else if (key == "PredictMaxVote")
209  UPDATE_FLAG(flags, rhs[i+1].toBool(), DTrees::PREDICT_MAX_VOTE);
210  else
211  mexErrMsgIdAndTxt("mexopencv:error",
212  "Unrecognized option %s", key.c_str());
213  }
214  Mat samples(rhs[2].toMat(CV_32F)),
215  results;
216  float f = obj->predict(samples, results, flags);
217  plhs[0] = MxArray(results);
218  if (nlhs>1)
219  plhs[1] = MxArray(f);
220  }
221  else if (method == "getNodes") {
222  nargchk(nrhs==2 && nlhs<=1);
223  plhs[0] = toStruct(obj->getNodes());
224  }
225  else if (method == "getRoots") {
226  nargchk(nrhs==2 && nlhs<=1);
227  plhs[0] = MxArray(obj->getRoots());
228  }
229  else if (method == "getSplits") {
230  nargchk(nrhs==2 && nlhs<=1);
231  plhs[0] = toStruct(obj->getSplits());
232  }
233  else if (method == "getSubsets") {
234  nargchk(nrhs==2 && nlhs<=1);
235  plhs[0] = MxArray(obj->getSubsets());
236  }
237  else if (method == "get") {
238  nargchk(nrhs==3 && nlhs<=1);
239  string prop(rhs[2].toString());
240  if (prop == "CVFolds")
241  plhs[0] = MxArray(obj->getCVFolds());
242  else if (prop == "MaxCategories")
243  plhs[0] = MxArray(obj->getMaxCategories());
244  else if (prop == "MaxDepth")
245  plhs[0] = MxArray(obj->getMaxDepth());
246  else if (prop == "MinSampleCount")
247  plhs[0] = MxArray(obj->getMinSampleCount());
248  else if (prop == "Priors")
249  plhs[0] = MxArray(obj->getPriors());
250  else if (prop == "RegressionAccuracy")
251  plhs[0] = MxArray(obj->getRegressionAccuracy());
252  else if (prop == "TruncatePrunedTree")
253  plhs[0] = MxArray(obj->getTruncatePrunedTree());
254  else if (prop == "Use1SERule")
255  plhs[0] = MxArray(obj->getUse1SERule());
256  else if (prop == "UseSurrogates")
257  plhs[0] = MxArray(obj->getUseSurrogates());
258  else if (prop == "BoostType")
259  plhs[0] = MxArray(InvBoostType[obj->getBoostType()]);
260  else if (prop == "WeakCount")
261  plhs[0] = MxArray(obj->getWeakCount());
262  else if (prop == "WeightTrimRate")
263  plhs[0] = MxArray(obj->getWeightTrimRate());
264  else
265  mexErrMsgIdAndTxt("mexopencv:error",
266  "Unrecognized property %s", prop.c_str());
267  }
268  else if (method == "set") {
269  nargchk(nrhs==4 && nlhs==0);
270  string prop(rhs[2].toString());
271  if (prop == "CVFolds")
272  obj->setCVFolds(rhs[3].toInt());
273  else if (prop == "MaxCategories")
274  obj->setMaxCategories(rhs[3].toInt());
275  else if (prop == "MaxDepth")
276  obj->setMaxDepth(rhs[3].toInt());
277  else if (prop == "MinSampleCount")
278  obj->setMinSampleCount(rhs[3].toInt());
279  else if (prop == "Priors")
280  obj->setPriors(rhs[3].toMat());
281  else if (prop == "RegressionAccuracy")
282  obj->setRegressionAccuracy(rhs[3].toFloat());
283  else if (prop == "TruncatePrunedTree")
284  obj->setTruncatePrunedTree(rhs[3].toBool());
285  else if (prop == "Use1SERule")
286  obj->setUse1SERule(rhs[3].toBool());
287  else if (prop == "UseSurrogates")
288  obj->setUseSurrogates(rhs[3].toBool());
289  else if (prop == "BoostType")
290  obj->setBoostType(BoostType[rhs[3].toString()]);
291  else if (prop == "WeakCount")
292  obj->setWeakCount(rhs[3].toInt());
293  else if (prop == "WeightTrimRate")
294  obj->setWeightTrimRate(rhs[3].toDouble());
295  else
296  mexErrMsgIdAndTxt("mexopencv:error",
297  "Unrecognized property %s", prop.c_str());
298  }
299  else
300  mexErrMsgIdAndTxt("mexopencv:error","Unrecognized operation");
301 }
#define UPDATE_FLAG(NUM, TF, BIT)
set or clear a bit in flag depending on bool value
Definition: mexopencv.hpp:159
MxArray toStruct(const std::vector< cv::ml::DTrees::Node > &nodes)
Convert tree nodes to struct array.
int toInt() const
Convert MxArray to int.
Definition: MxArray.cpp:489
cv::Ptr< cv::ml::TrainData > loadTrainData(const std::string &filename, std::vector< MxArray >::const_iterator first, std::vector< MxArray >::const_iterator last)
Read a dataset from a CSV file.
mxArray object wrapper for data conversion and manipulation.
Definition: MxArray.hpp:123
void nargchk(bool cond)
Alias for input/ouput arguments number check.
Definition: mexopencv.hpp:166
bool toBool() const
Convert MxArray to bool.
Definition: MxArray.cpp:510
void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[])
Main entry called from Matlab.
Definition: Boost_.cpp:43
Global constant definitions.
cv::Ptr< cv::ml::TrainData > createTrainData(const cv::Mat &samples, const cv::Mat &responses, std::vector< MxArray >::const_iterator first, std::vector< MxArray >::const_iterator last)
Create an instance of TrainData using options in arguments.
std::map wrapper with one-line initialization and lookup method.
Definition: MxArray.hpp:927
Common definitions for the ml module.