19 map<int,Ptr<RTrees> > obj_;
29 void mexFunction(
int nlhs, mxArray *plhs[],
int nrhs,
const mxArray *prhs[])
35 vector<MxArray> rhs(prhs, prhs+nrhs);
36 int id = rhs[0].toInt();
37 string method(rhs[1].toString());
40 if (method ==
"new") {
42 obj_[++last_id] = RTrees::create();
48 Ptr<RTrees> obj = obj_[id];
49 if (method ==
"delete") {
53 else if (method ==
"clear") {
57 else if (method ==
"load") {
58 nargchk(nrhs>=3 && (nrhs%2)==1 && nlhs==0);
60 bool loadFromString =
false;
61 for (
int i=3; i<nrhs; i+=2) {
62 string key(rhs[i].toString());
64 objname = rhs[i+1].toString();
65 else if (key ==
"FromString")
66 loadFromString = rhs[i+1].toBool();
68 mexErrMsgIdAndTxt(
"mexopencv:error",
69 "Unrecognized option %s", key.c_str());
71 obj_[id] = (loadFromString ?
72 Algorithm::loadFromString<RTrees>(rhs[2].toString(), objname) :
73 Algorithm::load<RTrees>(rhs[2].toString(), objname));
75 else if (method ==
"save") {
77 string fname(rhs[2].toString());
80 FileStorage fs(fname, FileStorage::WRITE + FileStorage::MEMORY);
81 fs << obj->getDefaultName() <<
"{";
85 plhs[0] =
MxArray(fs.releaseAndGetString());
91 else if (method ==
"empty") {
93 plhs[0] =
MxArray(obj->empty());
95 else if (method ==
"getDefaultName") {
97 plhs[0] =
MxArray(obj->getDefaultName());
99 else if (method ==
"getVarCount") {
101 plhs[0] =
MxArray(obj->getVarCount());
103 else if (method ==
"isClassifier") {
105 plhs[0] =
MxArray(obj->isClassifier());
107 else if (method ==
"isTrained") {
109 plhs[0] =
MxArray(obj->isTrained());
111 else if (method ==
"train") {
112 nargchk(nrhs>=4 && (nrhs%2)==0 && nlhs<=1);
113 vector<MxArray> dataOptions;
115 for (
int i=4; i<nrhs; i+=2) {
116 string key(rhs[i].toString());
118 dataOptions = rhs[i+1].toVector<
MxArray>();
119 else if (key ==
"Flags")
120 flags = rhs[i+1].
toInt();
121 else if (key ==
"RawOutput")
122 UPDATE_FLAG(flags, rhs[i+1].toBool(), StatModel::RAW_OUTPUT);
123 else if (key ==
"PredictSum")
124 UPDATE_FLAG(flags, rhs[i+1].toBool(), DTrees::PREDICT_SUM);
125 else if (key ==
"PredictMaxVote")
126 UPDATE_FLAG(flags, rhs[i+1].toBool(), DTrees::PREDICT_MAX_VOTE);
128 mexErrMsgIdAndTxt(
"mexopencv:error",
129 "Unrecognized option %s", key.c_str());
134 dataOptions.begin(), dataOptions.end());
137 rhs[2].toMat(CV_32F),
138 rhs[3].toMat(rhs[3].isInt32() ? CV_32S : CV_32F),
139 dataOptions.begin(), dataOptions.end());
140 bool b = obj->train(data, flags);
143 else if (method ==
"calcError") {
144 nargchk(nrhs>=4 && (nrhs%2)==0 && nlhs<=2);
145 vector<MxArray> dataOptions;
147 for (
int i=4; i<nrhs; i+=2) {
148 string key(rhs[i].toString());
150 dataOptions = rhs[i+1].toVector<
MxArray>();
151 else if (key ==
"TestError")
154 mexErrMsgIdAndTxt(
"mexopencv:error",
155 "Unrecognized option %s", key.c_str());
160 dataOptions.begin(), dataOptions.end());
163 rhs[2].toMat(CV_32F),
164 rhs[3].toMat(rhs[3].isInt32() ? CV_32S : CV_32F),
165 dataOptions.begin(), dataOptions.end());
167 float err = obj->calcError(data, test, (nlhs>1 ? resp : noArray()));
172 else if (method ==
"predict") {
173 nargchk(nrhs>=3 && (nrhs%2)==1 && nlhs<=2);
175 for (
int i=3; i<nrhs; i+=2) {
176 string key(rhs[i].toString());
178 flags = rhs[i+1].toInt();
179 else if (key ==
"RawOutput")
180 UPDATE_FLAG(flags, rhs[i+1].toBool(), StatModel::RAW_OUTPUT);
181 else if (key ==
"CompressedInput")
182 UPDATE_FLAG(flags, rhs[i+1].toBool(), StatModel::COMPRESSED_INPUT);
183 else if (key ==
"PreprocessedInput")
184 UPDATE_FLAG(flags, rhs[i+1].toBool(), StatModel::PREPROCESSED_INPUT);
185 else if (key ==
"PredictAuto") {
187 UPDATE_FLAG(flags, !rhs[i+1].toBool(), DTrees::PREDICT_SUM);
188 UPDATE_FLAG(flags, !rhs[i+1].toBool(), DTrees::PREDICT_MAX_VOTE);
190 else if (key ==
"PredictSum")
191 UPDATE_FLAG(flags, rhs[i+1].toBool(), DTrees::PREDICT_SUM);
192 else if (key ==
"PredictMaxVote")
193 UPDATE_FLAG(flags, rhs[i+1].toBool(), DTrees::PREDICT_MAX_VOTE);
195 mexErrMsgIdAndTxt(
"mexopencv:error",
196 "Unrecognized option %s", key.c_str());
198 Mat samples(rhs[2].toMat(CV_32F)),
200 float f = obj->predict(samples, results, flags);
205 else if (method ==
"getNodes") {
207 plhs[0] =
toStruct(obj->getNodes());
209 else if (method ==
"getRoots") {
211 plhs[0] =
MxArray(obj->getRoots());
213 else if (method ==
"getSplits") {
215 plhs[0] =
toStruct(obj->getSplits());
217 else if (method ==
"getSubsets") {
219 plhs[0] =
MxArray(obj->getSubsets());
221 else if (method ==
"getVarImportance") {
223 plhs[0] =
MxArray(obj->getVarImportance());
225 else if (method ==
"get") {
227 string prop(rhs[2].toString());
228 if (prop ==
"CVFolds")
229 plhs[0] =
MxArray(obj->getCVFolds());
230 else if (prop ==
"MaxCategories")
231 plhs[0] =
MxArray(obj->getMaxCategories());
232 else if (prop ==
"MaxDepth")
233 plhs[0] =
MxArray(obj->getMaxDepth());
234 else if (prop ==
"MinSampleCount")
235 plhs[0] =
MxArray(obj->getMinSampleCount());
236 else if (prop ==
"Priors")
237 plhs[0] =
MxArray(obj->getPriors());
238 else if (prop ==
"RegressionAccuracy")
239 plhs[0] =
MxArray(obj->getRegressionAccuracy());
240 else if (prop ==
"TruncatePrunedTree")
241 plhs[0] =
MxArray(obj->getTruncatePrunedTree());
242 else if (prop ==
"Use1SERule")
243 plhs[0] =
MxArray(obj->getUse1SERule());
244 else if (prop ==
"UseSurrogates")
245 plhs[0] =
MxArray(obj->getUseSurrogates());
246 else if (prop ==
"ActiveVarCount")
247 plhs[0] =
MxArray(obj->getActiveVarCount());
248 else if (prop ==
"CalculateVarImportance")
249 plhs[0] =
MxArray(obj->getCalculateVarImportance());
250 else if (prop ==
"TermCriteria")
251 plhs[0] =
MxArray(obj->getTermCriteria());
253 mexErrMsgIdAndTxt(
"mexopencv:error",
254 "Unrecognized property %s", prop.c_str());
256 else if (method ==
"set") {
258 string prop(rhs[2].toString());
259 if (prop ==
"CVFolds")
260 obj->setCVFolds(rhs[3].toInt());
261 else if (prop ==
"MaxCategories")
262 obj->setMaxCategories(rhs[3].toInt());
263 else if (prop ==
"MaxDepth")
264 obj->setMaxDepth(rhs[3].toInt());
265 else if (prop ==
"MinSampleCount")
266 obj->setMinSampleCount(rhs[3].toInt());
267 else if (prop ==
"Priors")
268 obj->setPriors(rhs[3].toMat());
269 else if (prop ==
"RegressionAccuracy")
270 obj->setRegressionAccuracy(rhs[3].toFloat());
271 else if (prop ==
"TruncatePrunedTree")
272 obj->setTruncatePrunedTree(rhs[3].toBool());
273 else if (prop ==
"Use1SERule")
274 obj->setUse1SERule(rhs[3].toBool());
275 else if (prop ==
"UseSurrogates")
276 obj->setUseSurrogates(rhs[3].toBool());
277 else if (prop ==
"ActiveVarCount")
278 obj->setActiveVarCount(rhs[3].toInt());
279 else if (prop ==
"CalculateVarImportance")
280 obj->setCalculateVarImportance(rhs[3].toBool());
281 else if (prop ==
"TermCriteria")
282 obj->setTermCriteria(rhs[3].toTermCriteria());
284 mexErrMsgIdAndTxt(
"mexopencv:error",
285 "Unrecognized property %s", prop.c_str());
288 mexErrMsgIdAndTxt(
"mexopencv:error",
"Unrecognized operation");
#define UPDATE_FLAG(NUM, TF, BIT)
set or clear a bit in flag depending on bool value
MxArray toStruct(const std::vector< cv::ml::DTrees::Node > &nodes)
Convert tree nodes to struct array.
int toInt() const
Convert MxArray to int.
cv::Ptr< cv::ml::TrainData > loadTrainData(const std::string &filename, std::vector< MxArray >::const_iterator first, std::vector< MxArray >::const_iterator last)
Read a dataset from a CSV file.
mxArray object wrapper for data conversion and manipulation.
void nargchk(bool cond)
Alias for input/ouput arguments number check.
bool toBool() const
Convert MxArray to bool.
Global constant definitions.
void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[])
Main entry called from Matlab.
cv::Ptr< cv::ml::TrainData > createTrainData(const cv::Mat &samples, const cv::Mat &responses, std::vector< MxArray >::const_iterator first, std::vector< MxArray >::const_iterator last)
Create an instance of TrainData using options in arguments.
Common definitions for the ml module.