mexopencv  0.1
mex interface for opencv library
ANN_MLP_.cpp
Go to the documentation of this file.
1 
8 #include "mexopencv.hpp"
9 #include "mexopencv_ml.hpp"
10 using namespace std;
11 using namespace cv;
12 using namespace cv::ml;
13 
14 // Persistent objects
15 namespace {
17 int last_id = 0;
19 map<int,Ptr<ANN_MLP> > obj_;
20 
22 const ConstMap<string,int> ANN_MLPTrain = ConstMap<string,int>
23  ("Backprop", cv::ml::ANN_MLP::BACKPROP)
24  ("RProp", cv::ml::ANN_MLP::RPROP);
25 
27 const ConstMap<int,string> InvANN_MLPTrain = ConstMap<int,string>
28  (cv::ml::ANN_MLP::BACKPROP, "Backprop")
29  (cv::ml::ANN_MLP::RPROP, "RProp");
30 
32 const ConstMap<string,int> ActivateFunc = ConstMap<string,int>
33  ("Identity", cv::ml::ANN_MLP::IDENTITY)
34  ("Sigmoid", cv::ml::ANN_MLP::SIGMOID_SYM)
35  ("Gaussian", cv::ml::ANN_MLP::GAUSSIAN);
36 
38 const ConstMap<int,string> InvActivateFunc = ConstMap<int,string>
39  (cv::ml::ANN_MLP::IDENTITY, "Identity")
40  (cv::ml::ANN_MLP::SIGMOID_SYM, "Sigmoid")
41  (cv::ml::ANN_MLP::GAUSSIAN, "Gaussian");
42 }
43 
51 void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[])
52 {
53  // Check the number of arguments
54  nargchk(nrhs>=2 && nlhs<=2);
55 
56  // Argument vector
57  vector<MxArray> rhs(prhs, prhs+nrhs);
58  int id = rhs[0].toInt();
59  string method(rhs[1].toString());
60 
61  // Constructor is called. Create a new object from argument
62  if (method == "new") {
63  nargchk(nrhs==2 && nlhs<=1);
64  obj_[++last_id] = ANN_MLP::create();
65  plhs[0] = MxArray(last_id);
66  return;
67  }
68 
69  // Big operation switch
70  Ptr<ANN_MLP> obj = obj_[id];
71  if (method == "delete") {
72  nargchk(nrhs==2 && nlhs==0);
73  obj_.erase(id);
74  }
75  else if (method == "clear") {
76  nargchk(nrhs==2 && nlhs==0);
77  obj->clear();
78  }
79  else if (method == "load") {
80  nargchk(nrhs>=3 && (nrhs%2)==1 && nlhs==0);
81  string objname;
82  bool loadFromString = false;
83  for (int i=3; i<nrhs; i+=2) {
84  string key(rhs[i].toString());
85  if (key == "ObjName")
86  objname = rhs[i+1].toString();
87  else if (key == "FromString")
88  loadFromString = rhs[i+1].toBool();
89  else
90  mexErrMsgIdAndTxt("mexopencv:error",
91  "Unrecognized option %s", key.c_str());
92  }
93  obj_[id] = (loadFromString ?
94  Algorithm::loadFromString<ANN_MLP>(rhs[2].toString(), objname) :
95  Algorithm::load<ANN_MLP>(rhs[2].toString(), objname));
96  }
97  else if (method == "save") {
98  nargchk(nrhs==3 && nlhs<=1);
99  string fname(rhs[2].toString());
100  if (nlhs > 0) {
101  // write to memory, and return string
102  FileStorage fs(fname, FileStorage::WRITE + FileStorage::MEMORY);
103  fs << obj->getDefaultName() << "{";
104  fs << "format" << 3;
105  obj->write(fs);
106  fs << "}";
107  plhs[0] = MxArray(fs.releaseAndGetString());
108  }
109  else
110  // write to disk
111  obj->save(fname);
112  }
113  else if (method == "empty") {
114  nargchk(nrhs==2 && nlhs<=1);
115  plhs[0] = MxArray(obj->empty());
116  }
117  else if (method == "getDefaultName") {
118  nargchk(nrhs==2 && nlhs<=1);
119  plhs[0] = MxArray(obj->getDefaultName());
120  }
121  else if (method == "getVarCount") {
122  nargchk(nrhs==2 && nlhs<=1);
123  plhs[0] = MxArray(obj->getVarCount());
124  }
125  else if (method == "isClassifier") {
126  nargchk(nrhs==2 && nlhs<=1);
127  plhs[0] = MxArray(obj->isClassifier());
128  }
129  else if (method == "isTrained") {
130  nargchk(nrhs==2 && nlhs<=1);
131  plhs[0] = MxArray(obj->isTrained());
132  }
133  else if (method == "train") {
134  nargchk(nrhs>=4 && (nrhs%2)==0 && nlhs<=1);
135  vector<MxArray> dataOptions;
136  int flags = 0;
137  for (int i=4; i<nrhs; i+=2) {
138  string key(rhs[i].toString());
139  if (key == "Data")
140  dataOptions = rhs[i+1].toVector<MxArray>();
141  else if (key == "Flags")
142  flags = rhs[i+1].toInt();
143  else if (key=="UpdateWeights")
144  UPDATE_FLAG(flags, rhs[i+1].toBool(), ANN_MLP::UPDATE_WEIGHTS);
145  else if (key=="NoInputScale")
146  UPDATE_FLAG(flags, rhs[i+1].toBool(), ANN_MLP::NO_INPUT_SCALE);
147  else if (key=="NoOutputScale")
148  UPDATE_FLAG(flags, rhs[i+1].toBool(), ANN_MLP::NO_OUTPUT_SCALE);
149  else
150  mexErrMsgIdAndTxt("mexopencv:error",
151  "Unrecognized option %s", key.c_str());
152  }
153  Ptr<TrainData> data;
154  if (rhs[2].isChar())
155  data = loadTrainData(rhs[2].toString(),
156  dataOptions.begin(), dataOptions.end());
157  else
158  data = createTrainData(
159  rhs[2].toMat(CV_32F),
160  rhs[3].toMat(CV_32F),
161  dataOptions.begin(), dataOptions.end());
162  bool b = obj->train(data, flags);
163  plhs[0] = MxArray(b);
164  }
165  else if (method == "calcError") {
166  nargchk(nrhs>=4 && (nrhs%2)==0 && nlhs<=2);
167  vector<MxArray> dataOptions;
168  bool test = false;
169  for (int i=4; i<nrhs; i+=2) {
170  string key(rhs[i].toString());
171  if (key == "Data")
172  dataOptions = rhs[i+1].toVector<MxArray>();
173  else if (key == "TestError")
174  test = rhs[i+1].toBool();
175  else
176  mexErrMsgIdAndTxt("mexopencv:error",
177  "Unrecognized option %s", key.c_str());
178  }
179  Ptr<TrainData> data;
180  if (rhs[2].isChar())
181  data = loadTrainData(rhs[2].toString(),
182  dataOptions.begin(), dataOptions.end());
183  else
184  data = createTrainData(
185  rhs[2].toMat(CV_32F),
186  rhs[3].toMat(CV_32F),
187  dataOptions.begin(), dataOptions.end());
188  Mat resp;
189  float err = obj->calcError(data, test, (nlhs>1 ? resp : noArray()));
190  plhs[0] = MxArray(err);
191  if (nlhs>1)
192  plhs[1] = MxArray(resp);
193  }
194  else if (method == "predict") {
195  nargchk(nrhs>=3 && (nrhs%2)==1 && nlhs<=2);
196  int flags = 0;
197  for (int i=3; i<nrhs; i+=2) {
198  string key(rhs[i].toString());
199  if (key == "Flags")
200  flags = rhs[i+1].toInt();
201  else
202  mexErrMsgIdAndTxt("mexopencv:error",
203  "Unrecognized option %s", key.c_str());
204  }
205  Mat samples(rhs[2].toMat(CV_32F)),
206  results;
207  float f = obj->predict(samples, results, flags);
208  plhs[0] = MxArray(results);
209  if (nlhs>1)
210  plhs[1] = MxArray(f);
211  }
212  else if (method == "getWeights") {
213  nargchk(nrhs==3 && nlhs<=1);
214  int layerIdx = rhs[2].toInt();
215  plhs[0] = MxArray(obj->getWeights(layerIdx));
216  }
217  else if (method == "setActivationFunction" || method == "setTrainMethod") {
218  nargchk(nrhs>=3 && (nrhs%2)==1 && nlhs==0);
219  double param1 = 0,
220  param2 = 0;
221  for (int i=3; i<nrhs; i+=2) {
222  string key(rhs[i].toString());
223  if (key=="Param1")
224  param1 = rhs[i+1].toDouble();
225  else if (key=="Param2")
226  param2 = rhs[i+1].toDouble();
227  else
228  mexErrMsgIdAndTxt("mexopencv:error",
229  "Unrecognized option %s", key.c_str());
230  }
231  if (method == "setActivationFunction") {
232  int type = ActivateFunc[rhs[2].toString()];
233  obj->setActivationFunction(type, param1, param2);
234  }
235  else {
236  int tmethod = ANN_MLPTrain[rhs[2].toString()];
237  obj->setTrainMethod(tmethod, param1, param2);
238  }
239  }
240  else if (method == "get") {
241  nargchk(nrhs==3 && nlhs<=1);
242  string prop(rhs[2].toString());
243  if (prop == "BackpropMomentumScale")
244  plhs[0] = MxArray(obj->getBackpropMomentumScale());
245  else if (prop == "BackpropWeightScale")
246  plhs[0] = MxArray(obj->getBackpropWeightScale());
247  else if (prop == "LayerSizes")
248  plhs[0] = MxArray(obj->getLayerSizes());
249  else if (prop == "RpropDW0")
250  plhs[0] = MxArray(obj->getRpropDW0());
251  else if (prop == "RpropDWMax")
252  plhs[0] = MxArray(obj->getRpropDWMax());
253  else if (prop == "RpropDWMin")
254  plhs[0] = MxArray(obj->getRpropDWMin());
255  else if (prop == "RpropDWMinus")
256  plhs[0] = MxArray(obj->getRpropDWMinus());
257  else if (prop == "RpropDWPlus")
258  plhs[0] = MxArray(obj->getRpropDWPlus());
259  else if (prop == "TermCriteria")
260  plhs[0] = MxArray(obj->getTermCriteria());
261  else if (prop == "TrainMethod")
262  plhs[0] = MxArray(InvANN_MLPTrain[obj->getTrainMethod()]);
263  else
264  mexErrMsgIdAndTxt("mexopencv:error",
265  "Unrecognized property %s", prop.c_str());
266  }
267  else if (method == "set") {
268  nargchk(nrhs==4 && nlhs==0);
269  string prop(rhs[2].toString());
270  if (prop == "BackpropMomentumScale")
271  obj->setBackpropMomentumScale(rhs[3].toDouble());
272  else if (prop == "BackpropWeightScale")
273  obj->setBackpropWeightScale(rhs[3].toDouble());
274  else if (prop == "LayerSizes")
275  obj->setLayerSizes(rhs[3].toMat());
276  else if (prop == "RpropDW0")
277  obj->setRpropDW0(rhs[3].toDouble());
278  else if (prop == "RpropDWMax")
279  obj->setRpropDWMax(rhs[3].toDouble());
280  else if (prop == "RpropDWMin")
281  obj->setRpropDWMin(rhs[3].toDouble());
282  else if (prop == "RpropDWMinus")
283  obj->setRpropDWMinus(rhs[3].toDouble());
284  else if (prop == "RpropDWPlus")
285  obj->setRpropDWPlus(rhs[3].toDouble());
286  else if (prop == "TermCriteria")
287  obj->setTermCriteria(rhs[3].toTermCriteria());
288  else if (prop == "TrainMethod")
289  obj->setTrainMethod(ANN_MLPTrain[rhs[3].toString()]);
290  else if (prop == "ActivationFunction")
291  obj->setActivationFunction(ActivateFunc[rhs[3].toString()]);
292  else
293  mexErrMsgIdAndTxt("mexopencv:error",
294  "Unrecognized property %s", prop.c_str());
295  }
296  else
297  mexErrMsgIdAndTxt("mexopencv:error","Unrecognized operation");
298 }
#define UPDATE_FLAG(NUM, TF, BIT)
set or clear a bit in flag depending on bool value
Definition: mexopencv.hpp:159
int toInt() const
Convert MxArray to int.
Definition: MxArray.cpp:489
cv::Ptr< cv::ml::TrainData > loadTrainData(const std::string &filename, std::vector< MxArray >::const_iterator first, std::vector< MxArray >::const_iterator last)
Read a dataset from a CSV file.
void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[])
Main entry called from Matlab.
Definition: ANN_MLP_.cpp:51
mxArray object wrapper for data conversion and manipulation.
Definition: MxArray.hpp:123
void nargchk(bool cond)
Alias for input/ouput arguments number check.
Definition: mexopencv.hpp:166
bool toBool() const
Convert MxArray to bool.
Definition: MxArray.cpp:510
Global constant definitions.
cv::Ptr< cv::ml::TrainData > createTrainData(const cv::Mat &samples, const cv::Mat &responses, std::vector< MxArray >::const_iterator first, std::vector< MxArray >::const_iterator last)
Create an instance of TrainData using options in arguments.
std::map wrapper with one-line initialization and lookup method.
Definition: MxArray.hpp:927
Common definitions for the ml module.