00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00028 #ifndef _FUNC_EVALUATOR_H_
00029 #define _FUNC_EVALUATOR_H_
00030
00031 #include <coconut_config.h>
00032 #include <evaluator.h>
00033 #include <expression.h>
00034 #include <model.h>
00035 #include <eval_main.h>
00036 #include <linalg.h>
00037 #include <math.h>
00038 #include <api_exception.h>
00039
00040 using namespace vgtl;
00041
00042 namespace coco {
00043
00045 typedef double (*func_evaluator)(const std::vector<double>* __x,
00046 const variable_indicator& __v);
00047
00049
00053 struct func_eval_type
00054 {
00055 const std::vector<double>* x;
00056 std::vector<double>* cache;
00057 const model* mod;
00058 union { void *p; double d; } u;
00059 double r;
00060 unsigned int n;
00061 };
00062
00064
00069 class func_eval :
00070 public cached_forward_evaluator_base<func_eval_type,expression_node,double,
00071 expression_const_walker>
00072 {
00073 private:
00075 typedef cached_forward_evaluator_base<func_eval_type,expression_node,
00076 double, expression_const_walker> _Base;
00077
00078 protected:
00081 bool is_cached(const node_data_type& __data)
00082 {
00083 if(__data.operator_type == EXPRINFO_LIN ||
00084 __data.operator_type == EXPRINFO_QUAD)
00085 return true;
00086 if(eval_data.cache && __data.n_parents > 1 && __data.n_children > 0 &&
00087 v_ind->match(__data.var_indicator()))
00088 return true;
00089 else
00090 return false;
00091 }
00092
00093 private:
00096 double __power(double __coeff, double __x, int __exp)
00097 {
00098 double ret;
00099 if(__exp == 0)
00100 ret = 1.;
00101 else
00102 {
00103 double k = __coeff*__x;
00104 switch(__exp)
00105 {
00106 case 1:
00107 ret = k;
00108 break;
00109 case 2:
00110 ret = k*k;
00111 break;
00112 case -1:
00113 ret = 1./k;
00114 break;
00115 case -2:
00116 ret = 1./(k*k);
00117 break;
00118 default:
00119 if(__exp & 1)
00120 {
00121 if(k < 0)
00122 ret = -std::pow(-k, __exp);
00123 else
00124 ret = std::pow(k, __exp);
00125 }
00126 else
00127 ret = std::pow(fabs(k), __exp);
00128 break;
00129 }
00130 }
00131 return ret;
00132 }
00133
00134 public:
00140 func_eval(const std::vector<double>& __x, const variable_indicator& __v,
00141 const model& __m, std::vector<double>* __c) : _Base()
00142 {
00143 eval_data.x = &__x;
00144 eval_data.cache = __c;
00145 eval_data.mod = &__m;
00146 eval_data.u.d = 0.;
00147 v_ind = &__v;
00148 eval_data.n = 0;
00149 }
00150
00152 func_eval(const func_eval& __v) : _Base(__v) {}
00153
00155 ~func_eval() {}
00156
00158 expression_const_walker short_cut_to(const expression_node& __data)
00159 { return eval_data.mod->node(0); }
00160
00164 void new_point(const std::vector<double>& __x, const variable_indicator& __v)
00165 {
00166 eval_data.x = &__x;
00167 v_ind = &__v;
00168 }
00169
00171
00172 void initialize() { return; }
00173
00174 int initialize(const expression_node& __data)
00175 {
00176 eval_data.n = 0;
00177 if(__data.ev != NULL && (*__data.ev)[FUNC_EVALUATOR] != NULL)
00178
00179 {
00180 eval_data.r = (*(func_evaluator)(*__data.ev)[FUNC_EVALUATOR])(eval_data.x,
00181 *v_ind);
00182 return 0;
00183 }
00184 else
00185 {
00186 switch(__data.operator_type)
00187 {
00188 case EXPRINFO_VARIABLE:
00189 eval_data.r = (*eval_data.x)[__data.params.nn()];
00190 break;
00191 case EXPRINFO_SUM:
00192 case EXPRINFO_PROD:
00193 case EXPRINFO_MAX:
00194 case EXPRINFO_MIN:
00195 case EXPRINFO_INVERT:
00196 case EXPRINFO_DIV:
00197 case EXPRINFO_CONSTANT:
00198 eval_data.r = __data.params.nd();
00199 break;
00200 case EXPRINFO_MONOME:
00201 case EXPRINFO_IN:
00202 case EXPRINFO_AND:
00203 case EXPRINFO_NOGOOD:
00204 eval_data.r = 1.;
00205 break;
00206 case EXPRINFO_HISTOGRAM:
00207 eval_data.u.p =
00208 (void*)new std::vector<unsigned int>(0,
00209 (__data.params.i().size()+1)>>1);
00210 eval_data.r = 1.;
00211 break;
00212 case EXPRINFO_ALLDIFF:
00213 eval_data.u.p = (void*) new std::vector<double>;
00214 ((std::vector<double>*)eval_data.u.p)->reserve(__data.n_children);
00215
00216 case EXPRINFO_MEAN:
00217 case EXPRINFO_IF:
00218 case EXPRINFO_OR:
00219 case EXPRINFO_NOT:
00220 case EXPRINFO_COUNT:
00221 case EXPRINFO_SCPROD:
00222 case EXPRINFO_NORM:
00223 case EXPRINFO_LEVEL:
00224 eval_data.r = 0.;
00225 break;
00226 case EXPRINFO_DET:
00227 case EXPRINFO_PSD:
00228
00229 break;
00230 case EXPRINFO_COND:
00231 case EXPRINFO_FEM:
00232 case EXPRINFO_MPROD:
00233
00234 break;
00235 case EXPRINFO_LOOKUP:
00236 case EXPRINFO_PWLIN:
00237 case EXPRINFO_SPLINE:
00238 case EXPRINFO_PWCONSTLC:
00239 case EXPRINFO_PWCONSTRC:
00240 eval_data.r = -COCO_INF;
00241 break;
00242 }
00243 return 1;
00244 }
00245 }
00246
00247 void calculate(const expression_node& __data)
00248 {
00249 if(__data.operator_type > 0)
00250 {
00251 eval_data.r = __data.f_evaluate(-1, __data.params.nn(), *eval_data.x,
00252 *v_ind, eval_data.r, 0, NULL);
00253 }
00254 }
00255
00256 void retrieve_from_cache(const expression_node& __data)
00257 {
00258 if(__data.operator_type == EXPRINFO_LIN)
00259 eval_data.r = linalg::linalg_dot(eval_data.mod->lin[__data.params.nn()],
00260 *eval_data.x,0.);
00261 else if(__data.operator_type == EXPRINFO_QUAD)
00262 {
00263 std::vector<double> irslt;
00264
00265 linalg::linalg_matvec(__data.params.m(), *eval_data.x, irslt);
00266 eval_data.r = irslt.back();
00267 irslt.pop_back();
00268 eval_data.r += linalg::linalg_dot(irslt,*eval_data.x,0.);
00269 }
00270 else
00271 eval_data.r = (*eval_data.cache)[__data.node_num];
00272 }
00273
00274 int update(const double& __rval)
00275 {
00276 eval_data.r = __rval;
00277 return 0;
00278 }
00279
00280 int update(const expression_node& __data, const double& __rval)
00281 {
00282 double __x;
00283 int ret = 0;
00284 if(__data.operator_type < 0)
00285 {
00286 switch(__data.operator_type)
00287 {
00288 case EXPRINFO_CONSTANT:
00289 case EXPRINFO_VARIABLE:
00290 break;
00291 case EXPRINFO_SUM:
00292 case EXPRINFO_MEAN:
00293 eval_data.r += __data.coeffs[eval_data.n++]*__rval;
00294 break;
00295 case EXPRINFO_PROD:
00296
00297 eval_data.r *= __rval;
00298 if(eval_data.r == 0) ret=-1;
00299 break;
00300 case EXPRINFO_MONOME:
00301 eval_data.r *= __power(__data.coeffs[eval_data.n],
00302 __rval, __data.params.n()[eval_data.n]);
00303 if(eval_data.r == 0) ret=-1;
00304 eval_data.n++;
00305 break;
00306 case EXPRINFO_MAX:
00307 __x = __rval * __data.coeffs[eval_data.n++];
00308 if(__x > eval_data.r)
00309 eval_data.r = __x;
00310 break;
00311 case EXPRINFO_MIN:
00312 __x = __rval * __data.coeffs[eval_data.n++];
00313 if(__x < eval_data.r)
00314 eval_data.r = __x;
00315 break;
00316 case EXPRINFO_SCPROD:
00317 { double h = __data.coeffs[eval_data.n]*__rval;
00318
00319
00320 if(eval_data.n & 1)
00321 eval_data.r += eval_data.u.d*h;
00322 else
00323 eval_data.u.d = h;
00324 }
00325 eval_data.n++;
00326 break;
00327 case EXPRINFO_NORM:
00328 if(__data.params.nd() == COCO_INF)
00329 {
00330 __x = fabs(__rval*__data.coeffs[eval_data.n]);
00331 if(__x > eval_data.r)
00332 eval_data.r = __x;
00333 }
00334 else
00335 { __x = fabs(__data.coeffs[eval_data.n]*__rval);
00336 eval_data.r += std::pow(__x, __data.params.nd());
00337 }
00338 eval_data.n++;
00339 if(eval_data.n == __data.n_children && __data.params.nd() != COCO_INF)
00340 eval_data.r = std::pow(eval_data.r, 1./__data.params.nd());
00341 break;
00342 case EXPRINFO_INVERT:
00343 eval_data.r /= __rval;
00344 break;
00345 case EXPRINFO_DIV:
00346
00347
00348 if(eval_data.n == 0)
00349 {
00350 eval_data.r *= __rval;
00351 if(eval_data.r == 0) ret=-1;
00352 }
00353 else
00354 eval_data.r /= __rval;
00355 ++eval_data.n;
00356 break;
00357 case EXPRINFO_SQUARE:
00358 __x = __data.coeffs[0]*__rval+__data.params.nd();
00359 eval_data.r = __x*__x;
00360 break;
00361 case EXPRINFO_INTPOWER:
00362 eval_data.r = __power(__data.coeffs[0], __rval, __data.params.nn());
00363 break;
00364 case EXPRINFO_SQROOT:
00365 eval_data.r = std::sqrt(__data.coeffs[0]*__rval+__data.params.nd());
00366 break;
00367 case EXPRINFO_ABS:
00368 eval_data.r = fabs(__data.coeffs[0]*__rval+__data.params.nd());
00369 break;
00370 case EXPRINFO_POW:
00371 __x = __rval * __data.coeffs[eval_data.n];
00372 if(eval_data.n == 0)
00373 eval_data.r = __x+__data.params.nd();
00374 else
00375 eval_data.r = std::pow(eval_data.r, __x);
00376 ++eval_data.n;
00377 break;
00378 case EXPRINFO_EXP:
00379 eval_data.r = std::exp(__rval*__data.coeffs[0]+__data.params.nd());
00380 break;
00381 case EXPRINFO_LOG:
00382 eval_data.r = std::log(__rval*__data.coeffs[0]+__data.params.nd());
00383 break;
00384 case EXPRINFO_SIN:
00385 eval_data.r = std::sin(__rval*__data.coeffs[0]+__data.params.nd());
00386 break;
00387 case EXPRINFO_COS:
00388 eval_data.r = std::cos(__rval*__data.coeffs[0]+__data.params.nd());
00389 break;
00390 case EXPRINFO_ATAN2:
00391 __x = __rval * __data.coeffs[eval_data.n];
00392 if(eval_data.n == 0)
00393 eval_data.r = __x;
00394 else
00395 eval_data.r = std::atan2(eval_data.r, __x);
00396 ++eval_data.n;
00397 break;
00398 case EXPRINFO_GAUSS:
00399 __x = (__data.coeffs[0]*__rval-__data.params.d()[0])/
00400 __data.params.d()[1];
00401 eval_data.r = std::exp(-__x*__x);
00402 break;
00403 case EXPRINFO_POLY:
00404 if(!__data.params.empty())
00405 {
00406 __x = __data.coeffs[eval_data.n]*__rval;
00407 eval_data.r = __data.params.d().back();
00408 for(int i = (int)__data.params.d().size()-2; i >= 0; --i)
00409 {
00410 eval_data.r *= __x;
00411 eval_data.r += __data.params.d()[i];
00412 }
00413 }
00414 else
00415 eval_data.r = 0;
00416 break;
00417 case EXPRINFO_LIN:
00418 case EXPRINFO_QUAD:
00419
00420 break;
00421 case EXPRINFO_IN:
00422 {
00423 __x = __data.coeffs[eval_data.n]*__rval;
00424 const interval& i(__data.params.i()[eval_data.n]);
00425 if(eval_data.r != -1 && i.contains(__x))
00426 {
00427 if(eval_data.r == 1 && (__x == i.inf() || __x == i.sup()))
00428 eval_data.r = 0;
00429 }
00430 else
00431 {
00432 eval_data.r = -1;
00433 ret = -1;
00434 }
00435 }
00436 eval_data.n++;
00437 break;
00438 case EXPRINFO_IF:
00439 __x = __rval * __data.coeffs[eval_data.n];
00440 if(eval_data.n == 0)
00441 {
00442 const interval& i(__data.params.ni());
00443 if(!i.contains(__x))
00444 ret = 1;
00445 }
00446 else
00447 {
00448 eval_data.r = __x;
00449 ret = -1;
00450 }
00451 eval_data.n += ret+1;
00452 break;
00453 case EXPRINFO_AND:
00454 { __x = __data.coeffs[eval_data.n]*__rval;
00455 const interval& i(__data.params.i()[eval_data.n]);
00456 if(eval_data.r == 1 && !i.contains(__x))
00457 {
00458 eval_data.r = 0;
00459 ret = -1;
00460 }
00461 }
00462 eval_data.n++;
00463 break;
00464 case EXPRINFO_OR:
00465 { __x = __data.coeffs[eval_data.n]*__rval;
00466 const interval& i(__data.params.i()[eval_data.n]);
00467 if(eval_data.r == 0 && i.contains(__x))
00468 {
00469 eval_data.r = 1;
00470 ret = -1;
00471 }
00472 }
00473 eval_data.n++;
00474 break;
00475 case EXPRINFO_NOT:
00476 { __x = __data.coeffs[0]*__rval;
00477 const interval& i(__data.params.ni());
00478 if(i.contains(__x))
00479 eval_data.r = 0;
00480 else
00481 eval_data.r = 1;
00482 }
00483 break;
00484 case EXPRINFO_IMPLIES:
00485 { const interval& i(__data.params.i()[eval_data.n]);
00486 __x = __rval * __data.coeffs[eval_data.n];
00487 if(eval_data.n == 0)
00488 {
00489 if(!i.contains(__x))
00490 {
00491 eval_data.r = 1;
00492 ret = -1;
00493 }
00494 }
00495 else
00496 eval_data.r = i.contains(__x) ? 1 : 0;
00497 ++eval_data.n;
00498 }
00499 break;
00500 case EXPRINFO_COUNT:
00501 { __x = __data.coeffs[eval_data.n]*__rval;
00502 const interval& i(__data.params.i()[eval_data.n]);
00503 if(i.contains(__x))
00504 eval_data.r += 1;
00505 }
00506 eval_data.n++;
00507 break;
00508 case EXPRINFO_ALLDIFF:
00509 { __x = __data.coeffs[eval_data.n]*__rval;
00510 for(std::vector<double>::const_iterator _b =
00511 ((std::vector<double>*)eval_data.u.p)->begin();
00512 _b != ((std::vector<double>*)eval_data.u.p)->end(); ++_b)
00513 {
00514 if(fabs(__x-*_b) <= __data.params.nd())
00515 {
00516 eval_data.r = 0;
00517 ret = -1;
00518 break;
00519 }
00520 }
00521 if(ret != -1)
00522 ((std::vector<double>*) eval_data.u.p)->push_back(__x);
00523 }
00524 eval_data.n++;
00525 if(eval_data.n == __data.n_children || ret == -1)
00526 delete (std::vector<double>*) eval_data.u.p;
00527 break;
00528 case EXPRINFO_HISTOGRAM:
00529 { __x = __data.coeffs[eval_data.n]*__rval;
00530 unsigned int ni = (__data.params.i().size()+1)>>1;
00531 for(unsigned int i=0; i<ni; i++)
00532 {
00533 if(__data.params.i()[i<<1].contains(__x))
00534 (*(std::vector<unsigned int>*)eval_data.u.p)[i]++;
00535 }
00536 eval_data.n++;
00537 if(eval_data.n == __data.n_children || ret == -1)
00538 {
00539 for(unsigned int i=0; i<ni; i++)
00540 {
00541 if(!__data.params.i()[i<<1].contains(
00542 (*(std::vector<unsigned int>*)eval_data.u.p)[i]+0.))
00543 {
00544 eval_data.r = 0;
00545 break;
00546 }
00547 }
00548 delete (std::vector<unsigned int>*) eval_data.u.p;
00549 }
00550 }
00551 break;
00552 case EXPRINFO_LEVEL:
00553 { int h = (int)eval_data.r;
00554 __x = __data.coeffs[eval_data.n]*__rval;
00555 interval _h;
00556
00557 if(h != INT_MAX)
00558 {
00559 while(h < __data.params.im().nrows())
00560 {
00561 _h = __data.params.im()[h][eval_data.n];
00562 if(_h.contains(__x))
00563 break;
00564 h++;
00565 }
00566 if(h == __data.params.im().nrows())
00567 {
00568 ret = -1;
00569 eval_data.r = INT_MAX;
00570 }
00571 else
00572 eval_data.r = h;
00573 }
00574 }
00575 eval_data.n++;
00576 break;
00577 case EXPRINFO_NEIGHBOR:
00578 if(eval_data.n == 0)
00579 eval_data.r = __data.coeffs[0]*__rval;
00580 else
00581 {
00582 double h = eval_data.r;
00583 eval_data.r = 0;
00584 __x = __data.coeffs[1]*__rval;
00585 for(unsigned int i = 0; i < __data.params.n().size(); i+=2)
00586 {
00587 if(h == __data.params.n()[i] && __x == __data.params.n()[i+1])
00588 {
00589 eval_data.r = 1;
00590 break;
00591 }
00592 }
00593 }
00594 eval_data.n++;
00595 break;
00596 case EXPRINFO_NOGOOD:
00597 {
00598 __x = __data.coeffs[eval_data.n]*__rval;
00599 if(eval_data.r == 0 || __data.params.n()[eval_data.n] != __x)
00600 {
00601 eval_data.r = 0;
00602 ret = -1;
00603 }
00604 }
00605 eval_data.n++;
00606 break;
00607 case EXPRINFO_EXPECTATION:
00608 throw nyi_exception("func_evaluator: EXPECTATION");
00609 break;
00610 case EXPRINFO_INTEGRAL:
00611 throw nyi_exception("func_evaluator: INTEGRAL");
00612 break;
00613 case EXPRINFO_LOOKUP:
00614 throw nyi_exception("func_evaluator: LOOKUP");
00615 case EXPRINFO_PWLIN:
00616 {
00617 int idx;
00618 __x = __data.coeffs[0]*__rval;
00619 for(idx=0; idx < __data.params.m().nrows(); ++idx)
00620 if(__x < __data.params.m()[idx][0]) break;
00621 if(idx == __data.params.m().nrows())
00622 eval_data.r = COCO_INF;
00623 else if(idx == 0)
00624 eval_data.r = -COCO_INF;
00625 else
00626 {
00627 double x0 = __data.params.m()[idx-1][0];
00628 double f0 = __data.params.m()[idx-1][1];
00629 eval_data.r = (__data.params.m()[idx][1]-f0)/(__data.params.m()[idx][0]-x0);
00630 eval_data.r *= __x-x0;
00631 eval_data.r += f0;
00632 }
00633 }
00634 break;
00635 case EXPRINFO_SPLINE:
00636 throw nyi_exception("func_evaluator: SPLINE");
00637 case EXPRINFO_PWCONSTLC:
00638 {
00639 int idx;
00640 __x = __data.coeffs[0]*__rval;
00641 for(idx=0; idx < __data.params.m().nrows(); ++idx)
00642 if(__x < __data.params.m()[idx][0]) break;
00643 if(idx == __data.params.m().nrows())
00644 eval_data.r = COCO_INF;
00645 else if(idx == 0)
00646 eval_data.r = -COCO_INF;
00647 else
00648 eval_data.r = __data.params.m()[idx-1][1];
00649 }
00650 break;
00651 case EXPRINFO_PWCONSTRC:
00652 {
00653 int idx;
00654 __x = __data.coeffs[0]*__rval;
00655 for(idx=0; idx < __data.params.m().nrows(); ++idx)
00656 if(__x <= __data.params.m()[idx][0]) break;
00657 if(idx == __data.params.m().nrows())
00658 eval_data.r = COCO_INF;
00659 else if(idx == 0)
00660 eval_data.r = -COCO_INF;
00661 else
00662 eval_data.r = __data.params.m()[idx-1][1];
00663 }
00664 break;
00665 case EXPRINFO_DET:
00666 case EXPRINFO_COND:
00667 case EXPRINFO_PSD:
00668 case EXPRINFO_MPROD:
00669 case EXPRINFO_FEM:
00670 throw nyi_exception("func_evaluator: Matrix Operations");
00671 break;
00672 case EXPRINFO_RE:
00673 case EXPRINFO_IM:
00674 case EXPRINFO_ARG:
00675 case EXPRINFO_CPLXCONJ:
00676 throw nyi_exception("func_evaluator: Complex Operations");
00677 break;
00678 case EXPRINFO_CMPROD:
00679 case EXPRINFO_CGFEM:
00680 throw nyi_exception("func_evaluator: Const Matrix Operations");
00681 default:
00682 throw api_exception(apiee_evaluator,
00683 std::string("func_evaluator: unknown function type ")+
00684 convert_to_str(__data.operator_type));
00685 break;
00686 }
00687 }
00688 else if(__data.operator_type > 0)
00689
00690 eval_data.r = __data.f_evaluate(eval_data.n++, __data.params.nn(),
00691 *eval_data.x, *v_ind, eval_data.r,
00692 __rval, NULL);
00693
00694 if(eval_data.cache )
00695 (*eval_data.cache)[__data.node_num] = eval_data.r;
00696 return ret;
00697 }
00698
00699 double calculate_value(bool eval_all)
00700 {
00701 return eval_data.r;
00702 }
00704 };
00705
00706 }
00707
00708 #endif