19 map<int,Ptr<Boost> > obj_;
23 (
"Discrete", cv::ml::Boost::DISCRETE)
24 (
"Real", cv::ml::Boost::REAL)
25 (
"Logit", cv::ml::Boost::LOGIT)
26 (
"Gentle", cv::ml::Boost::GENTLE);
30 (cv::ml::Boost::DISCRETE,
"Discrete")
31 (cv::ml::Boost::REAL,
"Real")
32 (cv::ml::Boost::LOGIT,
"Logit")
33 (cv::ml::Boost::GENTLE,
"Gentle");
43 void mexFunction(
int nlhs, mxArray *plhs[],
int nrhs,
const mxArray *prhs[])
49 vector<MxArray> rhs(prhs, prhs+nrhs);
50 int id = rhs[0].toInt();
51 string method(rhs[1].toString());
54 if (method ==
"new") {
56 obj_[++last_id] = Boost::create();
62 Ptr<Boost> obj = obj_[id];
63 if (method ==
"delete") {
67 else if (method ==
"clear") {
71 else if (method ==
"load") {
72 nargchk(nrhs>=3 && (nrhs%2)==1 && nlhs==0);
74 bool loadFromString =
false;
75 for (
int i=3; i<nrhs; i+=2) {
76 string key(rhs[i].toString());
78 objname = rhs[i+1].toString();
79 else if (key ==
"FromString")
80 loadFromString = rhs[i+1].toBool();
82 mexErrMsgIdAndTxt(
"mexopencv:error",
83 "Unrecognized option %s", key.c_str());
85 obj_[id] = (loadFromString ?
86 Algorithm::loadFromString<Boost>(rhs[2].toString(), objname) :
87 Algorithm::load<Boost>(rhs[2].toString(), objname));
89 else if (method ==
"save") {
91 string fname(rhs[2].toString());
94 FileStorage fs(fname, FileStorage::WRITE + FileStorage::MEMORY);
95 fs << obj->getDefaultName() <<
"{";
99 plhs[0] =
MxArray(fs.releaseAndGetString());
105 else if (method ==
"empty") {
107 plhs[0] =
MxArray(obj->empty());
109 else if (method ==
"getDefaultName") {
111 plhs[0] =
MxArray(obj->getDefaultName());
113 else if (method ==
"getVarCount") {
115 plhs[0] =
MxArray(obj->getVarCount());
117 else if (method ==
"isClassifier") {
119 plhs[0] =
MxArray(obj->isClassifier());
121 else if (method ==
"isTrained") {
123 plhs[0] =
MxArray(obj->isTrained());
125 else if (method ==
"train") {
126 nargchk(nrhs>=4 && (nrhs%2)==0 && nlhs<=1);
127 vector<MxArray> dataOptions;
129 for (
int i=4; i<nrhs; i+=2) {
130 string key(rhs[i].toString());
132 dataOptions = rhs[i+1].toVector<
MxArray>();
133 else if (key ==
"Flags")
134 flags = rhs[i+1].
toInt();
135 else if (key ==
"RawOutput")
136 UPDATE_FLAG(flags, rhs[i+1].toBool(), StatModel::RAW_OUTPUT);
137 else if (key ==
"CompressedInput")
138 UPDATE_FLAG(flags, rhs[i+1].toBool(), StatModel::COMPRESSED_INPUT);
139 else if (key ==
"PredictSum")
140 UPDATE_FLAG(flags, rhs[i+1].toBool(), DTrees::PREDICT_SUM);
141 else if (key ==
"PredictMaxVote")
142 UPDATE_FLAG(flags, rhs[i+1].toBool(), DTrees::PREDICT_MAX_VOTE);
144 mexErrMsgIdAndTxt(
"mexopencv:error",
145 "Unrecognized option %s", key.c_str());
150 dataOptions.begin(), dataOptions.end());
153 rhs[2].toMat(CV_32F),
154 rhs[3].toMat(rhs[3].isInt32() ? CV_32S : CV_32F),
155 dataOptions.begin(), dataOptions.end());
156 bool b = obj->train(data, flags);
159 else if (method ==
"calcError") {
160 nargchk(nrhs>=4 && (nrhs%2)==0 && nlhs<=2);
161 vector<MxArray> dataOptions;
163 for (
int i=4; i<nrhs; i+=2) {
164 string key(rhs[i].toString());
166 dataOptions = rhs[i+1].toVector<
MxArray>();
167 else if (key ==
"TestError")
170 mexErrMsgIdAndTxt(
"mexopencv:error",
171 "Unrecognized option %s", key.c_str());
176 dataOptions.begin(), dataOptions.end());
179 rhs[2].toMat(CV_32F),
180 rhs[3].toMat(rhs[3].isInt32() ? CV_32S : CV_32F),
181 dataOptions.begin(), dataOptions.end());
183 float err = obj->calcError(data, test, (nlhs>1 ? resp : noArray()));
188 else if (method ==
"predict") {
189 nargchk(nrhs>=3 && (nrhs%2)==1 && nlhs<=2);
191 for (
int i=3; i<nrhs; i+=2) {
192 string key(rhs[i].toString());
194 flags = rhs[i+1].toInt();
195 else if (key ==
"RawOutput")
196 UPDATE_FLAG(flags, rhs[i+1].toBool(), StatModel::RAW_OUTPUT);
197 else if (key ==
"CompressedInput")
198 UPDATE_FLAG(flags, rhs[i+1].toBool(), StatModel::COMPRESSED_INPUT);
199 else if (key ==
"PreprocessedInput")
200 UPDATE_FLAG(flags, rhs[i+1].toBool(), StatModel::PREPROCESSED_INPUT);
201 else if (key ==
"PredictAuto") {
203 UPDATE_FLAG(flags, !rhs[i+1].toBool(), DTrees::PREDICT_SUM);
204 UPDATE_FLAG(flags, !rhs[i+1].toBool(), DTrees::PREDICT_MAX_VOTE);
206 else if (key ==
"PredictSum")
207 UPDATE_FLAG(flags, rhs[i+1].toBool(), DTrees::PREDICT_SUM);
208 else if (key ==
"PredictMaxVote")
209 UPDATE_FLAG(flags, rhs[i+1].toBool(), DTrees::PREDICT_MAX_VOTE);
211 mexErrMsgIdAndTxt(
"mexopencv:error",
212 "Unrecognized option %s", key.c_str());
214 Mat samples(rhs[2].toMat(CV_32F)),
216 float f = obj->predict(samples, results, flags);
221 else if (method ==
"getNodes") {
223 plhs[0] =
toStruct(obj->getNodes());
225 else if (method ==
"getRoots") {
227 plhs[0] =
MxArray(obj->getRoots());
229 else if (method ==
"getSplits") {
231 plhs[0] =
toStruct(obj->getSplits());
233 else if (method ==
"getSubsets") {
235 plhs[0] =
MxArray(obj->getSubsets());
237 else if (method ==
"get") {
239 string prop(rhs[2].toString());
240 if (prop ==
"CVFolds")
241 plhs[0] =
MxArray(obj->getCVFolds());
242 else if (prop ==
"MaxCategories")
243 plhs[0] =
MxArray(obj->getMaxCategories());
244 else if (prop ==
"MaxDepth")
245 plhs[0] =
MxArray(obj->getMaxDepth());
246 else if (prop ==
"MinSampleCount")
247 plhs[0] =
MxArray(obj->getMinSampleCount());
248 else if (prop ==
"Priors")
249 plhs[0] =
MxArray(obj->getPriors());
250 else if (prop ==
"RegressionAccuracy")
251 plhs[0] =
MxArray(obj->getRegressionAccuracy());
252 else if (prop ==
"TruncatePrunedTree")
253 plhs[0] =
MxArray(obj->getTruncatePrunedTree());
254 else if (prop ==
"Use1SERule")
255 plhs[0] =
MxArray(obj->getUse1SERule());
256 else if (prop ==
"UseSurrogates")
257 plhs[0] =
MxArray(obj->getUseSurrogates());
258 else if (prop ==
"BoostType")
259 plhs[0] =
MxArray(InvBoostType[obj->getBoostType()]);
260 else if (prop ==
"WeakCount")
261 plhs[0] =
MxArray(obj->getWeakCount());
262 else if (prop ==
"WeightTrimRate")
263 plhs[0] =
MxArray(obj->getWeightTrimRate());
265 mexErrMsgIdAndTxt(
"mexopencv:error",
266 "Unrecognized property %s", prop.c_str());
268 else if (method ==
"set") {
270 string prop(rhs[2].toString());
271 if (prop ==
"CVFolds")
272 obj->setCVFolds(rhs[3].toInt());
273 else if (prop ==
"MaxCategories")
274 obj->setMaxCategories(rhs[3].toInt());
275 else if (prop ==
"MaxDepth")
276 obj->setMaxDepth(rhs[3].toInt());
277 else if (prop ==
"MinSampleCount")
278 obj->setMinSampleCount(rhs[3].toInt());
279 else if (prop ==
"Priors")
280 obj->setPriors(rhs[3].toMat());
281 else if (prop ==
"RegressionAccuracy")
282 obj->setRegressionAccuracy(rhs[3].toFloat());
283 else if (prop ==
"TruncatePrunedTree")
284 obj->setTruncatePrunedTree(rhs[3].toBool());
285 else if (prop ==
"Use1SERule")
286 obj->setUse1SERule(rhs[3].toBool());
287 else if (prop ==
"UseSurrogates")
288 obj->setUseSurrogates(rhs[3].toBool());
289 else if (prop ==
"BoostType")
290 obj->setBoostType(BoostType[rhs[3].toString()]);
291 else if (prop ==
"WeakCount")
292 obj->setWeakCount(rhs[3].toInt());
293 else if (prop ==
"WeightTrimRate")
294 obj->setWeightTrimRate(rhs[3].toDouble());
296 mexErrMsgIdAndTxt(
"mexopencv:error",
297 "Unrecognized property %s", prop.c_str());
300 mexErrMsgIdAndTxt(
"mexopencv:error",
"Unrecognized operation");
#define UPDATE_FLAG(NUM, TF, BIT)
set or clear a bit in flag depending on bool value
MxArray toStruct(const std::vector< cv::ml::DTrees::Node > &nodes)
Convert tree nodes to struct array.
int toInt() const
Convert MxArray to int.
cv::Ptr< cv::ml::TrainData > loadTrainData(const std::string &filename, std::vector< MxArray >::const_iterator first, std::vector< MxArray >::const_iterator last)
Read a dataset from a CSV file.
mxArray object wrapper for data conversion and manipulation.
void nargchk(bool cond)
Alias for input/ouput arguments number check.
bool toBool() const
Convert MxArray to bool.
void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[])
Main entry called from Matlab.
Global constant definitions.
cv::Ptr< cv::ml::TrainData > createTrainData(const cv::Mat &samples, const cv::Mat &responses, std::vector< MxArray >::const_iterator first, std::vector< MxArray >::const_iterator last)
Create an instance of TrainData using options in arguments.
std::map wrapper with one-line initialization and lookup method.
Common definitions for the ml module.