19 map<int,Ptr<LogisticRegression> > obj_;
23 (
"Batch", cv::ml::LogisticRegression::BATCH)
24 (
"MiniBatch", cv::ml::LogisticRegression::MINI_BATCH);
28 (cv::ml::LogisticRegression::BATCH,
"Batch")
29 (cv::ml::LogisticRegression::MINI_BATCH,
"MiniBatch");
33 (
"Disable", cv::ml::LogisticRegression::REG_DISABLE)
34 (
"L1", cv::ml::LogisticRegression::REG_L1)
35 (
"L2", cv::ml::LogisticRegression::REG_L2);
39 (cv::ml::LogisticRegression::REG_DISABLE,
"Disable")
40 (cv::ml::LogisticRegression::REG_L1,
"L1")
41 (cv::ml::LogisticRegression::REG_L2,
"L2");
51 void mexFunction(
int nlhs, mxArray *plhs[],
int nrhs,
const mxArray *prhs[])
57 vector<MxArray> rhs(prhs, prhs+nrhs);
58 int id = rhs[0].toInt();
59 string method(rhs[1].toString());
62 if (method ==
"new") {
64 obj_[++last_id] = LogisticRegression::create();
70 Ptr<LogisticRegression> obj = obj_[id];
71 if (method ==
"delete") {
75 else if (method ==
"clear") {
79 else if (method ==
"load") {
80 nargchk(nrhs>=3 && (nrhs%2)==1 && nlhs==0);
82 bool loadFromString =
false;
83 for (
int i=3; i<nrhs; i+=2) {
84 string key(rhs[i].toString());
86 objname = rhs[i+1].toString();
87 else if (key ==
"FromString")
88 loadFromString = rhs[i+1].toBool();
90 mexErrMsgIdAndTxt(
"mexopencv:error",
91 "Unrecognized option %s", key.c_str());
93 obj_[id] = (loadFromString ?
94 Algorithm::loadFromString<LogisticRegression>(rhs[2].toString(), objname) :
95 Algorithm::load<LogisticRegression>(rhs[2].toString(), objname));
97 else if (method ==
"save") {
99 string fname(rhs[2].toString());
102 FileStorage fs(fname, FileStorage::WRITE + FileStorage::MEMORY);
103 fs << obj->getDefaultName() <<
"{";
107 plhs[0] =
MxArray(fs.releaseAndGetString());
113 else if (method ==
"empty") {
115 plhs[0] =
MxArray(obj->empty());
117 else if (method ==
"getDefaultName") {
119 plhs[0] =
MxArray(obj->getDefaultName());
121 else if (method ==
"getVarCount") {
123 plhs[0] =
MxArray(obj->getVarCount());
125 else if (method ==
"isClassifier") {
127 plhs[0] =
MxArray(obj->isClassifier());
129 else if (method ==
"isTrained") {
131 plhs[0] =
MxArray(obj->isTrained());
133 else if (method ==
"train") {
134 nargchk(nrhs>=4 && (nrhs%2)==0 && nlhs<=1);
135 vector<MxArray> dataOptions;
137 for (
int i=4; i<nrhs; i+=2) {
138 string key(rhs[i].toString());
140 dataOptions = rhs[i+1].toVector<
MxArray>();
141 else if (key ==
"Flags")
142 flags = rhs[i+1].
toInt();
144 mexErrMsgIdAndTxt(
"mexopencv:error",
145 "Unrecognized option %s", key.c_str());
150 dataOptions.begin(), dataOptions.end());
153 rhs[2].toMat(CV_32F),
154 rhs[3].toMat(CV_32F),
155 dataOptions.begin(), dataOptions.end());
156 bool b = obj->train(data, flags);
159 else if (method ==
"calcError") {
160 nargchk(nrhs>=4 && (nrhs%2)==0 && nlhs<=2);
161 vector<MxArray> dataOptions;
163 for (
int i=4; i<nrhs; i+=2) {
164 string key(rhs[i].toString());
166 dataOptions = rhs[i+1].toVector<
MxArray>();
167 else if (key ==
"TestError")
170 mexErrMsgIdAndTxt(
"mexopencv:error",
171 "Unrecognized option %s", key.c_str());
176 dataOptions.begin(), dataOptions.end());
179 rhs[2].toMat(CV_32F),
180 rhs[3].toMat(CV_32F),
181 dataOptions.begin(), dataOptions.end());
183 float err = obj->calcError(data, test, (nlhs>1 ? resp : noArray()));
188 else if (method ==
"predict") {
189 nargchk(nrhs>=3 && (nrhs%2)==1 && nlhs<=2);
191 for (
int i=3; i<nrhs; i+=2) {
192 string key(rhs[i].toString());
194 flags = rhs[i+1].toInt();
195 else if (key ==
"RawOutput")
196 UPDATE_FLAG(flags, rhs[i+1].toBool(), StatModel::RAW_OUTPUT);
198 mexErrMsgIdAndTxt(
"mexopencv:error",
199 "Unrecognized option %s", key.c_str());
201 Mat samples(rhs[2].toMat(CV_32F)),
203 float f = obj->predict(samples, results, flags);
208 else if (method ==
"get_learnt_thetas") {
210 plhs[0] =
MxArray(obj->get_learnt_thetas());
212 else if (method ==
"get") {
214 string prop(rhs[2].toString());
215 if (prop ==
"Iterations")
216 plhs[0] =
MxArray(obj->getIterations());
217 else if (prop ==
"LearningRate")
218 plhs[0] =
MxArray(obj->getLearningRate());
219 else if (prop ==
"MiniBatchSize")
220 plhs[0] =
MxArray(obj->getMiniBatchSize());
221 else if (prop ==
"Regularization")
222 plhs[0] =
MxArray(InvRegularizationType[obj->getRegularization()]);
223 else if (prop ==
"TermCriteria")
224 plhs[0] =
MxArray(obj->getTermCriteria());
225 else if (prop ==
"TrainMethod")
226 plhs[0] =
MxArray(InvTrainingMethodType[obj->getTrainMethod()]);
228 mexErrMsgIdAndTxt(
"mexopencv:error",
229 "Unrecognized property %s", prop.c_str());
231 else if (method ==
"set") {
233 string prop(rhs[2].toString());
234 if (prop ==
"Iterations")
235 obj->setIterations(rhs[3].toInt());
236 else if (prop ==
"LearningRate")
237 obj->setLearningRate(rhs[3].toDouble());
238 else if (prop ==
"MiniBatchSize")
239 obj->setMiniBatchSize(rhs[3].toInt());
240 else if (prop ==
"Regularization")
241 obj->setRegularization(RegularizationType[rhs[3].toString()]);
242 else if (prop ==
"TermCriteria")
243 obj->setTermCriteria(rhs[3].toTermCriteria());
244 else if (prop ==
"TrainMethod")
245 obj->setTrainMethod(TrainingMethodType[rhs[3].toString()]);
247 mexErrMsgIdAndTxt(
"mexopencv:error",
248 "Unrecognized property %s", prop.c_str());
251 mexErrMsgIdAndTxt(
"mexopencv:error",
"Unrecognized operation");
#define UPDATE_FLAG(NUM, TF, BIT)
set or clear a bit in flag depending on bool value
int toInt() const
Convert MxArray to int.
cv::Ptr< cv::ml::TrainData > loadTrainData(const std::string &filename, std::vector< MxArray >::const_iterator first, std::vector< MxArray >::const_iterator last)
Read a dataset from a CSV file.
mxArray object wrapper for data conversion and manipulation.
void nargchk(bool cond)
Alias for input/ouput arguments number check.
bool toBool() const
Convert MxArray to bool.
Global constant definitions.
void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[])
Main entry called from Matlab.
cv::Ptr< cv::ml::TrainData > createTrainData(const cv::Mat &samples, const cv::Mat &responses, std::vector< MxArray >::const_iterator first, std::vector< MxArray >::const_iterator last)
Create an instance of TrainData using options in arguments.
std::map wrapper with one-line initialization and lookup method.
Common definitions for the ml module.