19 map<int,Ptr<EM> > obj_;
23 (
"Spherical", cv::ml::EM::COV_MAT_SPHERICAL)
24 (
"Diagonal", cv::ml::EM::COV_MAT_DIAGONAL)
25 (
"Generic", cv::ml::EM::COV_MAT_GENERIC)
26 (
"Default", cv::ml::EM::COV_MAT_DEFAULT);
30 (cv::ml::EM::COV_MAT_SPHERICAL,
"Spherical")
31 (cv::ml::EM::COV_MAT_DIAGONAL,
"Diagonal")
32 (cv::ml::EM::COV_MAT_GENERIC,
"Generic")
33 (cv::ml::EM::COV_MAT_DEFAULT,
"Default");
43 void mexFunction(
int nlhs, mxArray *plhs[],
int nrhs,
const mxArray *prhs[])
49 vector<MxArray> rhs(prhs, prhs+nrhs);
50 int id = rhs[0].toInt();
51 string method(rhs[1].toString());
54 if (method ==
"new") {
56 obj_[++last_id] = EM::create();
62 Ptr<EM> obj = obj_[id];
63 if (method ==
"delete") {
67 else if (method ==
"clear") {
71 else if (method ==
"load") {
72 nargchk(nrhs>=3 && (nrhs%2)==1 && nlhs==0);
74 bool loadFromString =
false;
75 for (
int i=3; i<nrhs; i+=2) {
76 string key(rhs[i].toString());
78 objname = rhs[i+1].toString();
79 else if (key ==
"FromString")
80 loadFromString = rhs[i+1].toBool();
82 mexErrMsgIdAndTxt(
"mexopencv:error",
83 "Unrecognized option %s", key.c_str());
85 obj_[id] = (loadFromString ?
86 Algorithm::loadFromString<EM>(rhs[2].toString(), objname) :
87 Algorithm::load<EM>(rhs[2].toString(), objname));
89 else if (method ==
"save") {
91 string fname(rhs[2].toString());
94 FileStorage fs(fname, FileStorage::WRITE + FileStorage::MEMORY);
95 fs << obj->getDefaultName() <<
"{";
99 plhs[0] =
MxArray(fs.releaseAndGetString());
105 else if (method ==
"empty") {
107 plhs[0] =
MxArray(obj->empty());
109 else if (method ==
"getDefaultName") {
111 plhs[0] =
MxArray(obj->getDefaultName());
113 else if (method ==
"getVarCount") {
115 plhs[0] =
MxArray(obj->getVarCount());
117 else if (method ==
"isClassifier") {
119 plhs[0] =
MxArray(obj->isClassifier());
121 else if (method ==
"isTrained") {
123 plhs[0] =
MxArray(obj->isTrained());
125 else if (method ==
"train") {
126 nargchk(nrhs>=3 && (nrhs%2)==1 && nlhs<=1);
127 vector<MxArray> dataOptions;
129 for (
int i=3; i<nrhs; i+=2) {
130 string key(rhs[i].toString());
132 dataOptions = rhs[i+1].toVector<
MxArray>();
133 else if (key ==
"Flags")
134 flags = rhs[i+1].
toInt();
136 mexErrMsgIdAndTxt(
"mexopencv:error",
137 "Unrecognized option %s", key.c_str());
142 dataOptions.begin(), dataOptions.end());
145 rhs[2].toMat(CV_32F), Mat(),
146 dataOptions.begin(), dataOptions.end());
147 bool b = obj->train(data, flags);
150 else if (method ==
"calcError") {
151 nargchk(nrhs>=4 && (nrhs%2)==0 && nlhs<=2);
152 vector<MxArray> dataOptions;
154 for (
int i=4; i<nrhs; i+=2) {
155 string key(rhs[i].toString());
157 dataOptions = rhs[i+1].toVector<
MxArray>();
158 else if (key ==
"TestError")
161 mexErrMsgIdAndTxt(
"mexopencv:error",
162 "Unrecognized option %s", key.c_str());
167 dataOptions.begin(), dataOptions.end());
170 rhs[2].toMat(CV_32F),
171 rhs[3].toMat(rhs[3].isInt32() ? CV_32S : CV_32F),
172 dataOptions.begin(), dataOptions.end());
174 float err = obj->calcError(data, test, (nlhs>1 ? resp : noArray()));
179 else if (method ==
"predict") {
180 nargchk(nrhs>=3 && (nrhs%2)==1 && nlhs<=2);
182 for (
int i=3; i<nrhs; i+=2) {
183 string key(rhs[i].toString());
185 flags = rhs[i+1].toInt();
187 mexErrMsgIdAndTxt(
"mexopencv:error",
188 "Unrecognized option %s", key.c_str());
190 Mat samples(rhs[2].toMat(rhs[2].isSingle() ? CV_32F : CV_64F)),
192 float f = obj->predict(samples, results, flags);
197 else if (method ==
"trainEM") {
199 Mat samples(rhs[2].toMat(rhs[2].isSingle() ? CV_32F : CV_64F)),
200 logLikelihoods, labels, probs;
201 bool b = obj->trainEM(samples,
202 (nlhs>0 ? logLikelihoods : noArray()),
203 (nlhs>1 ? labels : noArray()),
204 (nlhs>2 ? probs : noArray()));
205 plhs[0] =
MxArray(logLikelihoods);
213 else if (method ==
"trainE") {
214 nargchk(nrhs>=4 && (nrhs%2)==0 && nlhs<=4);
217 for(
int i = 4; i < nrhs; i += 2) {
218 string key(rhs[i].toString());
219 if (key ==
"Covs0") {
222 vector<MxArray> arr(rhs[i+1].toVector<MxArray>());
223 covs0.reserve(arr.size());
224 for (vector<MxArray>::const_iterator it = arr.begin(); it != arr.end(); ++it)
225 covs0.push_back(it->toMat(
226 it->isSingle() ? CV_32F : CV_64F));
228 else if (key ==
"Weights0")
229 weights0 = rhs[i+1].toMat(
230 rhs[i+1].isSingle() ? CV_32F : CV_64F);
232 mexErrMsgIdAndTxt(
"mexopencv:error",
233 "Unrecognized option %s", key.c_str());
235 Mat samples(rhs[2].toMat(rhs[2].isSingle() ? CV_32F : CV_64F)),
236 means0(rhs[3].toMat(rhs[3].isSingle() ? CV_32F : CV_64F)),
237 logLikelihoods, labels, probs;
238 bool b = obj->trainE(samples, means0, covs0, weights0,
239 (nlhs>0 ? logLikelihoods : noArray()),
240 (nlhs>1 ? labels : noArray()),
241 (nlhs>2 ? probs : noArray()));
242 plhs[0] =
MxArray(logLikelihoods);
250 else if (method ==
"trainM") {
252 Mat samples(rhs[2].toMat(rhs[2].isSingle() ? CV_32F : CV_64F)),
253 probs0(rhs[3].toMat(rhs[3].isSingle() ? CV_32F : CV_64F)),
254 logLikelihoods, labels, probs;
255 bool b = obj->trainM(samples, probs0,
256 (nlhs>0 ? logLikelihoods : noArray()),
257 (nlhs>1 ? labels : noArray()),
258 (nlhs>2 ? probs : noArray()));
259 plhs[0] =
MxArray(logLikelihoods);
267 else if (method ==
"predict2") {
269 Mat samples(rhs[2].toMat(rhs[2].isSingle() ? CV_32F : CV_64F)),
271 if (samples.rows == 1 || samples.cols == 1)
272 samples = samples.reshape(1,1);
274 probs.create(samples.rows, obj->getClustersNumber(), CV_64F);
275 vector<Vec2d> results;
276 results.reserve(samples.rows);
277 for (
size_t i = 0; i < samples.rows; ++i) {
278 Vec2d res = obj->predict2(samples.row(i),
279 (nlhs>1 ? probs.row(i) : noArray()));
280 results.push_back(res);
282 plhs[0] =
MxArray(Mat(results,
false).reshape(1,0));
286 else if (method ==
"getCovs") {
292 else if (method ==
"getMeans") {
294 plhs[0] =
MxArray(obj->getMeans());
296 else if (method ==
"getWeights") {
298 plhs[0] =
MxArray(obj->getWeights());
300 else if (method ==
"get") {
302 string prop(rhs[2].toString());
303 if (prop ==
"ClustersNumber")
304 plhs[0] =
MxArray(obj->getClustersNumber());
305 else if (prop ==
"CovarianceMatrixType")
306 plhs[0] =
MxArray(CovMatTypeInv[obj->getCovarianceMatrixType()]);
307 else if (prop ==
"TermCriteria")
308 plhs[0] =
MxArray(obj->getTermCriteria());
310 mexErrMsgIdAndTxt(
"mexopencv:error",
311 "Unrecognized property %s", prop.c_str());
313 else if (method ==
"set") {
315 string prop(rhs[2].toString());
316 if (prop ==
"ClustersNumber")
317 obj->setClustersNumber(rhs[3].toInt());
318 else if (prop ==
"CovarianceMatrixType")
319 obj->setCovarianceMatrixType(CovMatType[rhs[3].toString()]);
320 else if (prop ==
"TermCriteria")
321 obj->setTermCriteria(rhs[3].toTermCriteria());
323 mexErrMsgIdAndTxt(
"mexopencv:error",
324 "Unrecognized property %s", prop.c_str());
327 mexErrMsgIdAndTxt(
"mexopencv:error",
"Unrecognized operation");
int toInt() const
Convert MxArray to int.
cv::Ptr< cv::ml::TrainData > loadTrainData(const std::string &filename, std::vector< MxArray >::const_iterator first, std::vector< MxArray >::const_iterator last)
Read a dataset from a CSV file.
void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[])
Main entry called from Matlab.
mxArray object wrapper for data conversion and manipulation.
void nargchk(bool cond)
Alias for input/ouput arguments number check.
bool toBool() const
Convert MxArray to bool.
Global constant definitions.
cv::Ptr< cv::ml::TrainData > createTrainData(const cv::Mat &samples, const cv::Mat &responses, std::vector< MxArray >::const_iterator first, std::vector< MxArray >::const_iterator last)
Create an instance of TrainData using options in arguments.
std::map wrapper with one-line initialization and lookup method.
Common definitions for the ml module.