mexopencv  0.1
mex interface for opencv library
ConjGradSolver_.cpp
Go to the documentation of this file.
1 
8 #include "mexopencv.hpp"
9 using namespace std;
10 using namespace cv;
11 
12 // Persistent objects
13 namespace {
15 int last_id = 0;
17 map<int,Ptr<ConjGradSolver> > obj_;
18 
20 class MatlabFunction : public cv::MinProblemSolver::Function
21 {
22 public:
30  MatlabFunction(int num_dims, const string &func, const string &grad_func = "", double h = 1e-3)
31  : dims(num_dims), fun_name(func), grad_fun_name(grad_func), gradeps(h)
32  {}
33 
50  double calc(const double *x) const
51  {
52  // create input to evaluate objective function
53  mxArray *lhs, *rhs[2];
54  rhs[0] = MxArray(fun_name);
55  rhs[1] = MxArray(vector<double>(x, x + dims));
56 
57  // evaluate specified function in MATLAB as:
58  // val = feval("fun_name", x)
59  double val;
60  if (mexCallMATLAB(1, &lhs, 2, rhs, "feval") == 0) {
61  MxArray res(lhs);
62  CV_Assert(res.isDouble() && !res.isComplex() && res.numel() == 1);
63  val = res.at<double>(0);
64  }
65  else {
66  //TODO: error
67  val = 0;
68  }
69 
70  // cleanup
71  mxDestroyArray(lhs);
72  mxDestroyArray(rhs[0]);
73  mxDestroyArray(rhs[1]);
74 
75  // return scalar value of objective function evaluated at x
76  return val;
77  }
78 
95  void getGradient(const double* x, double* grad) /*const*/
96  {
97  // if no function is specified, approximate the gradient using
98  // finite difference method: F'(x) = (F(x+h) - F(x-h)) / 2*h
99  if (grad_fun_name.empty()) {
100  cv::MinProblemSolver::Function::getGradient(x, grad);
101  return;
102  }
103 
104  // create input to evaluate gradient function
105  mxArray *lhs, *rhs[2];
106  rhs[0] = MxArray(grad_fun_name);
107  rhs[1] = MxArray(vector<double>(x, x + dims));
108 
109  // evaluate specified function in MATLAB as:
110  // grad = feval("grad_fun_name", x)
111  if (mexCallMATLAB(1, &lhs, 2, rhs, "feval") == 0) {
112  MxArray res(lhs);
113  CV_Assert(res.isDouble() && !res.isComplex() && res.ndims() == 2);
114  vector<double> v(res.toVector<double>());
115  CV_Assert(v.size() == dims);
116  std::copy(v.begin(), v.end(), grad);
117  }
118  else {
119  //TODO: error
120  std::fill(grad, grad + dims, 0.0);
121  }
122 
123  // cleanup
124  mxDestroyArray(lhs);
125  mxDestroyArray(rhs[0]);
126  mxDestroyArray(rhs[1]);
127  }
128 
132  double getGradientEps() const
133  {
134  return gradeps;
135  }
136 
140  int getDims() const
141  {
142  return dims;
143  }
144 
148  MxArray toStruct() const
149  {
151  s.set("dims", dims);
152  s.set("fun", fun_name);
153  s.set("gradfun", grad_fun_name);
154  s.set("gradeps", gradeps);
155  return s;
156  }
157 
166  static Ptr<MatlabFunction> create(const MxArray &s)
167  {
168  if (!s.isStruct() || s.numel()!=1)
169  mexErrMsgIdAndTxt("mexopencv:error", "Invalid objective function");
170  return makePtr<MatlabFunction>(
171  s.at("dims").toInt(),
172  s.at("fun").toString(),
173  s.isField("gradfun") ? s.at("gradfun").toString() : "",
174  s.isField("gradeps") ? s.at("gradeps").toDouble() : 1e-3);
175  }
176 
177 private:
178  int dims;
179  string fun_name;
180  string grad_fun_name;
181  double gradeps;
182 };
183 }
184 
192 void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[])
193 {
194  // Check the number of arguments
195  nargchk(nrhs>=2 && nlhs<=2);
196 
197  // Arguments vector
198  vector<MxArray> rhs(prhs, prhs+nrhs);
199  int id = rhs[0].toInt();
200  string method(rhs[1].toString());
201 
202  // Constructor is called. Create a new object from argument
203  if (method == "new") {
204  nargchk(nrhs>=2 && (nrhs%2)==0 && nlhs<=1);
205  Ptr<MinProblemSolver::Function> f;
206  TermCriteria termcrit(TermCriteria::MAX_ITER+TermCriteria::EPS, 5000, 1e-6);
207  for (int i=2; i<nrhs; i+=2) {
208  string key(rhs[i].toString());
209  if (key=="Function")
210  f = MatlabFunction::create(rhs[i+1]);
211  else if (key=="TermCriteria")
212  termcrit = rhs[i+1].toTermCriteria();
213  else
214  mexErrMsgIdAndTxt("mexopencv:error",
215  "Unrecognized option %s", key.c_str());
216  }
217  obj_[++last_id] = ConjGradSolver::create(f, termcrit);
218  plhs[0] = MxArray(last_id);
219  return;
220  }
221 
222  // Big operation switch
223  Ptr<ConjGradSolver> obj = obj_[id];
224  if (method == "delete") {
225  nargchk(nrhs==2 && nlhs==0);
226  obj_.erase(id);
227  }
228  else if (method == "clear") {
229  nargchk(nrhs==2 && nlhs==0);
230  obj->clear();
231  }
232  else if (method == "load") {
233  //TODO
234  nargchk(false);
235  }
236  else if (method == "save") {
237  //TODO
238  nargchk(false);
239  }
240  else if (method == "empty") {
241  nargchk(nrhs==2 && nlhs<=1);
242  plhs[0] = MxArray(obj->empty());
243  }
244  else if (method == "getDefaultName") {
245  nargchk(nrhs==2 && nlhs<=1);
246  plhs[0] = MxArray(obj->getDefaultName());
247  }
248  else if (method == "minimize") {
249  nargchk(nrhs==3 && nlhs<=2);
250  Mat x(rhs[2].toMat(CV_64F));
251  double fx = obj->minimize(x);
252  plhs[0] = MxArray(x);
253  if (nlhs>1)
254  plhs[1] = MxArray(fx);
255  }
256  else if (method == "get") {
257  nargchk(nrhs==3 && nlhs<=1);
258  string prop(rhs[2].toString());
259  if (prop == "Function") {
260  Ptr<MinProblemSolver::Function> f(obj->getFunction());
261  plhs[0] = (f.empty()) ? MxArray::Struct() :
262  (f.dynamicCast<MatlabFunction>())->toStruct();
263  }
264  else if (prop == "TermCriteria")
265  plhs[0] = MxArray(obj->getTermCriteria());
266  else
267  mexErrMsgIdAndTxt("mexopencv:error", "Unrecognized property %s", prop.c_str());
268  }
269  else if (method == "set") {
270  nargchk(nrhs==4 && nlhs==0);
271  string prop(rhs[2].toString());
272  if (prop == "Function")
273  obj->setFunction(MatlabFunction::create(rhs[3]));
274  else if (prop == "TermCriteria")
275  obj->setTermCriteria(rhs[3].toTermCriteria());
276  else
277  mexErrMsgIdAndTxt("mexopencv:error", "Unrecognized property %s", prop.c_str());
278  }
279  else
280  mexErrMsgIdAndTxt("mexopencv:error","Unrecognized operation");
281 }
bool isStruct() const
Determine whether input is structure array.
Definition: MxArray.hpp:708
MxArray toStruct(const std::vector< cv::ml::DTrees::Node > &nodes)
Convert tree nodes to struct array.
void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[])
Main entry called from Matlab.
mxArray object wrapper for data conversion and manipulation.
Definition: MxArray.hpp:123
void nargchk(bool cond)
Alias for input/ouput arguments number check.
Definition: mexopencv.hpp:166
static MxArray Struct(const char **fields=NULL, int nfields=0, mwSize m=1, mwSize n=1)
Create a new struct array.
Definition: MxArray.hpp:312
bool isField(const std::string &fieldName) const
Determine whether a struct array has a specified field.
Definition: MxArray.hpp:743
Global constant definitions.
mwSize numel() const
Number of elements in an array.
Definition: MxArray.hpp:546
T at(mwIndex index) const
Template for numeric array element accessor.
Definition: MxArray.hpp:1250