19 map<int,Ptr<SVM> > obj_;
23 (
"C_SVC", cv::ml::SVM::C_SVC)
24 (
"NU_SVC", cv::ml::SVM::NU_SVC)
25 (
"ONE_CLASS", cv::ml::SVM::ONE_CLASS)
26 (
"EPS_SVR", cv::ml::SVM::EPS_SVR)
27 (
"NU_SVR", cv::ml::SVM::NU_SVR);
31 (cv::ml::SVM::C_SVC,
"C_SVC")
32 (cv::ml::SVM::NU_SVC,
"NU_SVC")
33 (cv::ml::SVM::ONE_CLASS,
"ONE_CLASS")
34 (cv::ml::SVM::EPS_SVR,
"EPS_SVR")
35 (cv::ml::SVM::NU_SVR,
"NU_SVR");
39 (
"Custom", cv::ml::SVM::CUSTOM)
40 (
"Linear", cv::ml::SVM::LINEAR)
41 (
"Poly", cv::ml::SVM::POLY)
42 (
"RBF", cv::ml::SVM::RBF)
43 (
"Sigmoid", cv::ml::SVM::SIGMOID)
44 (
"Chi2", cv::ml::SVM::CHI2)
45 (
"Intersection", cv::ml::SVM::INTER);
49 (cv::ml::SVM::CUSTOM,
"Custom")
50 (cv::ml::SVM::LINEAR,
"Linear")
51 (cv::ml::SVM::POLY,
"Poly")
52 (cv::ml::SVM::RBF,
"RBF")
53 (cv::ml::SVM::SIGMOID,
"Sigmoid")
54 (cv::ml::SVM::CHI2,
"Chi2")
55 (cv::ml::SVM::INTER,
"Intersection");
60 (
"Gamma", cv::ml::SVM::GAMMA)
62 (
"Nu", cv::ml::SVM::NU)
63 (
"Coef", cv::ml::SVM::COEF)
64 (
"Degree", cv::ml::SVM::DEGREE);
70 ParamGrid toParamGrid(
const MxArray& m)
74 g.minVal = m.
at<
double>(0);
75 g.maxVal = m.
at<
double>(1);
76 g.logStep = m.
at<
double>(2);
80 g.minVal = m.
at(
"minVal").toDouble();
82 g.maxVal = m.
at(
"maxVal").toDouble();
84 g.logStep = m.
at(
"logStep").toDouble();
87 g = SVM::getDefaultGrid(SVMParamType[m.
toString()]);
89 mexErrMsgIdAndTxt(
"mexopencv:error",
90 "Invalid argument to grid parameter");
101 class MatlabFunction :
public cv::ml::SVM::Kernel
107 explicit MatlabFunction(
const string &func)
122 void calc(
int vcount,
int n,
const float* vecs,
123 const float* another,
float* results)
126 mxArray *lhs, *rhs[3];
128 rhs[1] =
MxArray(Mat(vcount, n, CV_32F, const_cast<float*>(vecs)));
129 rhs[2] =
MxArray(Mat(1, n, CV_32F, const_cast<float*>(another)));
134 if (mexCallMATLAB(1, &lhs, 3, rhs,
"feval") == 0) {
136 CV_Assert(res.isSingle() && !res.isComplex() && res.ndims() == 2);
137 vector<float> v(res.toVector<
float>());
138 CV_Assert(v.size() == vcount);
139 std::copy(v.begin(), v.end(), results);
143 std::fill(results, results + vcount, 0.0f);
148 mxDestroyArray(rhs[0]);
149 mxDestroyArray(rhs[1]);
150 mxDestroyArray(rhs[2]);
158 return cv::ml::SVM::CUSTOM;
165 static Ptr<MatlabFunction> create(
const string &func)
167 return makePtr<MatlabFunction>(func);
182 void mexFunction(
int nlhs, mxArray *plhs[],
int nrhs,
const mxArray *prhs[])
188 vector<MxArray> rhs(prhs, prhs+nrhs);
189 int id = rhs[0].toInt();
190 string method(rhs[1].toString());
193 if (method ==
"new") {
195 obj_[++last_id] = SVM::create();
201 Ptr<SVM> obj = obj_[id];
202 if (method ==
"delete") {
206 else if (method ==
"clear") {
210 else if (method ==
"load") {
211 nargchk(nrhs>=3 && (nrhs%2)==1 && nlhs==0);
213 bool loadFromString =
false;
214 for (
int i=3; i<nrhs; i+=2) {
215 string key(rhs[i].toString());
216 if (key ==
"ObjName")
217 objname = rhs[i+1].toString();
218 else if (key ==
"FromString")
219 loadFromString = rhs[i+1].toBool();
221 mexErrMsgIdAndTxt(
"mexopencv:error",
222 "Unrecognized option %s", key.c_str());
224 obj_[id] = (loadFromString ?
225 Algorithm::loadFromString<SVM>(rhs[2].toString(), objname) :
226 Algorithm::load<SVM>(rhs[2].toString(), objname));
228 else if (method ==
"save") {
230 string fname(rhs[2].toString());
233 FileStorage fs(fname, FileStorage::WRITE + FileStorage::MEMORY);
234 fs << obj->getDefaultName() <<
"{";
238 plhs[0] =
MxArray(fs.releaseAndGetString());
244 else if (method ==
"empty") {
246 plhs[0] =
MxArray(obj->empty());
248 else if (method ==
"getDefaultName") {
250 plhs[0] =
MxArray(obj->getDefaultName());
252 else if (method ==
"getVarCount") {
254 plhs[0] =
MxArray(obj->getVarCount());
256 else if (method ==
"isClassifier") {
258 plhs[0] =
MxArray(obj->isClassifier());
260 else if (method ==
"isTrained") {
262 plhs[0] =
MxArray(obj->isTrained());
264 else if (method ==
"train") {
265 nargchk(nrhs>=4 && (nrhs%2)==0 && nlhs<=1);
266 vector<MxArray> dataOptions;
268 for (
int i=4; i<nrhs; i+=2) {
269 string key(rhs[i].toString());
271 dataOptions = rhs[i+1].toVector<
MxArray>();
272 else if (key ==
"Flags")
273 flags = rhs[i+1].
toInt();
275 mexErrMsgIdAndTxt(
"mexopencv:error",
276 "Unrecognized option %s", key.c_str());
281 dataOptions.begin(), dataOptions.end());
284 rhs[2].toMat(CV_32F),
285 rhs[3].toMat(rhs[3].isInt32() ? CV_32S : CV_32F),
286 dataOptions.begin(), dataOptions.end());
287 bool b = obj->train(data, flags);
290 else if (method ==
"calcError") {
291 nargchk(nrhs>=4 && (nrhs%2)==0 && nlhs<=2);
292 vector<MxArray> dataOptions;
294 for (
int i=4; i<nrhs; i+=2) {
295 string key(rhs[i].toString());
297 dataOptions = rhs[i+1].toVector<
MxArray>();
298 else if (key ==
"TestError")
301 mexErrMsgIdAndTxt(
"mexopencv:error",
302 "Unrecognized option %s", key.c_str());
307 dataOptions.begin(), dataOptions.end());
310 rhs[2].toMat(CV_32F),
311 rhs[3].toMat(rhs[3].isInt32() ? CV_32S : CV_32F),
312 dataOptions.begin(), dataOptions.end());
314 float err = obj->calcError(data, test, (nlhs>1 ? resp : noArray()));
319 else if (method ==
"predict") {
320 nargchk(nrhs>=3 && (nrhs%2)==1 && nlhs<=2);
322 for (
int i=3; i<nrhs; i+=2) {
323 string key(rhs[i].toString());
325 flags = rhs[i+1].toInt();
326 else if (key ==
"RawOutput")
327 UPDATE_FLAG(flags, rhs[i+1].toBool(), StatModel::RAW_OUTPUT);
329 mexErrMsgIdAndTxt(
"mexopencv:error",
330 "Unrecognized option %s", key.c_str());
332 Mat samples(rhs[2].toMat(CV_32F)),
334 float f = obj->predict(samples, results, flags);
339 else if (method ==
"trainAuto") {
340 nargchk(nrhs>=4 && (nrhs%2)==0 && nlhs<=1);
341 vector<MxArray> dataOptions;
343 bool balanced =
false;
344 ParamGrid CGrid = SVM::getDefaultGrid(SVM::C),
345 gammaGrid = SVM::getDefaultGrid(SVM::GAMMA),
346 pGrid = SVM::getDefaultGrid(SVM::P),
347 nuGrid = SVM::getDefaultGrid(SVM::NU),
348 coeffGrid = SVM::getDefaultGrid(SVM::COEF),
349 degreeGrid = SVM::getDefaultGrid(SVM::DEGREE);
350 for (
int i=4; i<nrhs; i+=2) {
351 string key(rhs[i].toString());
353 dataOptions = rhs[i+1].toVector<
MxArray>();
354 else if (key ==
"KFold")
355 kFold = rhs[i+1].
toInt();
356 else if (key ==
"Balanced")
357 balanced = rhs[i+1].toBool();
358 else if (key ==
"CGrid")
359 CGrid = toParamGrid(rhs[i+1]);
360 else if (key ==
"GammaGrid")
361 gammaGrid = toParamGrid(rhs[i+1]);
362 else if (key ==
"PGrid")
363 pGrid = toParamGrid(rhs[i+1]);
364 else if (key ==
"NuGrid")
365 nuGrid = toParamGrid(rhs[i+1]);
366 else if (key ==
"CoeffGrid")
367 coeffGrid = toParamGrid(rhs[i+1]);
368 else if (key ==
"DegreeGrid")
369 degreeGrid = toParamGrid(rhs[i+1]);
371 mexErrMsgIdAndTxt(
"mexopencv:error",
372 "Unrecognized option %s", key.c_str());
377 dataOptions.begin(), dataOptions.end());
380 rhs[2].toMat(CV_32F),
381 rhs[3].toMat(rhs[3].isInt32() ? CV_32S : CV_32F),
382 dataOptions.begin(), dataOptions.end());
383 bool b = obj->trainAuto(data, kFold,
384 CGrid, gammaGrid, pGrid, nuGrid, coeffGrid, degreeGrid, balanced);
387 else if (method ==
"getDecisionFunction") {
389 int index = rhs[2].toInt();
391 double rho = obj->getDecisionFunction(index, alpha, svidx);
398 else if (method ==
"getSupportVectors") {
400 plhs[0] =
MxArray(obj->getSupportVectors());
402 else if (method ==
"getUncompressedSupportVectors") {
404 plhs[0] =
MxArray(obj->getUncompressedSupportVectors());
406 else if (method ==
"setCustomKernel") {
408 obj->setCustomKernel(MatlabFunction::create(rhs[2].toString()));
410 else if (method ==
"get") {
412 string prop(rhs[2].toString());
414 plhs[0] =
MxArray(InvSVMType[obj->getType()]);
415 else if (prop ==
"KernelType")
416 plhs[0] =
MxArray(InvSVMKernelType[obj->getKernelType()]);
417 else if (prop ==
"Degree")
418 plhs[0] =
MxArray(obj->getDegree());
419 else if (prop ==
"Gamma")
420 plhs[0] =
MxArray(obj->getGamma());
421 else if (prop ==
"Coef0")
422 plhs[0] =
MxArray(obj->getCoef0());
423 else if (prop ==
"C")
424 plhs[0] =
MxArray(obj->getC());
425 else if (prop ==
"Nu")
426 plhs[0] =
MxArray(obj->getNu());
427 else if (prop ==
"P")
428 plhs[0] =
MxArray(obj->getP());
429 else if (prop ==
"ClassWeights")
430 plhs[0] =
MxArray(obj->getClassWeights());
431 else if (prop ==
"TermCriteria")
432 plhs[0] =
MxArray(obj->getTermCriteria());
434 mexErrMsgIdAndTxt(
"mexopencv:error",
435 "Unrecognized property %s", prop.c_str());
437 else if (method ==
"set") {
439 string prop(rhs[2].toString());
441 obj->setType(SVMType[rhs[3].toString()]);
442 else if (prop ==
"KernelType")
443 obj->setKernel(SVMKernelType[rhs[3].toString()]);
444 else if (prop ==
"Degree")
445 obj->setDegree(rhs[3].toDouble());
446 else if (prop ==
"Gamma")
447 obj->setGamma(rhs[3].toDouble());
448 else if (prop ==
"Coef0")
449 obj->setCoef0(rhs[3].toDouble());
450 else if (prop ==
"C")
451 obj->setC(rhs[3].toDouble());
452 else if (prop ==
"Nu")
453 obj->setNu(rhs[3].toDouble());
454 else if (prop ==
"P")
455 obj->setP(rhs[3].toDouble());
456 else if (prop ==
"ClassWeights")
457 obj->setClassWeights(rhs[3].toMat());
458 else if (prop ==
"TermCriteria")
459 obj->setTermCriteria(rhs[3].toTermCriteria());
461 mexErrMsgIdAndTxt(
"mexopencv:error",
462 "Unrecognized property %s", prop.c_str());
465 mexErrMsgIdAndTxt(
"mexopencv:error",
"Unrecognized operation");
#define UPDATE_FLAG(NUM, TF, BIT)
set or clear a bit in flag depending on bool value
bool isStruct() const
Determine whether input is structure array.
bool isChar() const
Determine whether input is string array.
int toInt() const
Convert MxArray to int.
cv::Ptr< cv::ml::TrainData > loadTrainData(const std::string &filename, std::vector< MxArray >::const_iterator first, std::vector< MxArray >::const_iterator last)
Read a dataset from a CSV file.
mxArray object wrapper for data conversion and manipulation.
void nargchk(bool cond)
Alias for input/ouput arguments number check.
bool toBool() const
Convert MxArray to bool.
bool isField(const std::string &fieldName) const
Determine whether a struct array has a specified field.
Global constant definitions.
mwSize numel() const
Number of elements in an array.
T at(mwIndex index) const
Template for numeric array element accessor.
cv::Ptr< cv::ml::TrainData > createTrainData(const cv::Mat &samples, const cv::Mat &responses, std::vector< MxArray >::const_iterator first, std::vector< MxArray >::const_iterator last)
Create an instance of TrainData using options in arguments.
std::map wrapper with one-line initialization and lookup method.
std::string toString() const
Convert MxArray to std::string.
void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[])
Main entry called from Matlab.
Common definitions for the ml module.
bool isNumeric() const
Determine whether array is numeric.