mexopencv  0.1
mex interface for opencv library
KNearest_.cpp
Go to the documentation of this file.
1 
8 #include "mexopencv.hpp"
9 #include "mexopencv_ml.hpp"
10 using namespace std;
11 using namespace cv;
12 using namespace cv::ml;
13 
14 // Persistent objects
15 namespace {
17 int last_id = 0;
19 map<int,Ptr<KNearest> > obj_;
20 
23  ("BruteForce", KNearest::BRUTE_FORCE)
24  ("KDTree", KNearest::KDTREE);
25 
28  (KNearest::BRUTE_FORCE, "BruteForce")
29  (KNearest::KDTREE, "KDTree");
30 }
31 
39 void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[])
40 {
41  // Check the number of arguments
42  nargchk(nrhs>=2 && nlhs<=4);
43 
44  // Argument vector
45  vector<MxArray> rhs(prhs, prhs+nrhs);
46  int id = rhs[0].toInt();
47  string method(rhs[1].toString());
48 
49  // Constructor is called. Create a new object from argument
50  if (method == "new") {
51  nargchk(nrhs==2 && nlhs<=1);
52  obj_[++last_id] = KNearest::create();
53  plhs[0] = MxArray(last_id);
54  return;
55  }
56 
57  // Big operation switch
58  Ptr<KNearest> obj = obj_[id];
59  if (method == "delete") {
60  nargchk(nrhs==2 && nlhs==0);
61  obj_.erase(id);
62  }
63  else if (method == "clear") {
64  nargchk(nrhs==2 && nlhs==0);
65  obj->clear();
66  }
67  else if (method == "load") {
68  nargchk(nrhs>=3 && (nrhs%2)==1 && nlhs==0);
69  string objname;
70  bool loadFromString = false;
71  for (int i=3; i<nrhs; i+=2) {
72  string key(rhs[i].toString());
73  if (key == "ObjName")
74  objname = rhs[i+1].toString();
75  else if (key == "FromString")
76  loadFromString = rhs[i+1].toBool();
77  else
78  mexErrMsgIdAndTxt("mexopencv:error",
79  "Unrecognized option %s", key.c_str());
80  }
81  obj_[id] = (loadFromString ?
82  Algorithm::loadFromString<KNearest>(rhs[2].toString(), objname) :
83  Algorithm::load<KNearest>(rhs[2].toString(), objname));
84  }
85  else if (method == "save") {
86  nargchk(nrhs==3 && nlhs<=1);
87  string fname(rhs[2].toString());
88  if (nlhs > 0) {
89  // write to memory, and return string
90  FileStorage fs(fname, FileStorage::WRITE + FileStorage::MEMORY);
91  fs << obj->getDefaultName() << "{";
92  fs << "format" << 3;
93  obj->write(fs);
94  fs << "}";
95  plhs[0] = MxArray(fs.releaseAndGetString());
96  }
97  else
98  // write to disk
99  obj->save(fname);
100  }
101  else if (method == "empty") {
102  nargchk(nrhs==2 && nlhs<=1);
103  plhs[0] = MxArray(obj->empty());
104  }
105  else if (method == "getDefaultName") {
106  nargchk(nrhs==2 && nlhs<=1);
107  plhs[0] = MxArray(obj->getDefaultName());
108  }
109  else if (method == "getVarCount") {
110  nargchk(nrhs==2 && nlhs<=1);
111  plhs[0] = MxArray(obj->getVarCount());
112  }
113  else if (method == "isClassifier") {
114  nargchk(nrhs==2 && nlhs<=1);
115  plhs[0] = MxArray(obj->isClassifier());
116  }
117  else if (method == "isTrained") {
118  nargchk(nrhs==2 && nlhs<=1);
119  plhs[0] = MxArray(obj->isTrained());
120  }
121  else if (method == "train") {
122  nargchk(nrhs>=4 && (nrhs%2)==0 && nlhs<=1);
123  vector<MxArray> dataOptions;
124  int flags = 0;
125  for (int i=4; i<nrhs; i+=2) {
126  string key(rhs[i].toString());
127  if (key == "Data")
128  dataOptions = rhs[i+1].toVector<MxArray>();
129  else if (key == "Flags")
130  flags = rhs[i+1].toInt();
131  else if (key == "UpdateModel")
132  UPDATE_FLAG(flags, rhs[i+1].toBool(), StatModel::UPDATE_MODEL);
133  else
134  mexErrMsgIdAndTxt("mexopencv:error",
135  "Unrecognized option %s", key.c_str());
136  }
137  Ptr<TrainData> data;
138  if (rhs[2].isChar())
139  data = loadTrainData(rhs[2].toString(),
140  dataOptions.begin(), dataOptions.end());
141  else
142  data = createTrainData(
143  rhs[2].toMat(CV_32F),
144  rhs[3].toMat(rhs[3].isInt32() ? CV_32S : CV_32F),
145  dataOptions.begin(), dataOptions.end());
146  bool b = obj->train(data, flags);
147  plhs[0] = MxArray(b);
148  }
149  else if (method == "calcError") {
150  nargchk(nrhs>=4 && (nrhs%2)==0 && nlhs<=2);
151  vector<MxArray> dataOptions;
152  bool test = false;
153  for (int i=4; i<nrhs; i+=2) {
154  string key(rhs[i].toString());
155  if (key == "Data")
156  dataOptions = rhs[i+1].toVector<MxArray>();
157  else if (key == "TestError")
158  test = rhs[i+1].toBool();
159  else
160  mexErrMsgIdAndTxt("mexopencv:error",
161  "Unrecognized option %s", key.c_str());
162  }
163  Ptr<TrainData> data;
164  if (rhs[2].isChar())
165  data = loadTrainData(rhs[2].toString(),
166  dataOptions.begin(), dataOptions.end());
167  else
168  data = createTrainData(
169  rhs[2].toMat(CV_32F),
170  rhs[3].toMat(rhs[3].isInt32() ? CV_32S : CV_32F),
171  dataOptions.begin(), dataOptions.end());
172  Mat resp;
173  float err = obj->calcError(data, test, (nlhs>1 ? resp : noArray()));
174  plhs[0] = MxArray(err);
175  if (nlhs>1)
176  plhs[1] = MxArray(resp);
177  }
178  else if (method == "predict") {
179  nargchk(nrhs>=3 && (nrhs%2)==1 && nlhs<=2);
180  int flags = 0;
181  for (int i=3; i<nrhs; i+=2) {
182  string key(rhs[i].toString());
183  if (key == "Flags")
184  flags = rhs[i+1].toInt();
185  else
186  mexErrMsgIdAndTxt("mexopencv:error",
187  "Unrecognized option %s", key.c_str());
188  }
189  Mat samples(rhs[2].toMat(CV_32F)),
190  results;
191  float f = obj->predict(samples, results, flags);
192  plhs[0] = MxArray(results);
193  if (nlhs>1)
194  plhs[1] = MxArray(f);
195  }
196  else if (method == "findNearest") {
197  nargchk(nrhs==4 && nlhs<=4);
198  Mat samples(rhs[2].toMat(CV_32F));
199  int k = rhs[3].toInt();
200  Mat results, neighborResponses, dist;
201  float f = obj->findNearest(samples, k, results, neighborResponses, dist);
202  plhs[0] = MxArray(results);
203  if (nlhs>1)
204  plhs[1] = MxArray(neighborResponses);
205  if (nlhs>2)
206  plhs[2] = MxArray(dist);
207  if (nlhs>3)
208  plhs[3] = MxArray(f);
209  }
210  else if (method == "get") {
211  nargchk(nrhs==3 && nlhs<=1);
212  string prop(rhs[2].toString());
213  if (prop == "AlgorithmType")
214  plhs[0] = MxArray(InvKNNAlgType[obj->getAlgorithmType()]);
215  else if (prop == "DefaultK")
216  plhs[0] = MxArray(obj->getDefaultK());
217  else if (prop == "Emax")
218  plhs[0] = MxArray(obj->getEmax());
219  else if (prop == "IsClassifier")
220  plhs[0] = MxArray(obj->getIsClassifier());
221  else
222  mexErrMsgIdAndTxt("mexopencv:error",
223  "Unrecognized property %s", prop.c_str());
224  }
225  else if (method == "set") {
226  nargchk(nrhs==4 && nlhs==0);
227  string prop(rhs[2].toString());
228  if (prop == "AlgorithmType")
229  obj->setAlgorithmType(KNNAlgType[rhs[3].toString()]);
230  else if (prop == "DefaultK")
231  obj->setDefaultK(rhs[3].toInt());
232  else if (prop == "Emax")
233  obj->setEmax(rhs[3].toInt());
234  else if (prop == "IsClassifier")
235  obj->setIsClassifier(rhs[3].toBool());
236  else
237  mexErrMsgIdAndTxt("mexopencv:error",
238  "Unrecognized property %s", prop.c_str());
239  }
240  else
241  mexErrMsgIdAndTxt("mexopencv:error","Unrecognized operation");
242 }
#define UPDATE_FLAG(NUM, TF, BIT)
set or clear a bit in flag depending on bool value
Definition: mexopencv.hpp:159
int toInt() const
Convert MxArray to int.
Definition: MxArray.cpp:489
cv::Ptr< cv::ml::TrainData > loadTrainData(const std::string &filename, std::vector< MxArray >::const_iterator first, std::vector< MxArray >::const_iterator last)
Read a dataset from a CSV file.
mxArray object wrapper for data conversion and manipulation.
Definition: MxArray.hpp:123
void nargchk(bool cond)
Alias for input/ouput arguments number check.
Definition: mexopencv.hpp:166
bool toBool() const
Convert MxArray to bool.
Definition: MxArray.cpp:510
Global constant definitions.
cv::Ptr< cv::ml::TrainData > createTrainData(const cv::Mat &samples, const cv::Mat &responses, std::vector< MxArray >::const_iterator first, std::vector< MxArray >::const_iterator last)
Create an instance of TrainData using options in arguments.
std::map wrapper with one-line initialization and lookup method.
Definition: MxArray.hpp:927
void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[])
Main entry called from Matlab.
Definition: KNearest_.cpp:39
Common definitions for the ml module.