00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00028 #ifndef _FUNC_EVALUATOR_H_
00029 #define _FUNC_EVALUATOR_H_
00030
00031 #include <coconut_config.h>
00032 #include <evaluator.h>
00033 #include <expression.h>
00034 #include <model.h>
00035 #include <eval_main.h>
00036 #include <linalg.h>
00037 #include <math.h>
00038 #include <api_exception.h>
00039
00040 using namespace vgtl;
00041
00042 namespace coco {
00043
00045 typedef double (*func_evaluator)(const std::vector<double>* __x,
00046 const variable_indicator& __v);
00047
00049
00053 struct func_cache {
00054 double f;
00055 bool v;
00056 };
00057
00059
00063 struct func_eval_type
00064 {
00065 const std::vector<double>* x;
00066 std::vector<func_cache>* cache;
00067 const model* mod;
00068 union { void *p; double d; } u;
00069 double r;
00070 unsigned int n;
00071 };
00072
00074
00079 class func_eval :
00080 public cached_forward_evaluator_base<func_eval_type,expression_node,double,
00081 expression_const_walker>
00082 {
00083 private:
00085 typedef cached_forward_evaluator_base<func_eval_type,expression_node,
00086 double, expression_const_walker> _Base;
00087
00088 protected:
00091 bool is_cached(const node_data_type& __data)
00092 {
00093 if(__data.operator_type == EXPRINFO_LIN ||
00094 __data.operator_type == EXPRINFO_QUAD)
00095 return true;
00096 if(eval_data.cache && __data.n_parents > 1 && __data.n_children > 0 &&
00097 (*eval_data.cache)[__data.node_num].v)
00098 return true;
00099 else
00100 return false;
00101 }
00102
00103 private:
00106 double __power(double __coeff, double __x, int __exp)
00107 {
00108 double ret;
00109 if(__exp == 0)
00110 ret = 1.;
00111 else
00112 {
00113 double k = __coeff*__x;
00114 switch(__exp)
00115 {
00116 case 1:
00117 ret = k;
00118 break;
00119 case 2:
00120 ret = k*k;
00121 break;
00122 case -1:
00123 ret = 1./k;
00124 break;
00125 case -2:
00126 ret = 1./(k*k);
00127 break;
00128 default:
00129 if(__exp & 1)
00130 {
00131 if(k < 0)
00132 ret = -std::pow(-k, __exp);
00133 else
00134 ret = std::pow(k, __exp);
00135 }
00136 else
00137 ret = std::pow(fabs(k), __exp);
00138 break;
00139 }
00140 }
00141 return ret;
00142 }
00143
00144 public:
00150 func_eval(const std::vector<double>& __x, const variable_indicator& __v,
00151 const model& __m, std::vector<func_cache>* __c) : _Base()
00152 {
00153 eval_data.x = &__x;
00154 eval_data.cache = __c;
00155 eval_data.mod = &__m;
00156 eval_data.u.d = 0.;
00157 v_ind = &__v;
00158 eval_data.n = 0;
00159 }
00160
00162 func_eval(const func_eval& __v) : _Base(__v) {}
00163
00165 ~func_eval() {}
00166
00168 expression_const_walker short_cut_to(const expression_node& __data)
00169 { return eval_data.mod->node(0); }
00170
00174 void new_point(const std::vector<double>& __x, const variable_indicator& __v)
00175 {
00176 eval_data.x = &__x;
00177 v_ind = &__v;
00178 if(eval_data.cache)
00179 for(int i=0; i<eval_data.mod->number_of_nodes(); ++i)
00180 (*eval_data.cache)[i].v = false;
00181 }
00182
00184
00185 void initialize() { return; }
00186
00187 int initialize(const expression_node& __data)
00188 {
00189 eval_data.n = 0;
00190 if(__data.ev != NULL && (*__data.ev)[FUNC_EVALUATOR] != NULL)
00191
00192 {
00193 eval_data.r = (*(func_evaluator)(*__data.ev)[FUNC_EVALUATOR])(eval_data.x,
00194 *v_ind);
00195 return 0;
00196 }
00197 else
00198 {
00199 switch(__data.operator_type)
00200 {
00201 case EXPRINFO_VARIABLE:
00202 eval_data.r = (*eval_data.x)[__data.params.nn()];
00203 break;
00204 case EXPRINFO_SUM:
00205 case EXPRINFO_PROD:
00206 case EXPRINFO_MAX:
00207 case EXPRINFO_MIN:
00208 case EXPRINFO_INVERT:
00209 case EXPRINFO_DIV:
00210 case EXPRINFO_CONSTANT:
00211 eval_data.r = __data.params.nd();
00212 break;
00213 case EXPRINFO_MONOME:
00214 case EXPRINFO_IN:
00215 case EXPRINFO_AND:
00216 case EXPRINFO_NOGOOD:
00217 eval_data.r = 1.;
00218 break;
00219 case EXPRINFO_HISTOGRAM:
00220 eval_data.u.p =
00221 (void*)new std::vector<unsigned int>(0,
00222 (__data.params.i().size()+1)>>1);
00223 eval_data.r = 1.;
00224 break;
00225 case EXPRINFO_ALLDIFF:
00226 eval_data.u.p = (void*) new std::vector<double>;
00227 ((std::vector<double>*)eval_data.u.p)->reserve(__data.n_children);
00228
00229 case EXPRINFO_MEAN:
00230 case EXPRINFO_IF:
00231 case EXPRINFO_OR:
00232 case EXPRINFO_NOT:
00233 case EXPRINFO_COUNT:
00234 case EXPRINFO_SCPROD:
00235 case EXPRINFO_NORM:
00236 case EXPRINFO_LEVEL:
00237 eval_data.r = 0.;
00238 break;
00239 case EXPRINFO_DET:
00240 case EXPRINFO_PSD:
00241
00242 break;
00243 case EXPRINFO_COND:
00244 case EXPRINFO_FEM:
00245 case EXPRINFO_MPROD:
00246
00247 break;
00248 case EXPRINFO_LOOKUP:
00249 case EXPRINFO_PWLIN:
00250 case EXPRINFO_SPLINE:
00251 case EXPRINFO_PWCONSTLC:
00252 case EXPRINFO_PWCONSTRC:
00253 eval_data.r = -COCO_INF;
00254 break;
00255 }
00256 return 1;
00257 }
00258 }
00259
00260 void calculate(const expression_node& __data)
00261 {
00262 if(__data.operator_type > 0)
00263 {
00264 eval_data.r = __data.f_evaluate(-1, __data.params.nn(), *eval_data.x,
00265 *v_ind, eval_data.r, 0, NULL);
00266 }
00267 }
00268
00269 void retrieve_from_cache(const expression_node& __data)
00270 {
00271 if(__data.operator_type == EXPRINFO_LIN)
00272 eval_data.r = linalg::linalg_dot(eval_data.mod->lin[__data.params.nn()],
00273 *eval_data.x,0.);
00274 else if(__data.operator_type == EXPRINFO_QUAD)
00275 {
00276 std::vector<double> irslt;
00277
00278 linalg::linalg_matvec(__data.params.m(), *eval_data.x, irslt);
00279 eval_data.r = irslt.back();
00280 irslt.pop_back();
00281 eval_data.r += linalg::linalg_dot(irslt,*eval_data.x,0.);
00282 }
00283 else
00284 eval_data.r = (*eval_data.cache)[__data.node_num].f;
00285 }
00286
00287 int update(const double& __rval)
00288 {
00289 eval_data.r = __rval;
00290 return 0;
00291 }
00292
00293 int update(const expression_node& __data, const double& __rval)
00294 {
00295 double __x;
00296 int ret = 0;
00297 if(__data.operator_type < 0)
00298 {
00299 switch(__data.operator_type)
00300 {
00301 case EXPRINFO_CONSTANT:
00302 case EXPRINFO_VARIABLE:
00303 break;
00304 case EXPRINFO_SUM:
00305 case EXPRINFO_MEAN:
00306 eval_data.r += __data.coeffs[eval_data.n++]*__rval;
00307 break;
00308 case EXPRINFO_PROD:
00309
00310 eval_data.r *= __rval;
00311 if(eval_data.r == 0) ret=-1;
00312 break;
00313 case EXPRINFO_MONOME:
00314 eval_data.r *= __power(__data.coeffs[eval_data.n],
00315 __rval, __data.params.n()[eval_data.n]);
00316 if(eval_data.r == 0) ret=-1;
00317 eval_data.n++;
00318 break;
00319 case EXPRINFO_MAX:
00320 __x = __rval * __data.coeffs[eval_data.n++];
00321 if(__x > eval_data.r)
00322 eval_data.r = __x;
00323 break;
00324 case EXPRINFO_MIN:
00325 __x = __rval * __data.coeffs[eval_data.n++];
00326 if(__x < eval_data.r)
00327 eval_data.r = __x;
00328 break;
00329 case EXPRINFO_SCPROD:
00330 { double h = __data.coeffs[eval_data.n]*__rval;
00331
00332
00333 if(eval_data.n & 1)
00334 eval_data.r += eval_data.u.d*h;
00335 else
00336 eval_data.u.d = h;
00337 }
00338 eval_data.n++;
00339 break;
00340 case EXPRINFO_NORM:
00341 if(__data.params.nd() == COCO_INF)
00342 {
00343 __x = fabs(__rval*__data.coeffs[eval_data.n]);
00344 if(__x > eval_data.r)
00345 eval_data.r = __x;
00346 }
00347 else
00348 { __x = fabs(__data.coeffs[eval_data.n]*__rval);
00349 eval_data.r += std::pow(__x, __data.params.nd());
00350 }
00351 eval_data.n++;
00352 if(eval_data.n == __data.n_children && __data.params.nd() != COCO_INF)
00353 eval_data.r = std::pow(eval_data.r, 1./__data.params.nd());
00354 break;
00355 case EXPRINFO_INVERT:
00356 eval_data.r /= __rval;
00357 break;
00358 case EXPRINFO_DIV:
00359
00360
00361 if(eval_data.n == 0)
00362 {
00363 eval_data.r *= __rval;
00364 if(eval_data.r == 0) ret=-1;
00365 }
00366 else
00367 eval_data.r /= __rval;
00368 ++eval_data.n;
00369 break;
00370 case EXPRINFO_SQUARE:
00371 __x = __data.coeffs[0]*__rval+__data.params.nd();
00372 eval_data.r = __x*__x;
00373 break;
00374 case EXPRINFO_INTPOWER:
00375 eval_data.r = __power(__data.coeffs[0], __rval, __data.params.nn());
00376 break;
00377 case EXPRINFO_SQROOT:
00378 eval_data.r = std::sqrt(__data.coeffs[0]*__rval+__data.params.nd());
00379 break;
00380 case EXPRINFO_ABS:
00381 eval_data.r = fabs(__data.coeffs[0]*__rval+__data.params.nd());
00382 break;
00383 case EXPRINFO_POW:
00384 __x = __rval * __data.coeffs[eval_data.n];
00385 if(eval_data.n == 0)
00386 eval_data.r = __x+__data.params.nd();
00387 else
00388 eval_data.r = std::pow(eval_data.r, __x);
00389 ++eval_data.n;
00390 break;
00391 case EXPRINFO_EXP:
00392 eval_data.r = std::exp(__rval*__data.coeffs[0]+__data.params.nd());
00393 break;
00394 case EXPRINFO_LOG:
00395 eval_data.r = std::log(__rval*__data.coeffs[0]+__data.params.nd());
00396 break;
00397 case EXPRINFO_SIN:
00398 eval_data.r = std::sin(__rval*__data.coeffs[0]+__data.params.nd());
00399 break;
00400 case EXPRINFO_COS:
00401 eval_data.r = std::cos(__rval*__data.coeffs[0]+__data.params.nd());
00402 break;
00403 case EXPRINFO_ATAN2:
00404 __x = __rval * __data.coeffs[eval_data.n];
00405 if(eval_data.n == 0)
00406 eval_data.r = __x;
00407 else
00408 eval_data.r = std::atan2(eval_data.r, __x);
00409 ++eval_data.n;
00410 break;
00411 case EXPRINFO_GAUSS:
00412 __x = (__data.coeffs[0]*__rval-__data.params.d()[0])/
00413 __data.params.d()[1];
00414 eval_data.r = std::exp(-__x*__x);
00415 break;
00416 case EXPRINFO_POLY:
00417 if(!__data.params.empty())
00418 {
00419 __x = __data.coeffs[eval_data.n]*__rval;
00420 eval_data.r = __data.params.d().back();
00421 for(int i = (int)__data.params.d().size()-2; i >= 0; --i)
00422 {
00423 eval_data.r *= __x;
00424 eval_data.r += __data.params.d()[i];
00425 }
00426 }
00427 else
00428 eval_data.r = 0;
00429 break;
00430 case EXPRINFO_LIN:
00431 case EXPRINFO_QUAD:
00432
00433 break;
00434 case EXPRINFO_IN:
00435 {
00436 __x = __data.coeffs[eval_data.n]*__rval;
00437 const interval& i(__data.params.i()[eval_data.n]);
00438 if(eval_data.r != -1 && i.contains(__x))
00439 {
00440 if(eval_data.r == 1 && (__x == i.inf() || __x == i.sup()))
00441 eval_data.r = 0;
00442 }
00443 else
00444 {
00445 eval_data.r = -1;
00446 ret = -1;
00447 }
00448 }
00449 eval_data.n++;
00450 break;
00451 case EXPRINFO_IF:
00452 __x = __rval * __data.coeffs[eval_data.n];
00453 if(eval_data.n == 0)
00454 {
00455 const interval& i(__data.params.ni());
00456 if(!i.contains(__x))
00457 ret = 1;
00458 }
00459 else
00460 {
00461 eval_data.r = __x;
00462 ret = -1;
00463 }
00464 eval_data.n += ret+1;
00465 break;
00466 case EXPRINFO_AND:
00467 { __x = __data.coeffs[eval_data.n]*__rval;
00468 const interval& i(__data.params.i()[eval_data.n]);
00469 if(eval_data.r == 1 && !i.contains(__x))
00470 {
00471 eval_data.r = 0;
00472 ret = -1;
00473 }
00474 }
00475 eval_data.n++;
00476 break;
00477 case EXPRINFO_OR:
00478 { __x = __data.coeffs[eval_data.n]*__rval;
00479 const interval& i(__data.params.i()[eval_data.n]);
00480 if(eval_data.r == 0 && i.contains(__x))
00481 {
00482 eval_data.r = 1;
00483 ret = -1;
00484 }
00485 }
00486 eval_data.n++;
00487 break;
00488 case EXPRINFO_NOT:
00489 { __x = __data.coeffs[0]*__rval;
00490 const interval& i(__data.params.ni());
00491 if(i.contains(__x))
00492 eval_data.r = 0;
00493 else
00494 eval_data.r = 1;
00495 }
00496 break;
00497 case EXPRINFO_IMPLIES:
00498 { const interval& i(__data.params.i()[eval_data.n]);
00499 __x = __rval * __data.coeffs[eval_data.n];
00500 if(eval_data.n == 0)
00501 {
00502 if(!i.contains(__x))
00503 {
00504 eval_data.r = 1;
00505 ret = -1;
00506 }
00507 }
00508 else
00509 eval_data.r = i.contains(__x) ? 1 : 0;
00510 ++eval_data.n;
00511 }
00512 break;
00513 case EXPRINFO_COUNT:
00514 { __x = __data.coeffs[eval_data.n]*__rval;
00515 const interval& i(__data.params.i()[eval_data.n]);
00516 if(i.contains(__x))
00517 eval_data.r += 1;
00518 }
00519 eval_data.n++;
00520 break;
00521 case EXPRINFO_ALLDIFF:
00522 { __x = __data.coeffs[eval_data.n]*__rval;
00523 for(std::vector<double>::const_iterator _b =
00524 ((std::vector<double>*)eval_data.u.p)->begin();
00525 _b != ((std::vector<double>*)eval_data.u.p)->end(); ++_b)
00526 {
00527 if(fabs(__x-*_b) <= __data.params.nd())
00528 {
00529 eval_data.r = 0;
00530 ret = -1;
00531 break;
00532 }
00533 }
00534 if(ret != -1)
00535 ((std::vector<double>*) eval_data.u.p)->push_back(__x);
00536 }
00537 eval_data.n++;
00538 if(eval_data.n == __data.n_children || ret == -1)
00539 delete (std::vector<double>*) eval_data.u.p;
00540 break;
00541 case EXPRINFO_HISTOGRAM:
00542 { __x = __data.coeffs[eval_data.n]*__rval;
00543 unsigned int ni = (__data.params.i().size()+1)>>1;
00544 for(unsigned int i=0; i<ni; i++)
00545 {
00546 if(__data.params.i()[i<<1].contains(__x))
00547 (*(std::vector<unsigned int>*)eval_data.u.p)[i]++;
00548 }
00549 eval_data.n++;
00550 if(eval_data.n == __data.n_children || ret == -1)
00551 {
00552 for(unsigned int i=0; i<ni; i++)
00553 {
00554 if(!__data.params.i()[i<<1].contains(
00555 (*(std::vector<unsigned int>*)eval_data.u.p)[i]+0.))
00556 {
00557 eval_data.r = 0;
00558 break;
00559 }
00560 }
00561 delete (std::vector<unsigned int>*) eval_data.u.p;
00562 }
00563 }
00564 break;
00565 case EXPRINFO_LEVEL:
00566 { int h = (int)eval_data.r;
00567 __x = __data.coeffs[eval_data.n]*__rval;
00568 interval _h;
00569
00570 if(h != INT_MAX)
00571 {
00572 while(h < __data.params.im().nrows())
00573 {
00574 _h = __data.params.im()[h][eval_data.n];
00575 if(_h.contains(__x))
00576 break;
00577 h++;
00578 }
00579 if(h == __data.params.im().nrows())
00580 {
00581 ret = -1;
00582 eval_data.r = INT_MAX;
00583 }
00584 else
00585 eval_data.r = h;
00586 }
00587 }
00588 eval_data.n++;
00589 break;
00590 case EXPRINFO_NEIGHBOR:
00591 if(eval_data.n == 0)
00592 eval_data.r = __data.coeffs[0]*__rval;
00593 else
00594 {
00595 double h = eval_data.r;
00596 eval_data.r = 0;
00597 __x = __data.coeffs[1]*__rval;
00598 for(unsigned int i = 0; i < __data.params.n().size(); i+=2)
00599 {
00600 if(h == __data.params.n()[i] && __x == __data.params.n()[i+1])
00601 {
00602 eval_data.r = 1;
00603 break;
00604 }
00605 }
00606 }
00607 eval_data.n++;
00608 break;
00609 case EXPRINFO_NOGOOD:
00610 {
00611 __x = __data.coeffs[eval_data.n]*__rval;
00612 if(eval_data.r == 0 || __data.params.n()[eval_data.n] != __x)
00613 {
00614 eval_data.r = 0;
00615 ret = -1;
00616 }
00617 }
00618 eval_data.n++;
00619 break;
00620 case EXPRINFO_EXPECTATION:
00621 throw nyi_exception("func_evaluator: EXPECTATION");
00622 break;
00623 case EXPRINFO_INTEGRAL:
00624 throw nyi_exception("func_evaluator: INTEGRAL");
00625 break;
00626 case EXPRINFO_LOOKUP:
00627 throw nyi_exception("func_evaluator: LOOKUP");
00628 case EXPRINFO_PWLIN:
00629 {
00630 int idx;
00631 __x = __data.coeffs[0]*__rval;
00632 for(idx=0; idx < __data.params.m().nrows(); ++idx)
00633 if(__x < __data.params.m()[idx][0]) break;
00634 if(idx == __data.params.m().nrows())
00635 eval_data.r = COCO_INF;
00636 else if(idx == 0)
00637 eval_data.r = -COCO_INF;
00638 else
00639 {
00640 double x0 = __data.params.m()[idx-1][0];
00641 double f0 = __data.params.m()[idx-1][1];
00642 eval_data.r = (__data.params.m()[idx][1]-f0)/(__data.params.m()[idx][0]-x0);
00643 eval_data.r *= __x-x0;
00644 eval_data.r += f0;
00645 }
00646 }
00647 break;
00648 case EXPRINFO_SPLINE:
00649 throw nyi_exception("func_evaluator: SPLINE");
00650 case EXPRINFO_PWCONSTLC:
00651 {
00652 int idx;
00653 __x = __data.coeffs[0]*__rval;
00654 for(idx=0; idx < __data.params.m().nrows(); ++idx)
00655 if(__x < __data.params.m()[idx][0]) break;
00656 if(idx == __data.params.m().nrows())
00657 eval_data.r = COCO_INF;
00658 else if(idx == 0)
00659 eval_data.r = -COCO_INF;
00660 else
00661 eval_data.r = __data.params.m()[idx-1][1];
00662 }
00663 break;
00664 case EXPRINFO_PWCONSTRC:
00665 {
00666 int idx;
00667 __x = __data.coeffs[0]*__rval;
00668 for(idx=0; idx < __data.params.m().nrows(); ++idx)
00669 if(__x <= __data.params.m()[idx][0]) break;
00670 if(idx == __data.params.m().nrows())
00671 eval_data.r = COCO_INF;
00672 else if(idx == 0)
00673 eval_data.r = -COCO_INF;
00674 else
00675 eval_data.r = __data.params.m()[idx-1][1];
00676 }
00677 break;
00678 case EXPRINFO_DET:
00679 case EXPRINFO_COND:
00680 case EXPRINFO_PSD:
00681 case EXPRINFO_MPROD:
00682 case EXPRINFO_FEM:
00683 throw nyi_exception("func_evaluator: Matrix Operations");
00684 break;
00685 case EXPRINFO_RE:
00686 case EXPRINFO_IM:
00687 case EXPRINFO_ARG:
00688 case EXPRINFO_CPLXCONJ:
00689 throw nyi_exception("func_evaluator: Complex Operations");
00690 break;
00691 case EXPRINFO_CMPROD:
00692 case EXPRINFO_CGFEM:
00693 throw nyi_exception("func_evaluator: Const Matrix Operations");
00694 default:
00695 throw api_exception(apiee_evaluator,
00696 std::string("func_evaluator: unknown function type ")+
00697 convert_to_str(__data.operator_type));
00698 break;
00699 }
00700 }
00701 else if(__data.operator_type > 0)
00702
00703 eval_data.r = __data.f_evaluate(eval_data.n++, __data.params.nn(),
00704 *eval_data.x, *v_ind, eval_data.r,
00705 __rval, NULL);
00706
00707 if(eval_data.cache )
00708 {
00709 (*eval_data.cache)[__data.node_num].f = eval_data.r;
00710 (*eval_data.cache)[__data.node_num].v = true;
00711 }
00712 return ret;
00713 }
00714
00715 double calculate_value(bool eval_all)
00716 {
00717 return eval_data.r;
00718 }
00720 };
00721
00722 }
00723
00724 #endif