19 map<int,Ptr<DTrees> > obj_;
29 void mexFunction(
int nlhs, mxArray *plhs[],
int nrhs,
const mxArray *prhs[])
35 vector<MxArray> rhs(prhs, prhs+nrhs);
36 int id = rhs[0].toInt();
37 string method(rhs[1].toString());
40 if (method ==
"new") {
42 obj_[++last_id] = DTrees::create();
48 Ptr<DTrees> obj = obj_[id];
49 if (method ==
"delete") {
53 else if (method ==
"clear") {
57 else if (method ==
"load") {
58 nargchk(nrhs>=3 && (nrhs%2)==1 && nlhs==0);
60 bool loadFromString =
false;
61 for (
int i=3; i<nrhs; i+=2) {
62 string key(rhs[i].toString());
64 objname = rhs[i+1].toString();
65 else if (key ==
"FromString")
66 loadFromString = rhs[i+1].toBool();
68 mexErrMsgIdAndTxt(
"mexopencv:error",
69 "Unrecognized option %s", key.c_str());
71 obj_[id] = (loadFromString ?
72 Algorithm::loadFromString<DTrees>(rhs[2].toString(), objname) :
73 Algorithm::load<DTrees>(rhs[2].toString(), objname));
75 else if (method ==
"save") {
77 string fname(rhs[2].toString());
80 FileStorage fs(fname, FileStorage::WRITE + FileStorage::MEMORY);
81 fs << obj->getDefaultName() <<
"{";
85 plhs[0] =
MxArray(fs.releaseAndGetString());
91 else if (method ==
"empty") {
93 plhs[0] =
MxArray(obj->empty());
95 else if (method ==
"getDefaultName") {
97 plhs[0] =
MxArray(obj->getDefaultName());
99 else if (method ==
"getVarCount") {
101 plhs[0] =
MxArray(obj->getVarCount());
103 else if (method ==
"isClassifier") {
105 plhs[0] =
MxArray(obj->isClassifier());
107 else if (method ==
"isTrained") {
109 plhs[0] =
MxArray(obj->isTrained());
111 else if (method ==
"train") {
112 nargchk(nrhs>=4 && (nrhs%2)==0 && nlhs<=1);
113 vector<MxArray> dataOptions;
115 for (
int i=4; i<nrhs; i+=2) {
116 string key(rhs[i].toString());
118 dataOptions = rhs[i+1].toVector<
MxArray>();
119 else if (key ==
"Flags")
120 flags = rhs[i+1].
toInt();
122 mexErrMsgIdAndTxt(
"mexopencv:error",
123 "Unrecognized option %s", key.c_str());
128 dataOptions.begin(), dataOptions.end());
131 rhs[2].toMat(CV_32F),
132 rhs[3].toMat(rhs[3].isInt32() ? CV_32S : CV_32F),
133 dataOptions.begin(), dataOptions.end());
134 bool b = obj->train(data, flags);
137 else if (method ==
"calcError") {
138 nargchk(nrhs>=4 && (nrhs%2)==0 && nlhs<=2);
139 vector<MxArray> dataOptions;
141 for (
int i=4; i<nrhs; i+=2) {
142 string key(rhs[i].toString());
144 dataOptions = rhs[i+1].toVector<
MxArray>();
145 else if (key ==
"TestError")
148 mexErrMsgIdAndTxt(
"mexopencv:error",
149 "Unrecognized option %s", key.c_str());
154 dataOptions.begin(), dataOptions.end());
157 rhs[2].toMat(CV_32F),
158 rhs[3].toMat(rhs[3].isInt32() ? CV_32S : CV_32F),
159 dataOptions.begin(), dataOptions.end());
161 float err = obj->calcError(data, test, (nlhs>1 ? resp : noArray()));
166 else if (method ==
"predict") {
167 nargchk(nrhs>=3 && (nrhs%2)==1 && nlhs<=2);
169 for (
int i=3; i<nrhs; i+=2) {
170 string key(rhs[i].toString());
172 flags = rhs[i+1].toInt();
173 else if (key ==
"RawOutput")
174 UPDATE_FLAG(flags, rhs[i+1].toBool(), StatModel::RAW_OUTPUT);
175 else if (key ==
"CompressedInput")
176 UPDATE_FLAG(flags, rhs[i+1].toBool(), StatModel::COMPRESSED_INPUT);
177 else if (key ==
"PreprocessedInput")
178 UPDATE_FLAG(flags, rhs[i+1].toBool(), StatModel::PREPROCESSED_INPUT);
179 else if (key ==
"PredictAuto") {
181 UPDATE_FLAG(flags, !rhs[i+1].toBool(), DTrees::PREDICT_SUM);
182 UPDATE_FLAG(flags, !rhs[i+1].toBool(), DTrees::PREDICT_MAX_VOTE);
184 else if (key ==
"PredictSum")
185 UPDATE_FLAG(flags, rhs[i+1].toBool(), DTrees::PREDICT_SUM);
186 else if (key ==
"PredictMaxVote")
187 UPDATE_FLAG(flags, rhs[i+1].toBool(), DTrees::PREDICT_MAX_VOTE);
189 mexErrMsgIdAndTxt(
"mexopencv:error",
190 "Unrecognized option %s", key.c_str());
192 Mat samples(rhs[2].toMat(CV_32F)),
194 float f = obj->predict(samples, results, flags);
199 else if (method ==
"getNodes") {
201 plhs[0] =
toStruct(obj->getNodes());
203 else if (method ==
"getRoots") {
205 plhs[0] =
MxArray(obj->getRoots());
207 else if (method ==
"getSplits") {
209 plhs[0] =
toStruct(obj->getSplits());
211 else if (method ==
"getSubsets") {
213 plhs[0] =
MxArray(obj->getSubsets());
215 else if (method ==
"get") {
217 string prop(rhs[2].toString());
218 if (prop ==
"CVFolds")
219 plhs[0] =
MxArray(obj->getCVFolds());
220 else if (prop ==
"MaxCategories")
221 plhs[0] =
MxArray(obj->getMaxCategories());
222 else if (prop ==
"MaxDepth")
223 plhs[0] =
MxArray(obj->getMaxDepth());
224 else if (prop ==
"MinSampleCount")
225 plhs[0] =
MxArray(obj->getMinSampleCount());
226 else if (prop ==
"Priors")
227 plhs[0] =
MxArray(obj->getPriors());
228 else if (prop ==
"RegressionAccuracy")
229 plhs[0] =
MxArray(obj->getRegressionAccuracy());
230 else if (prop ==
"TruncatePrunedTree")
231 plhs[0] =
MxArray(obj->getTruncatePrunedTree());
232 else if (prop ==
"Use1SERule")
233 plhs[0] =
MxArray(obj->getUse1SERule());
234 else if (prop ==
"UseSurrogates")
235 plhs[0] =
MxArray(obj->getUseSurrogates());
237 mexErrMsgIdAndTxt(
"mexopencv:error",
238 "Unrecognized property %s", prop.c_str());
240 else if (method ==
"set") {
242 string prop(rhs[2].toString());
243 if (prop ==
"CVFolds")
244 obj->setCVFolds(rhs[3].toInt());
245 else if (prop ==
"MaxCategories")
246 obj->setMaxCategories(rhs[3].toInt());
247 else if (prop ==
"MaxDepth")
248 obj->setMaxDepth(rhs[3].toInt());
249 else if (prop ==
"MinSampleCount")
250 obj->setMinSampleCount(rhs[3].toInt());
251 else if (prop ==
"Priors")
252 obj->setPriors(rhs[3].toMat());
253 else if (prop ==
"RegressionAccuracy")
254 obj->setRegressionAccuracy(rhs[3].toFloat());
255 else if (prop ==
"TruncatePrunedTree")
256 obj->setTruncatePrunedTree(rhs[3].toBool());
257 else if (prop ==
"Use1SERule")
258 obj->setUse1SERule(rhs[3].toBool());
259 else if (prop ==
"UseSurrogates")
260 obj->setUseSurrogates(rhs[3].toBool());
262 mexErrMsgIdAndTxt(
"mexopencv:error",
263 "Unrecognized property %s", prop.c_str());
266 mexErrMsgIdAndTxt(
"mexopencv:error",
"Unrecognized operation");
#define UPDATE_FLAG(NUM, TF, BIT)
set or clear a bit in flag depending on bool value
MxArray toStruct(const std::vector< cv::ml::DTrees::Node > &nodes)
Convert tree nodes to struct array.
void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[])
Main entry called from Matlab.
int toInt() const
Convert MxArray to int.
cv::Ptr< cv::ml::TrainData > loadTrainData(const std::string &filename, std::vector< MxArray >::const_iterator first, std::vector< MxArray >::const_iterator last)
Read a dataset from a CSV file.
mxArray object wrapper for data conversion and manipulation.
void nargchk(bool cond)
Alias for input/ouput arguments number check.
bool toBool() const
Convert MxArray to bool.
Global constant definitions.
cv::Ptr< cv::ml::TrainData > createTrainData(const cv::Mat &samples, const cv::Mat &responses, std::vector< MxArray >::const_iterator first, std::vector< MxArray >::const_iterator last)
Create an instance of TrainData using options in arguments.
Common definitions for the ml module.