mexopencv  0.1
mex interface for opencv library
EM_.cpp
Go to the documentation of this file.
1 
8 #include "mexopencv.hpp"
9 #include "mexopencv_ml.hpp"
10 using namespace std;
11 using namespace cv;
12 using namespace cv::ml;
13 
14 // Persistent objects
15 namespace {
17 int last_id = 0;
19 map<int,Ptr<EM> > obj_;
20 
23  ("Spherical", cv::ml::EM::COV_MAT_SPHERICAL)
24  ("Diagonal", cv::ml::EM::COV_MAT_DIAGONAL)
25  ("Generic", cv::ml::EM::COV_MAT_GENERIC)
26  ("Default", cv::ml::EM::COV_MAT_DEFAULT);
27 
29 const ConstMap<int, string> CovMatTypeInv = ConstMap<int, string>
30  (cv::ml::EM::COV_MAT_SPHERICAL, "Spherical")
31  (cv::ml::EM::COV_MAT_DIAGONAL, "Diagonal")
32  (cv::ml::EM::COV_MAT_GENERIC, "Generic")
33  (cv::ml::EM::COV_MAT_DEFAULT, "Default");
34 }
35 
43 void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[])
44 {
45  // Check the number of arguments
46  nargchk(nrhs>=2 && nlhs<=4);
47 
48  // Argument vector
49  vector<MxArray> rhs(prhs, prhs+nrhs);
50  int id = rhs[0].toInt();
51  string method(rhs[1].toString());
52 
53  // Constructor is called. Create a new object from argument
54  if (method == "new") {
55  nargchk(nrhs==2 && nlhs<=1);
56  obj_[++last_id] = EM::create();
57  plhs[0] = MxArray(last_id);
58  return;
59  }
60 
61  // Big operation switch
62  Ptr<EM> obj = obj_[id];
63  if (method == "delete") {
64  nargchk(nrhs==2 && nlhs==0);
65  obj_.erase(id);
66  }
67  else if (method == "clear") {
68  nargchk(nrhs==2 && nlhs==0);
69  obj->clear();
70  }
71  else if (method == "load") {
72  nargchk(nrhs>=3 && (nrhs%2)==1 && nlhs==0);
73  string objname;
74  bool loadFromString = false;
75  for (int i=3; i<nrhs; i+=2) {
76  string key(rhs[i].toString());
77  if (key == "ObjName")
78  objname = rhs[i+1].toString();
79  else if (key == "FromString")
80  loadFromString = rhs[i+1].toBool();
81  else
82  mexErrMsgIdAndTxt("mexopencv:error",
83  "Unrecognized option %s", key.c_str());
84  }
85  obj_[id] = (loadFromString ?
86  Algorithm::loadFromString<EM>(rhs[2].toString(), objname) :
87  Algorithm::load<EM>(rhs[2].toString(), objname));
88  }
89  else if (method == "save") {
90  nargchk(nrhs==3 && nlhs<=1);
91  string fname(rhs[2].toString());
92  if (nlhs > 0) {
93  // write to memory, and return string
94  FileStorage fs(fname, FileStorage::WRITE + FileStorage::MEMORY);
95  fs << obj->getDefaultName() << "{";
96  fs << "format" << 3;
97  obj->write(fs);
98  fs << "}";
99  plhs[0] = MxArray(fs.releaseAndGetString());
100  }
101  else
102  // write to disk
103  obj->save(fname);
104  }
105  else if (method == "empty") {
106  nargchk(nrhs==2 && nlhs<=1);
107  plhs[0] = MxArray(obj->empty());
108  }
109  else if (method == "getDefaultName") {
110  nargchk(nrhs==2 && nlhs<=1);
111  plhs[0] = MxArray(obj->getDefaultName());
112  }
113  else if (method == "getVarCount") {
114  nargchk(nrhs==2 && nlhs<=1);
115  plhs[0] = MxArray(obj->getVarCount());
116  }
117  else if (method == "isClassifier") {
118  nargchk(nrhs==2 && nlhs<=1);
119  plhs[0] = MxArray(obj->isClassifier());
120  }
121  else if (method == "isTrained") {
122  nargchk(nrhs==2 && nlhs<=1);
123  plhs[0] = MxArray(obj->isTrained());
124  }
125  else if (method == "train") {
126  nargchk(nrhs>=3 && (nrhs%2)==1 && nlhs<=1);
127  vector<MxArray> dataOptions;
128  int flags = 0;
129  for (int i=3; i<nrhs; i+=2) {
130  string key(rhs[i].toString());
131  if (key == "Data")
132  dataOptions = rhs[i+1].toVector<MxArray>();
133  else if (key == "Flags")
134  flags = rhs[i+1].toInt();
135  else
136  mexErrMsgIdAndTxt("mexopencv:error",
137  "Unrecognized option %s", key.c_str());
138  }
139  Ptr<TrainData> data;
140  if (rhs[2].isChar())
141  data = loadTrainData(rhs[2].toString(),
142  dataOptions.begin(), dataOptions.end());
143  else
144  data = createTrainData(
145  rhs[2].toMat(CV_32F), Mat(),
146  dataOptions.begin(), dataOptions.end());
147  bool b = obj->train(data, flags);
148  plhs[0] = MxArray(b);
149  }
150  else if (method == "calcError") {
151  nargchk(nrhs>=4 && (nrhs%2)==0 && nlhs<=2);
152  vector<MxArray> dataOptions;
153  bool test = false;
154  for (int i=4; i<nrhs; i+=2) {
155  string key(rhs[i].toString());
156  if (key == "Data")
157  dataOptions = rhs[i+1].toVector<MxArray>();
158  else if (key == "TestError")
159  test = rhs[i+1].toBool();
160  else
161  mexErrMsgIdAndTxt("mexopencv:error",
162  "Unrecognized option %s", key.c_str());
163  }
164  Ptr<TrainData> data;
165  if (rhs[2].isChar())
166  data = loadTrainData(rhs[2].toString(),
167  dataOptions.begin(), dataOptions.end());
168  else
169  data = createTrainData(
170  rhs[2].toMat(CV_32F),
171  rhs[3].toMat(rhs[3].isInt32() ? CV_32S : CV_32F),
172  dataOptions.begin(), dataOptions.end());
173  Mat resp;
174  float err = obj->calcError(data, test, (nlhs>1 ? resp : noArray()));
175  plhs[0] = MxArray(err);
176  if (nlhs>1)
177  plhs[1] = MxArray(resp);
178  }
179  else if (method == "predict") {
180  nargchk(nrhs>=3 && (nrhs%2)==1 && nlhs<=2);
181  int flags = 0;
182  for (int i=3; i<nrhs; i+=2) {
183  string key(rhs[i].toString());
184  if (key == "Flags")
185  flags = rhs[i+1].toInt();
186  else
187  mexErrMsgIdAndTxt("mexopencv:error",
188  "Unrecognized option %s", key.c_str());
189  }
190  Mat samples(rhs[2].toMat(rhs[2].isSingle() ? CV_32F : CV_64F)),
191  results;
192  float f = obj->predict(samples, results, flags);
193  plhs[0] = MxArray(results);
194  if (nlhs>1)
195  plhs[1] = MxArray(f);
196  }
197  else if (method == "trainEM") {
198  nargchk(nrhs==3 && nlhs<=4);
199  Mat samples(rhs[2].toMat(rhs[2].isSingle() ? CV_32F : CV_64F)),
200  logLikelihoods, labels, probs;
201  bool b = obj->trainEM(samples,
202  (nlhs>0 ? logLikelihoods : noArray()),
203  (nlhs>1 ? labels : noArray()),
204  (nlhs>2 ? probs : noArray()));
205  plhs[0] = MxArray(logLikelihoods);
206  if (nlhs > 1)
207  plhs[1] = MxArray(labels);
208  if (nlhs > 2)
209  plhs[2] = MxArray(probs);
210  if (nlhs > 3)
211  plhs[3] = MxArray(b);
212  }
213  else if (method == "trainE") {
214  nargchk(nrhs>=4 && (nrhs%2)==0 && nlhs<=4);
215  vector<Mat> covs0;
216  Mat weights0;
217  for(int i = 4; i < nrhs; i += 2) {
218  string key(rhs[i].toString());
219  if (key == "Covs0") {
220  //covs0 = rhs[i+1].toVector<Mat>();
221  covs0.clear();
222  vector<MxArray> arr(rhs[i+1].toVector<MxArray>());
223  covs0.reserve(arr.size());
224  for (vector<MxArray>::const_iterator it = arr.begin(); it != arr.end(); ++it)
225  covs0.push_back(it->toMat(
226  it->isSingle() ? CV_32F : CV_64F));
227  }
228  else if (key == "Weights0")
229  weights0 = rhs[i+1].toMat(
230  rhs[i+1].isSingle() ? CV_32F : CV_64F);
231  else
232  mexErrMsgIdAndTxt("mexopencv:error",
233  "Unrecognized option %s", key.c_str());
234  }
235  Mat samples(rhs[2].toMat(rhs[2].isSingle() ? CV_32F : CV_64F)),
236  means0(rhs[3].toMat(rhs[3].isSingle() ? CV_32F : CV_64F)),
237  logLikelihoods, labels, probs;
238  bool b = obj->trainE(samples, means0, covs0, weights0,
239  (nlhs>0 ? logLikelihoods : noArray()),
240  (nlhs>1 ? labels : noArray()),
241  (nlhs>2 ? probs : noArray()));
242  plhs[0] = MxArray(logLikelihoods);
243  if (nlhs > 1)
244  plhs[1] = MxArray(labels);
245  if (nlhs > 2)
246  plhs[2] = MxArray(probs);
247  if (nlhs > 3)
248  plhs[3] = MxArray(b);
249  }
250  else if (method == "trainM") {
251  nargchk(nrhs==4 && nlhs<=4);
252  Mat samples(rhs[2].toMat(rhs[2].isSingle() ? CV_32F : CV_64F)),
253  probs0(rhs[3].toMat(rhs[3].isSingle() ? CV_32F : CV_64F)),
254  logLikelihoods, labels, probs;
255  bool b = obj->trainM(samples, probs0,
256  (nlhs>0 ? logLikelihoods : noArray()),
257  (nlhs>1 ? labels : noArray()),
258  (nlhs>2 ? probs : noArray()));
259  plhs[0] = MxArray(logLikelihoods);
260  if (nlhs > 1)
261  plhs[1] = MxArray(labels);
262  if (nlhs > 2)
263  plhs[2] = MxArray(probs);
264  if (nlhs > 3)
265  plhs[3] = MxArray(b);
266  }
267  else if (method == "predict2") {
268  nargchk(nrhs==3 && nlhs<=3);
269  Mat samples(rhs[2].toMat(rhs[2].isSingle() ? CV_32F : CV_64F)),
270  probs;
271  if (samples.rows == 1 || samples.cols == 1)
272  samples = samples.reshape(1,1); // ensure 1xd vector if one sample
273  if (nlhs > 1)
274  probs.create(samples.rows, obj->getClustersNumber(), CV_64F);
275  vector<Vec2d> results;
276  results.reserve(samples.rows);
277  for (size_t i = 0; i < samples.rows; ++i) {
278  Vec2d res = obj->predict2(samples.row(i),
279  (nlhs>1 ? probs.row(i) : noArray()));
280  results.push_back(res);
281  }
282  plhs[0] = MxArray(Mat(results, false).reshape(1,0)); // Nx2
283  if (nlhs > 1)
284  plhs[1] = MxArray(probs); // NxK
285  }
286  else if (method == "getCovs") {
287  nargchk(nrhs==2 && nlhs<=1);
288  vector<Mat> covs;
289  obj->getCovs(covs);
290  plhs[0] = MxArray(covs);
291  }
292  else if (method == "getMeans") {
293  nargchk(nrhs==2 && nlhs<=1);
294  plhs[0] = MxArray(obj->getMeans());
295  }
296  else if (method == "getWeights") {
297  nargchk(nrhs==2 && nlhs<=1);
298  plhs[0] = MxArray(obj->getWeights());
299  }
300  else if (method == "get") {
301  nargchk(nrhs==3 && nlhs<=1);
302  string prop(rhs[2].toString());
303  if (prop == "ClustersNumber")
304  plhs[0] = MxArray(obj->getClustersNumber());
305  else if (prop == "CovarianceMatrixType")
306  plhs[0] = MxArray(CovMatTypeInv[obj->getCovarianceMatrixType()]);
307  else if (prop == "TermCriteria")
308  plhs[0] = MxArray(obj->getTermCriteria());
309  else
310  mexErrMsgIdAndTxt("mexopencv:error",
311  "Unrecognized property %s", prop.c_str());
312  }
313  else if (method == "set") {
314  nargchk(nrhs==4 && nlhs==0);
315  string prop(rhs[2].toString());
316  if (prop == "ClustersNumber")
317  obj->setClustersNumber(rhs[3].toInt());
318  else if (prop == "CovarianceMatrixType")
319  obj->setCovarianceMatrixType(CovMatType[rhs[3].toString()]);
320  else if (prop == "TermCriteria")
321  obj->setTermCriteria(rhs[3].toTermCriteria());
322  else
323  mexErrMsgIdAndTxt("mexopencv:error",
324  "Unrecognized property %s", prop.c_str());
325  }
326  else
327  mexErrMsgIdAndTxt("mexopencv:error","Unrecognized operation");
328 }
int toInt() const
Convert MxArray to int.
Definition: MxArray.cpp:489
cv::Ptr< cv::ml::TrainData > loadTrainData(const std::string &filename, std::vector< MxArray >::const_iterator first, std::vector< MxArray >::const_iterator last)
Read a dataset from a CSV file.
void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[])
Main entry called from Matlab.
Definition: EM_.cpp:43
mxArray object wrapper for data conversion and manipulation.
Definition: MxArray.hpp:123
void nargchk(bool cond)
Alias for input/ouput arguments number check.
Definition: mexopencv.hpp:166
bool toBool() const
Convert MxArray to bool.
Definition: MxArray.cpp:510
Global constant definitions.
cv::Ptr< cv::ml::TrainData > createTrainData(const cv::Mat &samples, const cv::Mat &responses, std::vector< MxArray >::const_iterator first, std::vector< MxArray >::const_iterator last)
Create an instance of TrainData using options in arguments.
std::map wrapper with one-line initialization and lookup method.
Definition: MxArray.hpp:927
Common definitions for the ml module.