00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00029 #ifndef _DFUNC_EVALUATOR_H_
00030 #define _DFUNC_EVALUATOR_H_
00031
00032 #include <coconut_config.h>
00033 #include <evaluator.h>
00034 #include <expression.h>
00035 #include <model.h>
00036 #include <eval_main.h>
00037 #include <linalg.h>
00038 #include <math.h>
00039 #include <api_exception.h>
00040
00041 using namespace vgtl;
00042
00043 namespace coco {
00044
00045 typedef double (*func_evaluator)(const std::vector<double>* __x,
00046 const variable_indicator& __v);
00047
00048 template <typename _T, int DN>
00049 struct dfunc_eval_rettype
00050 {
00051 _T f[DN+1];
00052 };
00053
00054 template <typename _T, int DN>
00055 struct dfunc_eval_type
00056 {
00057 unsigned int _0chnbase;
00058 _T x;
00059 std::vector<dfunc_eval_rettype<_T, DN> >* cache;
00060 const model* mod;
00061 void *p;
00062 _T d;
00063 dfunc_eval_rettype<_T, DN> r;
00064 unsigned int n;
00065 };
00066
00067 template <typename _T, int DN>
00068 class dfunc_eval :
00069 public cached_forward_evaluator_base<dfunc_eval_type<_T, DN>,expression_node,
00070 dfunc_eval_rettype<_T, DN>,
00071 expression_const_walker>
00072 {
00073 private:
00074 typedef dfunc_eval_rettype<_T, DN> retval;
00075 typedef cached_forward_evaluator_base<dfunc_eval_type<_T, DN>,expression_node,
00076 retval, expression_const_walker> _Base;
00077 typedef typename _Base::node_data_type node_data_type;
00078
00079 protected:
00080 bool is_cached(const node_data_type& __data)
00081 {
00082 if(__data.operator_type == EXPRINFO_LIN ||
00083 __data.operator_type == EXPRINFO_QUAD)
00084 return true;
00085 if(eval_data.cache && __data.n_parents > 1 && __data.n_children > 0 &&
00086 v_ind->match(__data.var_indicator()))
00087 return true;
00088 else
00089 return false;
00090 }
00091
00092 private:
00093 retval __power(double __coeff, retval __x, int __exp)
00094 {
00095 retval x;
00096 if(__exp == 0)
00097 {
00098 x.f[0] = _T(1.);
00099 for(int i=0; i <= DN; ++i)
00100 x.f[i] = _T(0.);
00101 }
00102 else
00103 {
00104 double k = __coeff*__x;
00105 switch(__exp)
00106 {
00107 case 1:
00108 return k;
00109 break;
00110 case 2:
00111 return k*k;
00112 break;
00113 case -1:
00114 return 1./k;
00115 break;
00116 case -2:
00117 return 1./(k*k);
00118 break;
00119 default:
00120 if(__exp & 1)
00121 {
00122 if(k < 0)
00123 return -std::pow(-k, __exp);
00124 else
00125 return std::pow(k, __exp);
00126 }
00127 else
00128 return std::pow(fabs(k), __exp);
00129 break;
00130 }
00131 }
00132 }
00133
00134 public:
00135 dfunc_eval(const _T& __x, const variable_indicator& __v,
00136 const model& __m, std::vector<retval>* __c) : _Base()
00137 {
00138 eval_data.x = __x;
00139 eval_data.cache = __c;
00140 eval_data.mod = &__m;
00141 eval_data.d = 0.;
00142 v_ind = &__v;
00143 eval_data.n = 0;
00144 }
00145
00146 dfunc_eval(const dfunc_eval& __v) : _Base(__v) {}
00147
00148 ~dfunc_eval() {}
00149
00150 expression_const_walker short_cut_to(const expression_node& __data)
00151 { return eval_data.mod->node(0); }
00152
00153
00154 void new_point(const _T& __x, const variable_indicator& __v)
00155 {
00156 eval_data.x = __x;
00157 v_ind = &__v;
00158 }
00159
00160 void initialize() { return; }
00161
00162 int initialize(const expression_node& __data)
00163 {
00164 int ret = 1;
00165 eval_data.n = 0;
00166 #if 0
00167 if(__data.ev != NULL && (*__data.ev)[DFUNC_EVALUATOR] != NULL)
00168
00169 {
00170 eval_data.r = (*(func_evaluator)(*__data.ev)[DFUNC_EVALUATOR])(eval_data.x,
00171 *v_ind);
00172 return 0;
00173 }
00174 #endif
00175 if(__data.node_num == eval_data._0chnbase)
00176 {
00177 eval_data.r.f[0] = eval_data.x;
00178 ret = 0;
00179 }
00180 else
00181 {
00182 switch(__data.operator_type)
00183 {
00184 case EXPRINFO_VARIABLE:
00185
00186 break;
00187 case EXPRINFO_SUM:
00188 case EXPRINFO_PROD:
00189 case EXPRINFO_MAX:
00190 case EXPRINFO_MIN:
00191 case EXPRINFO_INVERT:
00192 case EXPRINFO_DIV:
00193 case EXPRINFO_CONSTANT:
00194 eval_data.r.f[0] = __data.params.nd();
00195 break;
00196 case EXPRINFO_MONOME:
00197 case EXPRINFO_IN:
00198 case EXPRINFO_AND:
00199 case EXPRINFO_NOGOOD:
00200 eval_data.r.f[0] = 1.;
00201 break;
00202 case EXPRINFO_HISTOGRAM:
00203 eval_data.p =
00204 (void*)new std::vector<unsigned int>(0,
00205 (__data.params.i().size()+1)>>1);
00206 eval_data.r.f[0] = 1.;
00207 break;
00208 case EXPRINFO_ALLDIFF:
00209 eval_data.p = (void*) new std::vector<double>;
00210 ((std::vector<double>*)eval_data.p)->reserve(__data.n_children);
00211
00212 case EXPRINFO_MEAN:
00213 case EXPRINFO_IF:
00214 case EXPRINFO_OR:
00215 case EXPRINFO_NOT:
00216 case EXPRINFO_COUNT:
00217 case EXPRINFO_SCPROD:
00218 case EXPRINFO_NORM:
00219 case EXPRINFO_LEVEL:
00220 eval_data.r.f[0] = 0.;
00221 break;
00222 case EXPRINFO_DET:
00223 case EXPRINFO_PSD:
00224
00225 break;
00226 case EXPRINFO_COND:
00227 case EXPRINFO_FEM:
00228 case EXPRINFO_MPROD:
00229
00230 break;
00231 }
00232 for(int i=1; i <= DN; ++i)
00233 eval_data.r.f[i] = _T(0.);
00234 }
00235 return ret;
00236 }
00237
00238 void calculate(const expression_node& __data)
00239 {
00240 #if 0
00241 if(__data.operator_type > 0)
00242 {
00243 eval_data.r = __data.f_evaluate(-1, __data.params.nn(), *eval_data.x,
00244 *v_ind, eval_data.r, 0, NULL);
00245 }
00246 #endif
00247 }
00248
00249 void retrieve_from_cache(const expression_node& __data)
00250 {
00251 eval_data.r = (*eval_data.cache)[__data.node_num];
00252 }
00253
00254 int update(const retval& __rval)
00255 {
00256 eval_data.r = __rval;
00257 return 0;
00258 }
00259
00260 int update(const expression_node& __data, const retval& __rval)
00261 {
00262 double __x;
00263 int ret = 0;
00264 if(__data.operator_type < 0)
00265 {
00266 switch(__data.operator_type)
00267 {
00268 case EXPRINFO_CONSTANT:
00269 case EXPRINFO_VARIABLE:
00270 break;
00271 case EXPRINFO_SUM:
00272 case EXPRINFO_MEAN:
00273 for(int i=0; i <= DN; ++i)
00274 eval_data.r.f[i] += __data.coeffs[eval_data.n]*__rval.f[i];
00275 eval_data.n++;
00276 break;
00277 case EXPRINFO_PROD:
00278
00279 {
00280 int _c[DN+1];
00281 int i, j, k, _co, _cn;
00282 memset(_c+1, 0, sizeof(int)*(DN+1));
00283 retval f = eval_data.r;
00284 for(i=0; i <= DN; _c[++i]=1, _co = 1)
00285 {
00286 eval_data.r.f[i] *= __rval.f[0];
00287 for(j=1, k=i-1; j <= i; ++j, --k, _co = _cn)
00288 {
00289 _cn = _c[j];
00290 eval_data.r.f[i] += _cn*__rval.f[j]*f[k];
00291 _c[j]+=_co;
00292 }
00293 }
00294 }
00295 break;
00296 case EXPRINFO_MONOME:
00297 eval_data.r *= __power(__data.coeffs[eval_data.n],
00298 __rval, __data.params.n()[eval_data.n]);
00299 if(eval_data.r.f[0] == _T(0.)) ret=-1;
00300 eval_data.n++;
00301 break;
00302 case EXPRINFO_MAX:
00303 __x = __rval * __data.coeffs[eval_data.n++];
00304 if(__x > eval_data.r)
00305 eval_data.r = __x;
00306 break;
00307 case EXPRINFO_MIN:
00308 __x = __rval * __data.coeffs[eval_data.n++];
00309 if(__x < eval_data.r)
00310 eval_data.r = __x;
00311 break;
00312 case EXPRINFO_SCPROD:
00313 { double h = __data.coeffs[eval_data.n]*__rval;
00314
00315
00316 if(eval_data.n & 1)
00317 eval_data.r += eval_data.d*h;
00318 else
00319 eval_data.d = h;
00320 }
00321 eval_data.n++;
00322 break;
00323 case EXPRINFO_NORM:
00324 if(__data.params.nd() == COCO_INF)
00325 {
00326 __x = fabs(__rval*__data.coeffs[eval_data.n]);
00327 if(__x > eval_data.r)
00328 eval_data.r = __x;
00329 }
00330 else
00331 { __x = fabs(__data.coeffs[eval_data.n]*__rval);
00332 eval_data.r += std::pow(__x, __data.params.nd());
00333 }
00334 eval_data.n++;
00335 if(eval_data.n == __data.n_children && __data.params.nd() != COCO_INF)
00336 eval_data.r = std::pow(eval_data.r, 1./__data.params.nd());
00337 break;
00338 case EXPRINFO_INVERT:
00339 eval_data.r /= __rval;
00340 break;
00341 case EXPRINFO_DIV:
00342
00343
00344 if(eval_data.n == 0)
00345 {
00346 eval_data.r *= __rval;
00347 if(eval_data.r == 0) ret=-1;
00348 }
00349 else
00350 eval_data.r /= __rval;
00351 ++eval_data.n;
00352 break;
00353 case EXPRINFO_SQUARE:
00354 __x = __data.coeffs[0]*__rval+__data.params.nd();
00355 eval_data.r = __x*__x;
00356 break;
00357 case EXPRINFO_INTPOWER:
00358 eval_data.r = __power(__data.coeffs[0], __rval, __data.params.nn());
00359 break;
00360 case EXPRINFO_SQROOT:
00361 eval_data.r = std::sqrt(__data.coeffs[0]*__rval+__data.params.nd());
00362 break;
00363 case EXPRINFO_ABS:
00364 eval_data.r = fabs(__data.coeffs[0]*__rval+__data.params.nd());
00365 break;
00366 case EXPRINFO_POW:
00367 __x = __rval * __data.coeffs[eval_data.n];
00368 if(eval_data.n == 0)
00369 eval_data.r = __x+__data.params.nd();
00370 else
00371 eval_data.r = std::pow(eval_data.r, __x);
00372 ++eval_data.n;
00373 break;
00374 case EXPRINFO_EXP:
00375 eval_data.r = std::exp(__rval*__data.coeffs[0]+__data.params.nd());
00376 break;
00377 case EXPRINFO_LOG:
00378 eval_data.r = std::log(__rval*__data.coeffs[0]+__data.params.nd());
00379 break;
00380 case EXPRINFO_SIN:
00381 eval_data.r = std::sin(__rval*__data.coeffs[0]+__data.params.nd());
00382 break;
00383 case EXPRINFO_COS:
00384 eval_data.r = std::cos(__rval*__data.coeffs[0]+__data.params.nd());
00385 break;
00386 case EXPRINFO_ATAN2:
00387 __x = __rval * __data.coeffs[eval_data.n];
00388 if(eval_data.n == 0)
00389 eval_data.r = __x;
00390 else
00391 eval_data.r = std::atan2(eval_data.r, __x);
00392 ++eval_data.n;
00393 break;
00394 case EXPRINFO_GAUSS:
00395 __x = (__data.coeffs[0]*__rval-__data.params.d()[0])/
00396 __data.params.d()[1];
00397 eval_data.r = std::exp(-__x*__x);
00398 break;
00399 case EXPRINFO_POLY:
00400 if(!__data.params.empty())
00401 {
00402 __x = __data.coeffs[eval_data.n]*__rval;
00403 eval_data.r = __data.params.d().back();
00404 for(int i = (int)__data.params.d().size()-2; i >= 0; --i)
00405 {
00406 eval_data.r *= __x;
00407 eval_data.r += __data.params.d()[i];
00408 }
00409 }
00410 else
00411 eval_data.r = 0;
00412 break;
00413 case EXPRINFO_LIN:
00414 case EXPRINFO_QUAD:
00415
00416 break;
00417 case EXPRINFO_IN:
00418 {
00419 __x = __data.coeffs[eval_data.n]*__rval;
00420 const interval& i(__data.params.i()[eval_data.n]);
00421 if(eval_data.r != -1 && i.contains(__x))
00422 {
00423 if(eval_data.r == 1 && (__x == i.inf() || __x == i.sup()))
00424 eval_data.r = 0;
00425 }
00426 else
00427 {
00428 eval_data.r = -1;
00429 ret = -1;
00430 }
00431 }
00432 eval_data.n++;
00433 break;
00434 case EXPRINFO_IF:
00435 __x = __rval * __data.coeffs[eval_data.n];
00436 if(eval_data.n == 0)
00437 {
00438 const interval& i(__data.params.ni());
00439 if(!i.contains(__x))
00440 ret = 1;
00441 }
00442 else
00443 {
00444 eval_data.r = __x;
00445 ret = -1;
00446 }
00447 eval_data.n += ret+1;
00448 break;
00449 case EXPRINFO_AND:
00450 { __x = __data.coeffs[eval_data.n]*__rval;
00451 const interval& i(__data.params.i()[eval_data.n]);
00452 if(eval_data.r == 1 && !i.contains(__x))
00453 {
00454 eval_data.r = 0;
00455 ret = -1;
00456 }
00457 }
00458 eval_data.n++;
00459 break;
00460 case EXPRINFO_OR:
00461 { __x = __data.coeffs[eval_data.n]*__rval;
00462 const interval& i(__data.params.i()[eval_data.n]);
00463 if(eval_data.r == 0 && i.contains(__x))
00464 {
00465 eval_data.r = 1;
00466 ret = -1;
00467 }
00468 }
00469 eval_data.n++;
00470 break;
00471 case EXPRINFO_NOT:
00472 { __x = __data.coeffs[0]*__rval;
00473 const interval& i(__data.params.ni());
00474 if(i.contains(__x))
00475 eval_data.r = 0;
00476 else
00477 eval_data.r = 1;
00478 }
00479 break;
00480 case EXPRINFO_IMPLIES:
00481 { const interval& i(__data.params.i()[eval_data.n]);
00482 __x = __rval * __data.coeffs[eval_data.n];
00483 if(eval_data.n == 0)
00484 {
00485 if(!i.contains(__x))
00486 {
00487 eval_data.r = 1;
00488 ret = -1;
00489 }
00490 }
00491 else
00492 eval_data.r = i.contains(__x) ? 1 : 0;
00493 ++eval_data.n;
00494 }
00495 break;
00496 case EXPRINFO_COUNT:
00497 { __x = __data.coeffs[eval_data.n]*__rval;
00498 const interval& i(__data.params.i()[eval_data.n]);
00499 if(i.contains(__x))
00500 eval_data.r += 1;
00501 }
00502 eval_data.n++;
00503 break;
00504 case EXPRINFO_ALLDIFF:
00505 { __x = __data.coeffs[eval_data.n]*__rval;
00506 for(std::vector<double>::const_iterator _b =
00507 ((std::vector<double>*)eval_data.p)->begin();
00508 _b != ((std::vector<double>*)eval_data.p)->end(); ++_b)
00509 {
00510 if(fabs(__x-*_b) <= __data.params.nd())
00511 {
00512 eval_data.r = 0;
00513 ret = -1;
00514 break;
00515 }
00516 }
00517 if(ret != -1)
00518 ((std::vector<double>*) eval_data.p)->push_back(__x);
00519 }
00520 eval_data.n++;
00521 if(eval_data.n == __data.n_children || ret == -1)
00522 delete (std::vector<double>*) eval_data.p;
00523 break;
00524 case EXPRINFO_HISTOGRAM:
00525 { __x = __data.coeffs[eval_data.n]*__rval;
00526 unsigned int ni = (__data.params.i().size()+1)>>1;
00527 for(unsigned int i=0; i<ni; i++)
00528 {
00529 if(__data.params.i()[i<<1].contains(__x))
00530 (*(std::vector<unsigned int>*)eval_data.p)[i]++;
00531 }
00532 eval_data.n++;
00533 if(eval_data.n == __data.n_children || ret == -1)
00534 {
00535 for(unsigned int i=0; i<ni; i++)
00536 {
00537 if(!__data.params.i()[i<<1].contains(
00538 (*(std::vector<unsigned int>*)eval_data.p)[i]+0.))
00539 {
00540 eval_data.r = 0;
00541 break;
00542 }
00543 }
00544 delete (std::vector<unsigned int>*) eval_data.p;
00545 }
00546 }
00547 break;
00548 case EXPRINFO_LEVEL:
00549 { int h = (int)eval_data.r;
00550 __x = __data.coeffs[eval_data.n]*__rval;
00551 interval _h;
00552
00553 if(h != INT_MAX)
00554 {
00555 while(h < __data.params.im().nrows())
00556 {
00557 _h = __data.params.im()[h][eval_data.n];
00558 if(_h.contains(__x))
00559 break;
00560 h++;
00561 }
00562 if(h == __data.params.im().nrows())
00563 {
00564 ret = -1;
00565 eval_data.r = INT_MAX;
00566 }
00567 else
00568 eval_data.r = h;
00569 }
00570 }
00571 eval_data.n++;
00572 break;
00573 case EXPRINFO_NEIGHBOR:
00574 if(eval_data.n == 0)
00575 eval_data.r = __data.coeffs[0]*__rval;
00576 else
00577 {
00578 double h = eval_data.r;
00579 eval_data.r = 0;
00580 __x = __data.coeffs[1]*__rval;
00581 for(unsigned int i = 0; i < __data.params.n().size(); i+=2)
00582 {
00583 if(h == __data.params.n()[i] && __x == __data.params.n()[i+1])
00584 {
00585 eval_data.r = 1;
00586 break;
00587 }
00588 }
00589 }
00590 eval_data.n++;
00591 break;
00592 case EXPRINFO_NOGOOD:
00593 {
00594 __x = __data.coeffs[eval_data.n]*__rval;
00595 if(eval_data.r == 0 || __data.params.n()[eval_data.n] != __x)
00596 {
00597 eval_data.r = 0;
00598 ret = -1;
00599 }
00600 }
00601 eval_data.n++;
00602 break;
00603 case EXPRINFO_EXPECTATION:
00604 throw nyi_exception("dfunc_evaluator: EXPECTATION");
00605 break;
00606 case EXPRINFO_INTEGRAL:
00607 throw nyi_exception("dfunc_evaluator: INTEGRAL");
00608 break;
00609 case EXPRINFO_LOOKUP:
00610 case EXPRINFO_PWLIN:
00611 case EXPRINFO_SPLINE:
00612 case EXPRINFO_PWCONSTLC:
00613 case EXPRINFO_PWCONSTRC:
00614 throw nyi_exception("dfunc_evaluator: Table Operations");
00615 case EXPRINFO_DET:
00616 case EXPRINFO_COND:
00617 case EXPRINFO_PSD:
00618 case EXPRINFO_MPROD:
00619 case EXPRINFO_FEM:
00620 throw nyi_exception("dfunc_evaluator: Matrix Operations");
00621 break;
00622 case EXPRINFO_RE:
00623 case EXPRINFO_IM:
00624 case EXPRINFO_ARG:
00625 case EXPRINFO_CPLXCONJ:
00626 throw nyi_exception("dfunc_evaluator: Complex Operations");
00627 break;
00628 case EXPRINFO_CMPROD:
00629 case EXPRINFO_CGFEM:
00630 throw nyi_exception("dfunc_evaluator: Const Matrix Operations");
00631 default:
00632 throw api_exception(apiee_evaluator,
00633 std::string("dfunc_evaluator: unknown function type ")+
00634 convert_to_str(__data.operator_type));
00635 break;
00636 }
00637 }
00638 else if(__data.operator_type > 0)
00639
00640 eval_data.r = __data.f_evaluate(eval_data.n++, __data.params.nn(),
00641 *eval_data.x, *v_ind, eval_data.r,
00642 __rval, NULL);
00643
00644 if(eval_data.cache && __data.n_parents > 1 && __data.n_children > 0)
00645 (*eval_data.cache)[__data.node_num] = eval_data.r;
00646 return ret;
00647 }
00648
00649 retval calculate_value(bool eval_all)
00650 {
00651 return eval_data.r;
00652 }
00653 };
00654
00655 }
00656
00657 #endif