19 map<int,Ptr<ANN_MLP> > obj_;
23 (
"Backprop", cv::ml::ANN_MLP::BACKPROP)
24 (
"RProp", cv::ml::ANN_MLP::RPROP);
28 (cv::ml::ANN_MLP::BACKPROP,
"Backprop")
29 (cv::ml::ANN_MLP::RPROP,
"RProp");
33 (
"Identity", cv::ml::ANN_MLP::IDENTITY)
34 (
"Sigmoid", cv::ml::ANN_MLP::SIGMOID_SYM)
35 (
"Gaussian", cv::ml::ANN_MLP::GAUSSIAN);
39 (cv::ml::ANN_MLP::IDENTITY,
"Identity")
40 (cv::ml::ANN_MLP::SIGMOID_SYM,
"Sigmoid")
41 (cv::ml::ANN_MLP::GAUSSIAN,
"Gaussian");
51 void mexFunction(
int nlhs, mxArray *plhs[],
int nrhs,
const mxArray *prhs[])
57 vector<MxArray> rhs(prhs, prhs+nrhs);
58 int id = rhs[0].toInt();
59 string method(rhs[1].toString());
62 if (method ==
"new") {
64 obj_[++last_id] = ANN_MLP::create();
70 Ptr<ANN_MLP> obj = obj_[id];
71 if (method ==
"delete") {
75 else if (method ==
"clear") {
79 else if (method ==
"load") {
80 nargchk(nrhs>=3 && (nrhs%2)==1 && nlhs==0);
82 bool loadFromString =
false;
83 for (
int i=3; i<nrhs; i+=2) {
84 string key(rhs[i].toString());
86 objname = rhs[i+1].toString();
87 else if (key ==
"FromString")
88 loadFromString = rhs[i+1].toBool();
90 mexErrMsgIdAndTxt(
"mexopencv:error",
91 "Unrecognized option %s", key.c_str());
93 obj_[id] = (loadFromString ?
94 Algorithm::loadFromString<ANN_MLP>(rhs[2].toString(), objname) :
95 Algorithm::load<ANN_MLP>(rhs[2].toString(), objname));
97 else if (method ==
"save") {
99 string fname(rhs[2].toString());
102 FileStorage fs(fname, FileStorage::WRITE + FileStorage::MEMORY);
103 fs << obj->getDefaultName() <<
"{";
107 plhs[0] =
MxArray(fs.releaseAndGetString());
113 else if (method ==
"empty") {
115 plhs[0] =
MxArray(obj->empty());
117 else if (method ==
"getDefaultName") {
119 plhs[0] =
MxArray(obj->getDefaultName());
121 else if (method ==
"getVarCount") {
123 plhs[0] =
MxArray(obj->getVarCount());
125 else if (method ==
"isClassifier") {
127 plhs[0] =
MxArray(obj->isClassifier());
129 else if (method ==
"isTrained") {
131 plhs[0] =
MxArray(obj->isTrained());
133 else if (method ==
"train") {
134 nargchk(nrhs>=4 && (nrhs%2)==0 && nlhs<=1);
135 vector<MxArray> dataOptions;
137 for (
int i=4; i<nrhs; i+=2) {
138 string key(rhs[i].toString());
140 dataOptions = rhs[i+1].toVector<
MxArray>();
141 else if (key ==
"Flags")
142 flags = rhs[i+1].
toInt();
143 else if (key==
"UpdateWeights")
144 UPDATE_FLAG(flags, rhs[i+1].toBool(), ANN_MLP::UPDATE_WEIGHTS);
145 else if (key==
"NoInputScale")
146 UPDATE_FLAG(flags, rhs[i+1].toBool(), ANN_MLP::NO_INPUT_SCALE);
147 else if (key==
"NoOutputScale")
148 UPDATE_FLAG(flags, rhs[i+1].toBool(), ANN_MLP::NO_OUTPUT_SCALE);
150 mexErrMsgIdAndTxt(
"mexopencv:error",
151 "Unrecognized option %s", key.c_str());
156 dataOptions.begin(), dataOptions.end());
159 rhs[2].toMat(CV_32F),
160 rhs[3].toMat(CV_32F),
161 dataOptions.begin(), dataOptions.end());
162 bool b = obj->train(data, flags);
165 else if (method ==
"calcError") {
166 nargchk(nrhs>=4 && (nrhs%2)==0 && nlhs<=2);
167 vector<MxArray> dataOptions;
169 for (
int i=4; i<nrhs; i+=2) {
170 string key(rhs[i].toString());
172 dataOptions = rhs[i+1].toVector<
MxArray>();
173 else if (key ==
"TestError")
176 mexErrMsgIdAndTxt(
"mexopencv:error",
177 "Unrecognized option %s", key.c_str());
182 dataOptions.begin(), dataOptions.end());
185 rhs[2].toMat(CV_32F),
186 rhs[3].toMat(CV_32F),
187 dataOptions.begin(), dataOptions.end());
189 float err = obj->calcError(data, test, (nlhs>1 ? resp : noArray()));
194 else if (method ==
"predict") {
195 nargchk(nrhs>=3 && (nrhs%2)==1 && nlhs<=2);
197 for (
int i=3; i<nrhs; i+=2) {
198 string key(rhs[i].toString());
200 flags = rhs[i+1].toInt();
202 mexErrMsgIdAndTxt(
"mexopencv:error",
203 "Unrecognized option %s", key.c_str());
205 Mat samples(rhs[2].toMat(CV_32F)),
207 float f = obj->predict(samples, results, flags);
212 else if (method ==
"getWeights") {
214 int layerIdx = rhs[2].toInt();
215 plhs[0] =
MxArray(obj->getWeights(layerIdx));
217 else if (method ==
"setActivationFunction" || method ==
"setTrainMethod") {
218 nargchk(nrhs>=3 && (nrhs%2)==1 && nlhs==0);
221 for (
int i=3; i<nrhs; i+=2) {
222 string key(rhs[i].toString());
224 param1 = rhs[i+1].toDouble();
225 else if (key==
"Param2")
226 param2 = rhs[i+1].toDouble();
228 mexErrMsgIdAndTxt(
"mexopencv:error",
229 "Unrecognized option %s", key.c_str());
231 if (method ==
"setActivationFunction") {
232 int type = ActivateFunc[rhs[2].toString()];
233 obj->setActivationFunction(type, param1, param2);
236 int tmethod = ANN_MLPTrain[rhs[2].toString()];
237 obj->setTrainMethod(tmethod, param1, param2);
240 else if (method ==
"get") {
242 string prop(rhs[2].toString());
243 if (prop ==
"BackpropMomentumScale")
244 plhs[0] =
MxArray(obj->getBackpropMomentumScale());
245 else if (prop ==
"BackpropWeightScale")
246 plhs[0] =
MxArray(obj->getBackpropWeightScale());
247 else if (prop ==
"LayerSizes")
248 plhs[0] =
MxArray(obj->getLayerSizes());
249 else if (prop ==
"RpropDW0")
250 plhs[0] =
MxArray(obj->getRpropDW0());
251 else if (prop ==
"RpropDWMax")
252 plhs[0] =
MxArray(obj->getRpropDWMax());
253 else if (prop ==
"RpropDWMin")
254 plhs[0] =
MxArray(obj->getRpropDWMin());
255 else if (prop ==
"RpropDWMinus")
256 plhs[0] =
MxArray(obj->getRpropDWMinus());
257 else if (prop ==
"RpropDWPlus")
258 plhs[0] =
MxArray(obj->getRpropDWPlus());
259 else if (prop ==
"TermCriteria")
260 plhs[0] =
MxArray(obj->getTermCriteria());
261 else if (prop ==
"TrainMethod")
262 plhs[0] =
MxArray(InvANN_MLPTrain[obj->getTrainMethod()]);
264 mexErrMsgIdAndTxt(
"mexopencv:error",
265 "Unrecognized property %s", prop.c_str());
267 else if (method ==
"set") {
269 string prop(rhs[2].toString());
270 if (prop ==
"BackpropMomentumScale")
271 obj->setBackpropMomentumScale(rhs[3].toDouble());
272 else if (prop ==
"BackpropWeightScale")
273 obj->setBackpropWeightScale(rhs[3].toDouble());
274 else if (prop ==
"LayerSizes")
275 obj->setLayerSizes(rhs[3].toMat());
276 else if (prop ==
"RpropDW0")
277 obj->setRpropDW0(rhs[3].toDouble());
278 else if (prop ==
"RpropDWMax")
279 obj->setRpropDWMax(rhs[3].toDouble());
280 else if (prop ==
"RpropDWMin")
281 obj->setRpropDWMin(rhs[3].toDouble());
282 else if (prop ==
"RpropDWMinus")
283 obj->setRpropDWMinus(rhs[3].toDouble());
284 else if (prop ==
"RpropDWPlus")
285 obj->setRpropDWPlus(rhs[3].toDouble());
286 else if (prop ==
"TermCriteria")
287 obj->setTermCriteria(rhs[3].toTermCriteria());
288 else if (prop ==
"TrainMethod")
289 obj->setTrainMethod(ANN_MLPTrain[rhs[3].toString()]);
290 else if (prop ==
"ActivationFunction")
291 obj->setActivationFunction(ActivateFunc[rhs[3].toString()]);
293 mexErrMsgIdAndTxt(
"mexopencv:error",
294 "Unrecognized property %s", prop.c_str());
297 mexErrMsgIdAndTxt(
"mexopencv:error",
"Unrecognized operation");
#define UPDATE_FLAG(NUM, TF, BIT)
set or clear a bit in flag depending on bool value
int toInt() const
Convert MxArray to int.
cv::Ptr< cv::ml::TrainData > loadTrainData(const std::string &filename, std::vector< MxArray >::const_iterator first, std::vector< MxArray >::const_iterator last)
Read a dataset from a CSV file.
void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[])
Main entry called from Matlab.
mxArray object wrapper for data conversion and manipulation.
void nargchk(bool cond)
Alias for input/ouput arguments number check.
bool toBool() const
Convert MxArray to bool.
Global constant definitions.
cv::Ptr< cv::ml::TrainData > createTrainData(const cv::Mat &samples, const cv::Mat &responses, std::vector< MxArray >::const_iterator first, std::vector< MxArray >::const_iterator last)
Create an instance of TrainData using options in arguments.
std::map wrapper with one-line initialization and lookup method.
Common definitions for the ml module.