mexopencv  0.1
mex interface for opencv library
SVM_.cpp
Go to the documentation of this file.
1 
8 #include "mexopencv.hpp"
9 #include "mexopencv_ml.hpp"
10 using namespace std;
11 using namespace cv;
12 using namespace cv::ml;
13 
14 // Persistent objects
15 namespace {
17 int last_id = 0;
19 map<int,Ptr<SVM> > obj_;
20 
23  ("C_SVC", cv::ml::SVM::C_SVC)
24  ("NU_SVC", cv::ml::SVM::NU_SVC)
25  ("ONE_CLASS", cv::ml::SVM::ONE_CLASS)
26  ("EPS_SVR", cv::ml::SVM::EPS_SVR)
27  ("NU_SVR", cv::ml::SVM::NU_SVR);
28 
31  (cv::ml::SVM::C_SVC, "C_SVC")
32  (cv::ml::SVM::NU_SVC, "NU_SVC")
33  (cv::ml::SVM::ONE_CLASS, "ONE_CLASS")
34  (cv::ml::SVM::EPS_SVR, "EPS_SVR")
35  (cv::ml::SVM::NU_SVR, "NU_SVR");
36 
38 const ConstMap<string,int> SVMKernelType = ConstMap<string,int>
39  ("Custom", cv::ml::SVM::CUSTOM)
40  ("Linear", cv::ml::SVM::LINEAR)
41  ("Poly", cv::ml::SVM::POLY)
42  ("RBF", cv::ml::SVM::RBF)
43  ("Sigmoid", cv::ml::SVM::SIGMOID)
44  ("Chi2", cv::ml::SVM::CHI2)
45  ("Intersection", cv::ml::SVM::INTER);
46 
48 const ConstMap<int,string> InvSVMKernelType = ConstMap<int,string>
49  (cv::ml::SVM::CUSTOM, "Custom")
50  (cv::ml::SVM::LINEAR, "Linear")
51  (cv::ml::SVM::POLY, "Poly")
52  (cv::ml::SVM::RBF, "RBF")
53  (cv::ml::SVM::SIGMOID, "Sigmoid")
54  (cv::ml::SVM::CHI2, "Chi2")
55  (cv::ml::SVM::INTER, "Intersection");
56 
58 const ConstMap<string,int> SVMParamType = ConstMap<string,int>
59  ("C", cv::ml::SVM::C)
60  ("Gamma", cv::ml::SVM::GAMMA)
61  ("P", cv::ml::SVM::P)
62  ("Nu", cv::ml::SVM::NU)
63  ("Coef", cv::ml::SVM::COEF)
64  ("Degree", cv::ml::SVM::DEGREE);
65 
70 ParamGrid toParamGrid(const MxArray& m)
71 {
72  ParamGrid g;
73  if (m.isNumeric() && m.numel()==3) {
74  g.minVal = m.at<double>(0);
75  g.maxVal = m.at<double>(1);
76  g.logStep = m.at<double>(2);
77  }
78  else if (m.isStruct() && m.numel()==1) {
79  if (m.isField("minVal"))
80  g.minVal = m.at("minVal").toDouble();
81  if (m.isField("maxVal"))
82  g.maxVal = m.at("maxVal").toDouble();
83  if (m.isField("logStep"))
84  g.logStep = m.at("logStep").toDouble();
85  }
86  else if (m.isChar())
87  g = SVM::getDefaultGrid(SVMParamType[m.toString()]);
88  else
89  mexErrMsgIdAndTxt("mexopencv:error",
90  "Invalid argument to grid parameter");
91  // SVM::trainAuto permits setting step<=1 if we want to disable optimizing
92  // a certain paramter, in which case the value is taken from the props.
93  // Besides the check is done by function itself, so its not needed here.
94  //if (!g.check())
95  // mexErrMsgIdAndTxt("mexopencv:error","Invalid argument to grid parameter");
96  return g;
97 }
98 
101 class MatlabFunction : public cv::ml::SVM::Kernel
102 {
103 public:
107  explicit MatlabFunction(const string &func)
108  : fun_name(func)
109  {}
110 
122  void calc(int vcount, int n, const float* vecs,
123  const float* another, float* results)
124  {
125  // create input to evaluate kernel function
126  mxArray *lhs, *rhs[3];
127  rhs[0] = MxArray(fun_name);
128  rhs[1] = MxArray(Mat(vcount, n, CV_32F, const_cast<float*>(vecs)));
129  rhs[2] = MxArray(Mat(1, n, CV_32F, const_cast<float*>(another)));
130 
131  //TODO: mexCallMATLAB is not thread-safe!
132  // evaluate specified function in MATLAB as:
133  // results = feval("fun_name", vecs, another)
134  if (mexCallMATLAB(1, &lhs, 3, rhs, "feval") == 0) {
135  MxArray res(lhs);
136  CV_Assert(res.isSingle() && !res.isComplex() && res.ndims() == 2);
137  vector<float> v(res.toVector<float>());
138  CV_Assert(v.size() == vcount);
139  std::copy(v.begin(), v.end(), results);
140  }
141  else {
142  //TODO: error
143  std::fill(results, results + vcount, 0.0f);
144  }
145 
146  // cleanup
147  mxDestroyArray(lhs);
148  mxDestroyArray(rhs[0]);
149  mxDestroyArray(rhs[1]);
150  mxDestroyArray(rhs[2]);
151  }
152 
156  int getType() const
157  {
158  return cv::ml::SVM::CUSTOM;
159  }
160 
165  static Ptr<MatlabFunction> create(const string &func)
166  {
167  return makePtr<MatlabFunction>(func);
168  }
169 
170 private:
171  string fun_name;
172 };
173 }
174 
182 void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[])
183 {
184  // Check the number of arguments
185  nargchk(nrhs>=2 && nlhs<=3);
186 
187  // Argument vector
188  vector<MxArray> rhs(prhs, prhs+nrhs);
189  int id = rhs[0].toInt();
190  string method(rhs[1].toString());
191 
192  // Constructor is called. Create a new object from argument
193  if (method == "new") {
194  nargchk(nrhs==2 && nlhs<=1);
195  obj_[++last_id] = SVM::create();
196  plhs[0] = MxArray(last_id);
197  return;
198  }
199 
200  // Big operation switch
201  Ptr<SVM> obj = obj_[id];
202  if (method == "delete") {
203  nargchk(nrhs==2 && nlhs==0);
204  obj_.erase(id);
205  }
206  else if (method == "clear") {
207  nargchk(nrhs==2 && nlhs==0);
208  obj->clear();
209  }
210  else if (method == "load") {
211  nargchk(nrhs>=3 && (nrhs%2)==1 && nlhs==0);
212  string objname;
213  bool loadFromString = false;
214  for (int i=3; i<nrhs; i+=2) {
215  string key(rhs[i].toString());
216  if (key == "ObjName")
217  objname = rhs[i+1].toString();
218  else if (key == "FromString")
219  loadFromString = rhs[i+1].toBool();
220  else
221  mexErrMsgIdAndTxt("mexopencv:error",
222  "Unrecognized option %s", key.c_str());
223  }
224  obj_[id] = (loadFromString ?
225  Algorithm::loadFromString<SVM>(rhs[2].toString(), objname) :
226  Algorithm::load<SVM>(rhs[2].toString(), objname));
227  }
228  else if (method == "save") {
229  nargchk(nrhs==3 && nlhs<=1);
230  string fname(rhs[2].toString());
231  if (nlhs > 0) {
232  // write to memory, and return string
233  FileStorage fs(fname, FileStorage::WRITE + FileStorage::MEMORY);
234  fs << obj->getDefaultName() << "{";
235  fs << "format" << 3;
236  obj->write(fs);
237  fs << "}";
238  plhs[0] = MxArray(fs.releaseAndGetString());
239  }
240  else
241  // write to disk
242  obj->save(fname);
243  }
244  else if (method == "empty") {
245  nargchk(nrhs==2 && nlhs<=1);
246  plhs[0] = MxArray(obj->empty());
247  }
248  else if (method == "getDefaultName") {
249  nargchk(nrhs==2 && nlhs<=1);
250  plhs[0] = MxArray(obj->getDefaultName());
251  }
252  else if (method == "getVarCount") {
253  nargchk(nrhs==2 && nlhs<=1);
254  plhs[0] = MxArray(obj->getVarCount());
255  }
256  else if (method == "isClassifier") {
257  nargchk(nrhs==2 && nlhs<=1);
258  plhs[0] = MxArray(obj->isClassifier());
259  }
260  else if (method == "isTrained") {
261  nargchk(nrhs==2 && nlhs<=1);
262  plhs[0] = MxArray(obj->isTrained());
263  }
264  else if (method == "train") {
265  nargchk(nrhs>=4 && (nrhs%2)==0 && nlhs<=1);
266  vector<MxArray> dataOptions;
267  int flags = 0;
268  for (int i=4; i<nrhs; i+=2) {
269  string key(rhs[i].toString());
270  if (key == "Data")
271  dataOptions = rhs[i+1].toVector<MxArray>();
272  else if (key == "Flags")
273  flags = rhs[i+1].toInt();
274  else
275  mexErrMsgIdAndTxt("mexopencv:error",
276  "Unrecognized option %s", key.c_str());
277  }
278  Ptr<TrainData> data;
279  if (rhs[2].isChar())
280  data = loadTrainData(rhs[2].toString(),
281  dataOptions.begin(), dataOptions.end());
282  else
283  data = createTrainData(
284  rhs[2].toMat(CV_32F),
285  rhs[3].toMat(rhs[3].isInt32() ? CV_32S : CV_32F),
286  dataOptions.begin(), dataOptions.end());
287  bool b = obj->train(data, flags);
288  plhs[0] = MxArray(b);
289  }
290  else if (method == "calcError") {
291  nargchk(nrhs>=4 && (nrhs%2)==0 && nlhs<=2);
292  vector<MxArray> dataOptions;
293  bool test = false;
294  for (int i=4; i<nrhs; i+=2) {
295  string key(rhs[i].toString());
296  if (key == "Data")
297  dataOptions = rhs[i+1].toVector<MxArray>();
298  else if (key == "TestError")
299  test = rhs[i+1].toBool();
300  else
301  mexErrMsgIdAndTxt("mexopencv:error",
302  "Unrecognized option %s", key.c_str());
303  }
304  Ptr<TrainData> data;
305  if (rhs[2].isChar())
306  data = loadTrainData(rhs[2].toString(),
307  dataOptions.begin(), dataOptions.end());
308  else
309  data = createTrainData(
310  rhs[2].toMat(CV_32F),
311  rhs[3].toMat(rhs[3].isInt32() ? CV_32S : CV_32F),
312  dataOptions.begin(), dataOptions.end());
313  Mat resp;
314  float err = obj->calcError(data, test, (nlhs>1 ? resp : noArray()));
315  plhs[0] = MxArray(err);
316  if (nlhs>1)
317  plhs[1] = MxArray(resp);
318  }
319  else if (method == "predict") {
320  nargchk(nrhs>=3 && (nrhs%2)==1 && nlhs<=2);
321  int flags = 0;
322  for (int i=3; i<nrhs; i+=2) {
323  string key(rhs[i].toString());
324  if (key == "Flags")
325  flags = rhs[i+1].toInt();
326  else if (key == "RawOutput")
327  UPDATE_FLAG(flags, rhs[i+1].toBool(), StatModel::RAW_OUTPUT);
328  else
329  mexErrMsgIdAndTxt("mexopencv:error",
330  "Unrecognized option %s", key.c_str());
331  }
332  Mat samples(rhs[2].toMat(CV_32F)),
333  results;
334  float f = obj->predict(samples, results, flags);
335  plhs[0] = MxArray(results);
336  if (nlhs>1)
337  plhs[1] = MxArray(f);
338  }
339  else if (method == "trainAuto") {
340  nargchk(nrhs>=4 && (nrhs%2)==0 && nlhs<=1);
341  vector<MxArray> dataOptions;
342  int kFold = 10;
343  bool balanced = false;
344  ParamGrid CGrid = SVM::getDefaultGrid(SVM::C),
345  gammaGrid = SVM::getDefaultGrid(SVM::GAMMA),
346  pGrid = SVM::getDefaultGrid(SVM::P),
347  nuGrid = SVM::getDefaultGrid(SVM::NU),
348  coeffGrid = SVM::getDefaultGrid(SVM::COEF),
349  degreeGrid = SVM::getDefaultGrid(SVM::DEGREE);
350  for (int i=4; i<nrhs; i+=2) {
351  string key(rhs[i].toString());
352  if (key == "Data")
353  dataOptions = rhs[i+1].toVector<MxArray>();
354  else if (key == "KFold")
355  kFold = rhs[i+1].toInt();
356  else if (key == "Balanced")
357  balanced = rhs[i+1].toBool();
358  else if (key == "CGrid")
359  CGrid = toParamGrid(rhs[i+1]);
360  else if (key == "GammaGrid")
361  gammaGrid = toParamGrid(rhs[i+1]);
362  else if (key == "PGrid")
363  pGrid = toParamGrid(rhs[i+1]);
364  else if (key == "NuGrid")
365  nuGrid = toParamGrid(rhs[i+1]);
366  else if (key == "CoeffGrid")
367  coeffGrid = toParamGrid(rhs[i+1]);
368  else if (key == "DegreeGrid")
369  degreeGrid = toParamGrid(rhs[i+1]);
370  else
371  mexErrMsgIdAndTxt("mexopencv:error",
372  "Unrecognized option %s", key.c_str());
373  }
374  Ptr<TrainData> data;
375  if (rhs[2].isChar())
376  data = loadTrainData(rhs[2].toString(),
377  dataOptions.begin(), dataOptions.end());
378  else
379  data = createTrainData(
380  rhs[2].toMat(CV_32F),
381  rhs[3].toMat(rhs[3].isInt32() ? CV_32S : CV_32F),
382  dataOptions.begin(), dataOptions.end());
383  bool b = obj->trainAuto(data, kFold,
384  CGrid, gammaGrid, pGrid, nuGrid, coeffGrid, degreeGrid, balanced);
385  plhs[0] = MxArray(b);
386  }
387  else if (method == "getDecisionFunction") {
388  nargchk(nrhs==3 && nlhs<=3);
389  int index = rhs[2].toInt();
390  Mat alpha, svidx;
391  double rho = obj->getDecisionFunction(index, alpha, svidx);
392  plhs[0] = MxArray(alpha);
393  if (nlhs > 1)
394  plhs[1] = MxArray(svidx);
395  if (nlhs > 2)
396  plhs[2] = MxArray(rho);
397  }
398  else if (method == "getSupportVectors") {
399  nargchk(nrhs==2 && nlhs<=1);
400  plhs[0] = MxArray(obj->getSupportVectors());
401  }
402  else if (method == "getUncompressedSupportVectors") {
403  nargchk(nrhs==2 && nlhs<=1);
404  plhs[0] = MxArray(obj->getUncompressedSupportVectors());
405  }
406  else if (method == "setCustomKernel") {
407  nargchk(nrhs==3 && nlhs==0);
408  obj->setCustomKernel(MatlabFunction::create(rhs[2].toString()));
409  }
410  else if (method == "get") {
411  nargchk(nrhs==3 && nlhs<=1);
412  string prop(rhs[2].toString());
413  if (prop == "Type")
414  plhs[0] = MxArray(InvSVMType[obj->getType()]);
415  else if (prop == "KernelType")
416  plhs[0] = MxArray(InvSVMKernelType[obj->getKernelType()]);
417  else if (prop == "Degree")
418  plhs[0] = MxArray(obj->getDegree());
419  else if (prop == "Gamma")
420  plhs[0] = MxArray(obj->getGamma());
421  else if (prop == "Coef0")
422  plhs[0] = MxArray(obj->getCoef0());
423  else if (prop == "C")
424  plhs[0] = MxArray(obj->getC());
425  else if (prop == "Nu")
426  plhs[0] = MxArray(obj->getNu());
427  else if (prop == "P")
428  plhs[0] = MxArray(obj->getP());
429  else if (prop == "ClassWeights")
430  plhs[0] = MxArray(obj->getClassWeights());
431  else if (prop == "TermCriteria")
432  plhs[0] = MxArray(obj->getTermCriteria());
433  else
434  mexErrMsgIdAndTxt("mexopencv:error",
435  "Unrecognized property %s", prop.c_str());
436  }
437  else if (method == "set") {
438  nargchk(nrhs==4 && nlhs==0);
439  string prop(rhs[2].toString());
440  if (prop == "Type")
441  obj->setType(SVMType[rhs[3].toString()]);
442  else if (prop == "KernelType")
443  obj->setKernel(SVMKernelType[rhs[3].toString()]);
444  else if (prop == "Degree")
445  obj->setDegree(rhs[3].toDouble());
446  else if (prop == "Gamma")
447  obj->setGamma(rhs[3].toDouble());
448  else if (prop == "Coef0")
449  obj->setCoef0(rhs[3].toDouble());
450  else if (prop == "C")
451  obj->setC(rhs[3].toDouble());
452  else if (prop == "Nu")
453  obj->setNu(rhs[3].toDouble());
454  else if (prop == "P")
455  obj->setP(rhs[3].toDouble());
456  else if (prop == "ClassWeights")
457  obj->setClassWeights(rhs[3].toMat());
458  else if (prop == "TermCriteria")
459  obj->setTermCriteria(rhs[3].toTermCriteria());
460  else
461  mexErrMsgIdAndTxt("mexopencv:error",
462  "Unrecognized property %s", prop.c_str());
463  }
464  else
465  mexErrMsgIdAndTxt("mexopencv:error","Unrecognized operation");
466 }
#define UPDATE_FLAG(NUM, TF, BIT)
set or clear a bit in flag depending on bool value
Definition: mexopencv.hpp:159
bool isStruct() const
Determine whether input is structure array.
Definition: MxArray.hpp:708
bool isChar() const
Determine whether input is string array.
Definition: MxArray.hpp:614
int toInt() const
Convert MxArray to int.
Definition: MxArray.cpp:489
cv::Ptr< cv::ml::TrainData > loadTrainData(const std::string &filename, std::vector< MxArray >::const_iterator first, std::vector< MxArray >::const_iterator last)
Read a dataset from a CSV file.
mxArray object wrapper for data conversion and manipulation.
Definition: MxArray.hpp:123
void nargchk(bool cond)
Alias for input/ouput arguments number check.
Definition: mexopencv.hpp:166
bool toBool() const
Convert MxArray to bool.
Definition: MxArray.cpp:510
bool isField(const std::string &fieldName) const
Determine whether a struct array has a specified field.
Definition: MxArray.hpp:743
Global constant definitions.
mwSize numel() const
Number of elements in an array.
Definition: MxArray.hpp:546
T at(mwIndex index) const
Template for numeric array element accessor.
Definition: MxArray.hpp:1250
cv::Ptr< cv::ml::TrainData > createTrainData(const cv::Mat &samples, const cv::Mat &responses, std::vector< MxArray >::const_iterator first, std::vector< MxArray >::const_iterator last)
Create an instance of TrainData using options in arguments.
std::map wrapper with one-line initialization and lookup method.
Definition: MxArray.hpp:927
std::string toString() const
Convert MxArray to std::string.
Definition: MxArray.cpp:517
void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[])
Main entry called from Matlab.
Definition: SVM_.cpp:182
Common definitions for the ml module.
bool isNumeric() const
Determine whether array is numeric.
Definition: MxArray.hpp:695