mexopencv  0.1
mex interface for opencv library
LogisticRegression_.cpp
Go to the documentation of this file.
1 
8 #include "mexopencv.hpp"
9 #include "mexopencv_ml.hpp"
10 using namespace std;
11 using namespace cv;
12 using namespace cv::ml;
13 
14 // Persistent objects
15 namespace {
17 int last_id = 0;
19 map<int,Ptr<LogisticRegression> > obj_;
20 
22 const ConstMap<string,int> TrainingMethodType = ConstMap<string,int>
23  ("Batch", cv::ml::LogisticRegression::BATCH)
24  ("MiniBatch", cv::ml::LogisticRegression::MINI_BATCH);
25 
27 const ConstMap<int,string> InvTrainingMethodType = ConstMap<int,string>
28  (cv::ml::LogisticRegression::BATCH, "Batch")
29  (cv::ml::LogisticRegression::MINI_BATCH, "MiniBatch");
30 
32 const ConstMap<string,int> RegularizationType = ConstMap<string,int>
33  ("Disable", cv::ml::LogisticRegression::REG_DISABLE) // Regularization disabled
34  ("L1", cv::ml::LogisticRegression::REG_L1) // L1 norm
35  ("L2", cv::ml::LogisticRegression::REG_L2); // L2 norm
36 
38 const ConstMap<int,string> InvRegularizationType = ConstMap<int,string>
39  (cv::ml::LogisticRegression::REG_DISABLE, "Disable")
40  (cv::ml::LogisticRegression::REG_L1, "L1")
41  (cv::ml::LogisticRegression::REG_L2, "L2");
42 }
43 
51 void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[])
52 {
53  // Check the number of arguments
54  nargchk(nrhs>=2 && nlhs<=2);
55 
56  // Argument vector
57  vector<MxArray> rhs(prhs, prhs+nrhs);
58  int id = rhs[0].toInt();
59  string method(rhs[1].toString());
60 
61  // Constructor is called. Create a new object from argument
62  if (method == "new") {
63  nargchk(nrhs==2 && nlhs<=1);
64  obj_[++last_id] = LogisticRegression::create();
65  plhs[0] = MxArray(last_id);
66  return;
67  }
68 
69  // Big operation switch
70  Ptr<LogisticRegression> obj = obj_[id];
71  if (method == "delete") {
72  nargchk(nrhs==2 && nlhs==0);
73  obj_.erase(id);
74  }
75  else if (method == "clear") {
76  nargchk(nrhs==2 && nlhs==0);
77  obj->clear();
78  }
79  else if (method == "load") {
80  nargchk(nrhs>=3 && (nrhs%2)==1 && nlhs==0);
81  string objname;
82  bool loadFromString = false;
83  for (int i=3; i<nrhs; i+=2) {
84  string key(rhs[i].toString());
85  if (key == "ObjName")
86  objname = rhs[i+1].toString();
87  else if (key == "FromString")
88  loadFromString = rhs[i+1].toBool();
89  else
90  mexErrMsgIdAndTxt("mexopencv:error",
91  "Unrecognized option %s", key.c_str());
92  }
93  obj_[id] = (loadFromString ?
94  Algorithm::loadFromString<LogisticRegression>(rhs[2].toString(), objname) :
95  Algorithm::load<LogisticRegression>(rhs[2].toString(), objname));
96  }
97  else if (method == "save") {
98  nargchk(nrhs==3 && nlhs<=1);
99  string fname(rhs[2].toString());
100  if (nlhs > 0) {
101  // write to memory, and return string
102  FileStorage fs(fname, FileStorage::WRITE + FileStorage::MEMORY);
103  fs << obj->getDefaultName() << "{";
104  fs << "format" << 3;
105  obj->write(fs);
106  fs << "}";
107  plhs[0] = MxArray(fs.releaseAndGetString());
108  }
109  else
110  // write to disk
111  obj->save(fname);
112  }
113  else if (method == "empty") {
114  nargchk(nrhs==2 && nlhs<=1);
115  plhs[0] = MxArray(obj->empty());
116  }
117  else if (method == "getDefaultName") {
118  nargchk(nrhs==2 && nlhs<=1);
119  plhs[0] = MxArray(obj->getDefaultName());
120  }
121  else if (method == "getVarCount") {
122  nargchk(nrhs==2 && nlhs<=1);
123  plhs[0] = MxArray(obj->getVarCount());
124  }
125  else if (method == "isClassifier") {
126  nargchk(nrhs==2 && nlhs<=1);
127  plhs[0] = MxArray(obj->isClassifier());
128  }
129  else if (method == "isTrained") {
130  nargchk(nrhs==2 && nlhs<=1);
131  plhs[0] = MxArray(obj->isTrained());
132  }
133  else if (method == "train") {
134  nargchk(nrhs>=4 && (nrhs%2)==0 && nlhs<=1);
135  vector<MxArray> dataOptions;
136  int flags = 0;
137  for (int i=4; i<nrhs; i+=2) {
138  string key(rhs[i].toString());
139  if (key == "Data")
140  dataOptions = rhs[i+1].toVector<MxArray>();
141  else if (key == "Flags")
142  flags = rhs[i+1].toInt();
143  else
144  mexErrMsgIdAndTxt("mexopencv:error",
145  "Unrecognized option %s", key.c_str());
146  }
147  Ptr<TrainData> data;
148  if (rhs[2].isChar())
149  data = loadTrainData(rhs[2].toString(),
150  dataOptions.begin(), dataOptions.end());
151  else
152  data = createTrainData(
153  rhs[2].toMat(CV_32F),
154  rhs[3].toMat(CV_32F),
155  dataOptions.begin(), dataOptions.end());
156  bool b = obj->train(data, flags);
157  plhs[0] = MxArray(b);
158  }
159  else if (method == "calcError") {
160  nargchk(nrhs>=4 && (nrhs%2)==0 && nlhs<=2);
161  vector<MxArray> dataOptions;
162  bool test = false;
163  for (int i=4; i<nrhs; i+=2) {
164  string key(rhs[i].toString());
165  if (key == "Data")
166  dataOptions = rhs[i+1].toVector<MxArray>();
167  else if (key == "TestError")
168  test = rhs[i+1].toBool();
169  else
170  mexErrMsgIdAndTxt("mexopencv:error",
171  "Unrecognized option %s", key.c_str());
172  }
173  Ptr<TrainData> data;
174  if (rhs[2].isChar())
175  data = loadTrainData(rhs[2].toString(),
176  dataOptions.begin(), dataOptions.end());
177  else
178  data = createTrainData(
179  rhs[2].toMat(CV_32F),
180  rhs[3].toMat(CV_32F),
181  dataOptions.begin(), dataOptions.end());
182  Mat resp;
183  float err = obj->calcError(data, test, (nlhs>1 ? resp : noArray()));
184  plhs[0] = MxArray(err);
185  if (nlhs>1)
186  plhs[1] = MxArray(resp);
187  }
188  else if (method == "predict") {
189  nargchk(nrhs>=3 && (nrhs%2)==1 && nlhs<=2);
190  int flags = 0;
191  for (int i=3; i<nrhs; i+=2) {
192  string key(rhs[i].toString());
193  if (key == "Flags")
194  flags = rhs[i+1].toInt();
195  else if (key == "RawOutput")
196  UPDATE_FLAG(flags, rhs[i+1].toBool(), StatModel::RAW_OUTPUT);
197  else
198  mexErrMsgIdAndTxt("mexopencv:error",
199  "Unrecognized option %s", key.c_str());
200  }
201  Mat samples(rhs[2].toMat(CV_32F)),
202  results;
203  float f = obj->predict(samples, results, flags);
204  plhs[0] = MxArray(results);
205  if (nlhs>1)
206  plhs[1] = MxArray(f);
207  }
208  else if (method == "get_learnt_thetas") {
209  nargchk(nrhs==2 && nlhs<=1);
210  plhs[0] = MxArray(obj->get_learnt_thetas());
211  }
212  else if (method == "get") {
213  nargchk(nrhs==3 && nlhs<=1);
214  string prop(rhs[2].toString());
215  if (prop == "Iterations")
216  plhs[0] = MxArray(obj->getIterations());
217  else if (prop == "LearningRate")
218  plhs[0] = MxArray(obj->getLearningRate());
219  else if (prop == "MiniBatchSize")
220  plhs[0] = MxArray(obj->getMiniBatchSize());
221  else if (prop == "Regularization")
222  plhs[0] = MxArray(InvRegularizationType[obj->getRegularization()]);
223  else if (prop == "TermCriteria")
224  plhs[0] = MxArray(obj->getTermCriteria());
225  else if (prop == "TrainMethod")
226  plhs[0] = MxArray(InvTrainingMethodType[obj->getTrainMethod()]);
227  else
228  mexErrMsgIdAndTxt("mexopencv:error",
229  "Unrecognized property %s", prop.c_str());
230  }
231  else if (method == "set") {
232  nargchk(nrhs==4 && nlhs==0);
233  string prop(rhs[2].toString());
234  if (prop == "Iterations")
235  obj->setIterations(rhs[3].toInt());
236  else if (prop == "LearningRate")
237  obj->setLearningRate(rhs[3].toDouble());
238  else if (prop == "MiniBatchSize")
239  obj->setMiniBatchSize(rhs[3].toInt());
240  else if (prop == "Regularization")
241  obj->setRegularization(RegularizationType[rhs[3].toString()]);
242  else if (prop == "TermCriteria")
243  obj->setTermCriteria(rhs[3].toTermCriteria());
244  else if (prop == "TrainMethod")
245  obj->setTrainMethod(TrainingMethodType[rhs[3].toString()]);
246  else
247  mexErrMsgIdAndTxt("mexopencv:error",
248  "Unrecognized property %s", prop.c_str());
249  }
250  else
251  mexErrMsgIdAndTxt("mexopencv:error","Unrecognized operation");
252 }
#define UPDATE_FLAG(NUM, TF, BIT)
set or clear a bit in flag depending on bool value
Definition: mexopencv.hpp:159
int toInt() const
Convert MxArray to int.
Definition: MxArray.cpp:489
cv::Ptr< cv::ml::TrainData > loadTrainData(const std::string &filename, std::vector< MxArray >::const_iterator first, std::vector< MxArray >::const_iterator last)
Read a dataset from a CSV file.
mxArray object wrapper for data conversion and manipulation.
Definition: MxArray.hpp:123
void nargchk(bool cond)
Alias for input/ouput arguments number check.
Definition: mexopencv.hpp:166
bool toBool() const
Convert MxArray to bool.
Definition: MxArray.cpp:510
Global constant definitions.
void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[])
Main entry called from Matlab.
cv::Ptr< cv::ml::TrainData > createTrainData(const cv::Mat &samples, const cv::Mat &responses, std::vector< MxArray >::const_iterator first, std::vector< MxArray >::const_iterator last)
Create an instance of TrainData using options in arguments.
std::map wrapper with one-line initialization and lookup method.
Definition: MxArray.hpp:927
Common definitions for the ml module.