00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00028 #ifndef _DER_EVALUATOR_H_
00029 #define _DER_EVALUATOR_H_
00030
00031 #include <coconut_config.h>
00032 #include <evaluator.h>
00033 #include <expression.h>
00034 #include <model.h>
00035 #include <eval_main.h>
00036 #include <linalg.h>
00037 #include <math.h>
00038 #include <api_exception.h>
00039
00040 using namespace vgtl;
00041
00042 namespace coco {
00043
00045
00047 typedef bool (*prep_d_evaluator)();
00048 typedef double (*func_d_evaluator)(const std::vector<double>* __x,
00049 const variable_indicator& __v,
00050 std::vector<double>& __d_data);
00051 typedef std::vector<double>& (*der_evaluator)(const std::vector<double>& __d_dat,
00052 const variable_indicator& __v);
00054
00056
00062 class prep_d_eval : public
00063 cached_forward_evaluator_base<std::vector<std::vector<double> >*, expression_node, bool, expression_const_walker>
00064 {
00065 private:
00067 typedef cached_forward_evaluator_base<std::vector<std::vector<double> >*,
00068 expression_node,bool,expression_const_walker> _Base;
00069
00070 public:
00074 prep_d_eval(std::vector<std::vector<double> >& __d,
00075 unsigned int _num_of_nodes)
00076 {
00077 eval_data = &__d;
00078 if((*eval_data).size() < _num_of_nodes)
00079 (*eval_data).insert((*eval_data).end(),
00080 _num_of_nodes-(*eval_data).size(), std::vector<double>());
00081 }
00082
00084 prep_d_eval(const prep_d_eval& __x) { eval_data = __x.eval_data; }
00085
00087 ~prep_d_eval() {}
00088
00090 bool is_cached(const expression_node& __data)
00091 {
00092 return (*eval_data)[__data.node_num].size() > 0;
00093 }
00094
00096 int initialize(const expression_node& __data)
00097 {
00098 (*eval_data)[__data.node_num].insert((*eval_data)[__data.node_num].end(),
00099 __data.n_children, 0);
00100 return 1;
00101 }
00102
00104
00106 void retrieve_from_cache(const expression_node& __data) { return; }
00107
00108 void initialize() { return; }
00109
00110 void calculate(const expression_node& __data) { return; }
00111
00112 int update(bool __rval) { return 0; }
00113
00114 int update(const expression_node& __data, bool __rval)
00115 { return 0; }
00116
00117 bool calculate_value(bool eval_all) { return true; }
00119 };
00120
00121
00122
00123
00125
00129 struct func_d_eval_type
00130 {
00131 const std::vector<double>* x;
00132 std::vector<double>* f_cache;
00133 std::vector<std::vector<double> >* d_data;
00134 const model* mod;
00135 union { void* p; double d; } u;
00136 double r;
00137 unsigned int n,
00138 info;
00139 };
00140
00142
00148 class func_d_eval : public
00149 cached_forward_evaluator_base<func_d_eval_type, expression_node,
00150 double, expression_const_walker>
00151 {
00152 private:
00154 typedef cached_forward_evaluator_base<func_d_eval_type,expression_node,
00155 double, expression_const_walker> _Base;
00156
00157 protected:
00160 bool is_cached(const node_data_type& __data)
00161 {
00162 if(__data.operator_type == EXPRINFO_LIN ||
00163 __data.operator_type == EXPRINFO_QUAD)
00164 return true;
00165 if(eval_data.f_cache && __data.n_parents > 1 && __data.n_children > 0 &&
00166 v_ind->match(__data.var_indicator()))
00167 return true;
00168 else
00169 return false;
00170 }
00171
00172 private:
00175 double __power(double __coeff, double __x, int __exp)
00176 {
00177 if(__exp == 0)
00178 return 1.;
00179 else
00180 {
00181 double k = __coeff*__x;
00182 switch(__exp)
00183 {
00184 case 1:
00185 return k;
00186 break;
00187 case 2:
00188 return k*k;
00189 break;
00190 case -1:
00191 return 1./k;
00192 break;
00193 case -2:
00194 return 1./(k*k);
00195 break;
00196 default:
00197 if(__exp & 1)
00198 {
00199 if(k < 0)
00200 return -std::pow(-k, __exp);
00201 else
00202 return std::pow(k, __exp);
00203 }
00204 else
00205 return std::pow(fabs(k), __exp);
00206 break;
00207 }
00208 }
00209
00210 return 0.;
00211 }
00212
00215 void __calc_max(double h, const expression_node& __data)
00216 {
00217 if(h >= eval_data.r)
00218 {
00219 (*eval_data.d_data)[__data.node_num][eval_data.info] = 0;
00220 if(h > eval_data.r)
00221 (*eval_data.d_data)[__data.node_num][eval_data.n] =
00222 __data.coeffs[eval_data.n];
00223 else
00224
00225 (*eval_data.d_data)[__data.node_num][eval_data.n] = 0;
00226 eval_data.r = h;
00227 eval_data.info = eval_data.n;
00228 }
00229 else
00230 {
00231 (*eval_data.d_data)[__data.node_num][eval_data.n] = 0;
00232 }
00233 }
00234
00235 public:
00241 func_d_eval(const std::vector<double>& __x, const variable_indicator& __v,
00242 const model& __m, std::vector<std::vector<double> >& __d,
00243 std::vector<double>* __c) : _Base()
00244 {
00245 eval_data.x = &__x;
00246 eval_data.f_cache = __c;
00247 eval_data.mod = &__m;
00248 eval_data.d_data = &__d;
00249 eval_data.n = 0;
00250 eval_data.r = 0;
00251 eval_data.u.d = 0;
00252 v_ind = &__v;
00253 }
00254
00256 func_d_eval(const func_d_eval& __x) : _Base(__x) {}
00257
00259 ~func_d_eval() {}
00260
00262 expression_const_walker short_cut_to(const expression_node& __data)
00263 { return eval_data.mod->node(0); }
00264
00268 void new_point(const std::vector<double>& __x, const variable_indicator& __v)
00269 {
00270 eval_data.x = &__x;
00271 v_ind = &__v;
00272 }
00273
00275
00276 void initialize() { return; }
00277
00278 int initialize(const expression_node& __data)
00279 {
00280 eval_data.n = 0;
00281 if(__data.ev != NULL && (*__data.ev)[FUNC_D_EVALUATOR] != NULL)
00282
00283 {
00284 eval_data.r =
00285 (*(func_d_evaluator)(*__data.ev)[FUNC_D_EVALUATOR])(eval_data.x,
00286 *v_ind, (*eval_data.d_data)[__data.node_num]);
00287 return 0;
00288 }
00289 else
00290 {
00291 switch(__data.operator_type)
00292 {
00293 case EXPRINFO_MAX:
00294 case EXPRINFO_MIN:
00295 eval_data.info = 0;
00296
00297 case EXPRINFO_SUM:
00298 case EXPRINFO_PROD:
00299 case EXPRINFO_INVERT:
00300 eval_data.r = __data.params.nd();
00301 break;
00302 case EXPRINFO_IN:
00303 case EXPRINFO_AND:
00304 case EXPRINFO_NOGOOD:
00305 eval_data.r = 1.;
00306 break;
00307 case EXPRINFO_ALLDIFF:
00308 eval_data.u.p = (void*) new std::vector<double>;
00309 ((std::vector<double>*)eval_data.u.p)->reserve(__data.n_children);
00310
00311 case EXPRINFO_MEAN:
00312 case EXPRINFO_IF:
00313 case EXPRINFO_OR:
00314 case EXPRINFO_NOT:
00315 case EXPRINFO_COUNT:
00316 case EXPRINFO_SCPROD:
00317 case EXPRINFO_LEVEL:
00318 eval_data.r = 0.;
00319 break;
00320 case EXPRINFO_NORM:
00321 eval_data.info = 0;
00322 eval_data.r = 0.;
00323 break;
00324 case EXPRINFO_DET:
00325 case EXPRINFO_PSD:
00326
00327 break;
00328 case EXPRINFO_COND:
00329 case EXPRINFO_FEM:
00330 case EXPRINFO_MPROD:
00331
00332 break;
00333 }
00334 return 1;
00335 }
00336 }
00337
00338 void calculate(const expression_node& __data)
00339 {
00340 if(__data.operator_type > 0)
00341 {
00342 eval_data.r = __data.f_evaluate(-1, __data.params.nn(), *eval_data.x,
00343 *v_ind, eval_data.r, 0,
00344 &((*eval_data.d_data)[__data.node_num]));
00345 }
00346 }
00347
00348 void retrieve_from_cache(const expression_node& __data)
00349 {
00350
00351 if(__data.operator_type == EXPRINFO_LIN)
00352 eval_data.r = linalg::linalg_dot(eval_data.mod->lin[__data.params.nn()],
00353 *eval_data.x,0.);
00354 else if(__data.operator_type == EXPRINFO_QUAD)
00355 {
00356 std::vector<double> irslt = *eval_data.x;
00357 unsigned int r = __data.params.m().nrows();
00358
00359
00360 irslt.push_back(0);
00361 linalg::linalg_matvec(__data.params.m(), irslt, irslt);
00362 irslt.pop_back();
00363
00364 eval_data.r = linalg::linalg_dot(__data.params.m()[r-1], *eval_data.x, 0.);
00365
00366 eval_data.r += linalg::linalg_dot(irslt,*eval_data.x,0.);
00367
00368 linalg::linalg_add(linalg_scale(__data.params.m()[r-1], 0.5), irslt,
00369 (*eval_data.d_data)[__data.node_num]);
00370 }
00371 else
00372 eval_data.r = (*eval_data.f_cache)[__data.node_num];
00373 }
00374
00375 int update(const double& __rval)
00376 {
00377 eval_data.r = __rval;
00378 return 0;
00379 }
00380
00381 int update(const expression_node& __data, const double& __rval)
00382 {
00383 int ret = 0;
00384 double __x;
00385 if(__data.operator_type < 0)
00386 {
00387 switch(__data.operator_type)
00388 {
00389 case EXPRINFO_CONSTANT:
00390 eval_data.r = __data.params.nd();
00391
00392 break;
00393 case EXPRINFO_VARIABLE:
00394 eval_data.r = (*eval_data.x)[__data.params.nn()];
00395
00396 break;
00397 case EXPRINFO_SUM:
00398 case EXPRINFO_MEAN:
00399 { double h = __data.coeffs[eval_data.n];
00400 eval_data.r += h*__rval;
00401 (*eval_data.d_data)[__data.node_num][eval_data.n++] = h;
00402 }
00403 break;
00404 case EXPRINFO_PROD:
00405 if(eval_data.n == 0)
00406 {
00407 eval_data.r *= __rval;
00408 (*eval_data.d_data)[__data.node_num][0] = __data.params.nd();
00409 }
00410 else
00411 {
00412 (*eval_data.d_data)[__data.node_num][eval_data.n] = eval_data.r;
00413 eval_data.r *= __rval;
00414 for(int i = eval_data.n-1; i >= 0; i--)
00415 (*eval_data.d_data)[__data.node_num][i] *= __rval;
00416 }
00417 ++eval_data.n;
00418 break;
00419 case EXPRINFO_MONOME:
00420 if(eval_data.n == 0)
00421 {
00422 int n = __data.params.n()[0];
00423 if(n != 0)
00424 {
00425 __x = __power(__data.coeffs[0], __rval, n-1)*__data.coeffs[0];
00426 eval_data.r = __x*__rval;
00427 (*eval_data.d_data)[__data.node_num][0] = n*__x;
00428 }
00429 else
00430 {
00431 (*eval_data.d_data)[__data.node_num][0] = 0;
00432 eval_data.r = 1.;
00433 }
00434 }
00435 else
00436 {
00437 int n = __data.params.n()[eval_data.n];
00438 if(n != 0)
00439 {
00440 __x = __power(__data.coeffs[eval_data.n], __rval, n-1)*
00441 __data.coeffs[eval_data.n];
00442 (*eval_data.d_data)[__data.node_num][eval_data.n] =
00443 eval_data.r*n*__x;
00444 __x *= __rval;
00445 eval_data.r *= __x;
00446 for(int i = eval_data.n-1; i >= 0; i--)
00447 (*eval_data.d_data)[__data.node_num][i] *= __x;
00448 }
00449 else
00450 (*eval_data.d_data)[__data.node_num][eval_data.n] = 0;
00451 }
00452 ++eval_data.n;
00453 break;
00454 case EXPRINFO_MAX:
00455 __calc_max(__rval * __data.coeffs[eval_data.n], __data);
00456 ++eval_data.n;
00457 break;
00458 case EXPRINFO_MIN:
00459 { double h = __rval * __data.coeffs[eval_data.n];
00460 if(h <= eval_data.r)
00461 {
00462 (*eval_data.d_data)[__data.node_num][eval_data.info] = 0;
00463 if(h < eval_data.r)
00464 (*eval_data.d_data)[__data.node_num][eval_data.n] =
00465 __data.coeffs[eval_data.n];
00466 else
00467
00468 (*eval_data.d_data)[__data.node_num][eval_data.n] = 0;
00469 eval_data.r = h;
00470 eval_data.info = eval_data.n;
00471 }
00472 else
00473 {
00474 (*eval_data.d_data)[__data.node_num][eval_data.n] = 0;
00475 }
00476 }
00477 ++eval_data.n;
00478 break;
00479 case EXPRINFO_SCPROD:
00480 { double h = __data.coeffs[eval_data.n]*__rval;
00481
00482
00483 if(eval_data.n & 1)
00484 {
00485 eval_data.r += eval_data.u.d*h;
00486 (*eval_data.d_data)[__data.node_num][eval_data.n] =
00487 eval_data.u.d*__data.coeffs[eval_data.n-1];
00488 (*eval_data.d_data)[__data.node_num][eval_data.n-1] =
00489 h*__data.coeffs[eval_data.n];
00490 }
00491 else
00492 eval_data.u.d = h;
00493 }
00494 eval_data.n++;
00495 break;
00496 case EXPRINFO_NORM:
00497 if(__data.params.nd() == COCO_INF)
00498 __calc_max(fabs(__rval * __data.coeffs[eval_data.n]), __data);
00499 else
00500 {
00501 double h = __data.coeffs[eval_data.n]*fabs(__rval);
00502 double O = std::pow(h, __data.params.nd()-1);
00503 eval_data.r += O*h;
00504 (*eval_data.d_data)[__data.node_num][eval_data.n] = O;
00505 }
00506 eval_data.n++;
00507 if(eval_data.n == __data.n_children &&
00508 __data.params.nd() != COCO_INF)
00509 {
00510 double h = std::pow(eval_data.r,1./(__data.params.nd())-1.);
00511 for(unsigned int i = 0; i < eval_data.n; ++i)
00512 (*eval_data.d_data)[__data.node_num][eval_data.n] *= h;
00513 eval_data.r = std::pow(eval_data.r,1./(__data.params.nd()));
00514 }
00515 break;
00516 case EXPRINFO_INVERT:
00517 { double h = 1/__rval;
00518 eval_data.r *= h;
00519 (*eval_data.d_data)[__data.node_num][0] = -__data.params.nd()*h*h;
00520 }
00521 break;
00522 case EXPRINFO_DIV:
00523 if(eval_data.n++ == 0)
00524 eval_data.r = __rval;
00525 else
00526 {
00527 double h = 1/__rval;
00528 eval_data.r *=
00529 (*eval_data.d_data)[__data.node_num][0] = __data.params.nd()*h;
00530 (*eval_data.d_data)[__data.node_num][1] = -eval_data.r*h;
00531 }
00532 break;
00533 case EXPRINFO_SQUARE:
00534 { double h = __data.coeffs[0]*__rval+__data.params.nd();
00535 eval_data.r = h*h;
00536 (*eval_data.d_data)[__data.node_num][0] = 2*h*__data.coeffs[0];
00537 }
00538 break;
00539 case EXPRINFO_INTPOWER:
00540 { int hl = __data.params.nn();
00541 if(hl == 0)
00542 {
00543 eval_data.r = 1;
00544 (*eval_data.d_data)[__data.node_num][0] = 0;
00545 }
00546 else
00547 {
00548 double kl = __data.coeffs[0]*__rval;
00549 switch(hl)
00550 {
00551 case 1:
00552 eval_data.r = kl;
00553 (*eval_data.d_data)[__data.node_num][0] = __data.coeffs[0];
00554 break;
00555 case 2:
00556 eval_data.r = kl*kl;
00557 (*eval_data.d_data)[__data.node_num][0] =
00558 2*kl*__data.coeffs[0];
00559 break;
00560 case -1:
00561 { double h = 1/kl;
00562 eval_data.r = h;
00563 (*eval_data.d_data)[__data.node_num][0] =
00564 -h*h*__data.coeffs[0];
00565 }
00566 break;
00567 case -2:
00568 { double h = 1/kl;
00569 double k = h*h;
00570 eval_data.r = k;
00571 (*eval_data.d_data)[__data.node_num][0] =
00572 -2*h*k*__data.coeffs[0];
00573 }
00574 break;
00575 default:
00576 { double h;
00577 if(hl & 1)
00578 h = std::pow(fabs(kl), hl-1);
00579 else
00580 {
00581 if(kl < 0)
00582 h = -std::pow(-kl, hl-1);
00583 else
00584 h = std::pow(kl, hl-1);
00585 }
00586 eval_data.r = h*kl;
00587 (*eval_data.d_data)[__data.node_num][0] =
00588 hl*h*__data.coeffs[0];
00589 }
00590 break;
00591 }
00592 }
00593 }
00594 break;
00595 case EXPRINFO_SQROOT:
00596 { double h = std::sqrt(__data.coeffs[0]*__rval+__data.params.nd());
00597 eval_data.r = h;
00598 (*eval_data.d_data)[__data.node_num][0] = 0.5*__data.coeffs[0]/h;
00599 }
00600 break;
00601 case EXPRINFO_ABS:
00602 { double h = __data.coeffs[0]*__rval+__data.params.nd();
00603 eval_data.r = fabs(h);
00604 (*eval_data.d_data)[__data.node_num][0] =
00605 h > 0 ? __data.coeffs[0] : (h < 0 ? -__data.coeffs[0] : 0);
00606 }
00607 break;
00608 case EXPRINFO_POW:
00609 { double hh = __rval * __data.coeffs[eval_data.n];
00610 if(eval_data.n++ == 0)
00611 eval_data.r = hh+__data.params.nd();
00612 else
00613 {
00614 if(hh == 0)
00615 {
00616 (*eval_data.d_data)[__data.node_num][0] = 0;
00617 (*eval_data.d_data)[__data.node_num][1] =
00618 std::log(eval_data.r)*__data.coeffs[1];
00619 eval_data.r = 1;
00620 }
00621 else
00622 {
00623 double h = std::pow(eval_data.r, hh);
00624
00625 (*eval_data.d_data)[__data.node_num][0] =
00626 hh*std::pow(eval_data.r, hh-1)*__data.coeffs[0];
00627 (*eval_data.d_data)[__data.node_num][1] =
00628 std::log(eval_data.r)*h*__data.coeffs[1];
00629 eval_data.r = h;
00630 }
00631 }
00632 }
00633 break;
00634 case EXPRINFO_EXP:
00635 { double h = std::exp(__rval*__data.coeffs[0]+__data.params.nd());
00636 eval_data.r = h;
00637 (*eval_data.d_data)[__data.node_num][0] = h*__data.coeffs[0];
00638 }
00639 break;
00640 case EXPRINFO_LOG:
00641 { double h = __rval*__data.coeffs[0]+__data.params.nd();
00642 eval_data.r = std::log(h);
00643 (*eval_data.d_data)[__data.node_num][0] = __data.coeffs[0]/h;
00644 }
00645 break;
00646 case EXPRINFO_SIN:
00647 { double h = __rval*__data.coeffs[0]+__data.params.nd();
00648 eval_data.r = std::sin(h);
00649 (*eval_data.d_data)[__data.node_num][0] =
00650 __data.coeffs[0]*std::cos(h);
00651 }
00652 break;
00653 case EXPRINFO_COS:
00654 { double h = __rval*__data.coeffs[0]+__data.params.nd();
00655 eval_data.r = std::cos(h);
00656 (*eval_data.d_data)[__data.node_num][0] =
00657 -__data.coeffs[0]*std::sin(h);
00658 }
00659 break;
00660 case EXPRINFO_ATAN2:
00661 { double hh = __rval * __data.coeffs[eval_data.n];
00662 if(eval_data.n++ == 0)
00663 eval_data.r = hh;
00664 else
00665 { double h = eval_data.r;
00666 h *= h;
00667 h += hh*hh;
00668 (*eval_data.d_data)[__data.node_num][0] = __data.coeffs[0]*hh/h;
00669 (*eval_data.d_data)[__data.node_num][1] =
00670 -__data.coeffs[1]*eval_data.r/h;
00671 eval_data.r = std::atan2(eval_data.r,hh);
00672 }
00673 }
00674 break;
00675 case EXPRINFO_GAUSS:
00676 { double h = (__data.coeffs[0]*__rval-__data.params.d()[0])/
00677 __data.params.d()[1];
00678 double k = std::exp(-h*h);
00679 eval_data.r = k;
00680 (*eval_data.d_data)[__data.node_num][0] =
00681 -2*__data.coeffs[0]*k*h/__data.params.d()[1];
00682 }
00683 break;
00684 case EXPRINFO_POLY:
00685 throw nyi_exception("func_d_evaluator: POLY");
00686 break;
00687 case EXPRINFO_LIN:
00688 case EXPRINFO_QUAD:
00689
00690 break;
00691 case EXPRINFO_IN:
00692 {
00693 __x = __data.coeffs[eval_data.n]*__rval;
00694 const interval& i(__data.params.i()[eval_data.n]);
00695 if(eval_data.r != -1 && i.contains(__x))
00696 {
00697 if(eval_data.r == 1 && (__x == i.inf() || __x == i.sup()))
00698 eval_data.r = 0;
00699 }
00700 else
00701 {
00702 eval_data.r = -1;
00703 ret = -1;
00704 }
00705 }
00706
00707 if(eval_data.n == 0)
00708 {
00709 std::vector<double>& v((*eval_data.d_data)[__data.node_num]);
00710 v.erase(v.begin(),v.end());
00711 v.insert(v.begin(),__data.n_children,0.);
00712 }
00713 eval_data.n++;
00714 break;
00715 case EXPRINFO_IF:
00716 __x = __rval * __data.coeffs[eval_data.n];
00717 if(eval_data.n == 0)
00718 {
00719 const interval& i(__data.params.ni());
00720 if(!i.contains(__x))
00721 {
00722 ret = 1;
00723 (*eval_data.d_data)[__data.node_num][1] = 0.;
00724 }
00725 else
00726 (*eval_data.d_data)[__data.node_num][2] = 0.;
00727 (*eval_data.d_data)[__data.node_num][0] = 0.;
00728 }
00729 else
00730 {
00731 eval_data.r = __x;
00732 (*eval_data.d_data)[__data.node_num][eval_data.n] =
00733 __data.coeffs[eval_data.n];
00734 ret = -1;
00735 }
00736 eval_data.n += ret+1;
00737 break;
00738 case EXPRINFO_AND:
00739 { __x = __data.coeffs[eval_data.n]*__rval;
00740 const interval& i(__data.params.i()[eval_data.n]);
00741 if(eval_data.r == 1 && !i.contains(__x))
00742 {
00743 eval_data.r = 0;
00744 ret = -1;
00745 }
00746 }
00747
00748 if(eval_data.n == 0)
00749 {
00750 std::vector<double>& v((*eval_data.d_data)[__data.node_num]);
00751 v.erase(v.begin(),v.end());
00752 v.insert(v.begin(),__data.n_children,0.);
00753 }
00754 eval_data.n++;
00755 break;
00756 case EXPRINFO_OR:
00757 { __x = __data.coeffs[eval_data.n]*__rval;
00758 const interval& i(__data.params.i()[eval_data.n]);
00759 if(eval_data.r == 0 && i.contains(__x))
00760 {
00761 eval_data.r = 1;
00762 ret = -1;
00763 }
00764 }
00765
00766 if(eval_data.n == 0)
00767 {
00768 std::vector<double>& v((*eval_data.d_data)[__data.node_num]);
00769 v.erase(v.begin(),v.end());
00770 v.insert(v.begin(),__data.n_children,0.);
00771 }
00772 eval_data.n++;
00773 break;
00774 case EXPRINFO_NOT:
00775 { __x = __data.coeffs[0]*__rval;
00776 const interval& i(__data.params.ni());
00777 if(i.contains(__x))
00778 eval_data.r = 0;
00779 else
00780 eval_data.r = 1;
00781
00782 (*eval_data.d_data)[__data.node_num][0] = 0.;
00783 }
00784 break;
00785 case EXPRINFO_IMPLIES:
00786 { const interval& i(__data.params.i()[eval_data.n]);
00787 __x = __rval * __data.coeffs[eval_data.n];
00788 if(eval_data.n == 0)
00789 {
00790 if(!i.contains(__x))
00791 {
00792 eval_data.r = 1;
00793 ret = -1;
00794 }
00795
00796 (*eval_data.d_data)[__data.node_num][0] = 0.;
00797 (*eval_data.d_data)[__data.node_num][1] = 0.;
00798 }
00799 else
00800 eval_data.r = i.contains(__x) ? 1 : 0;
00801 ++eval_data.n;
00802 }
00803 break;
00804 case EXPRINFO_COUNT:
00805 { __x = __data.coeffs[eval_data.n]*__rval;
00806 const interval& i(__data.params.i()[eval_data.n]);
00807 if(i.contains(__x))
00808 eval_data.r += 1;
00809 }
00810
00811 if(eval_data.n == 0)
00812 {
00813 std::vector<double>& v((*eval_data.d_data)[__data.node_num]);
00814 v.erase(v.begin(),v.end());
00815 v.insert(v.begin(),__data.n_children,0.);
00816 }
00817 eval_data.n++;
00818 break;
00819 case EXPRINFO_ALLDIFF:
00820 { __x = __data.coeffs[eval_data.n]*__rval;
00821 for(std::vector<double>::const_iterator _b =
00822 ((std::vector<double>*)eval_data.u.p)->begin();
00823 _b != ((std::vector<double>*)eval_data.u.p)->end(); ++_b)
00824 {
00825 if(fabs(__x-*_b) <= __data.params.nd())
00826 {
00827 eval_data.r = 0;
00828 ret = -1;
00829 break;
00830 }
00831 }
00832 if(ret != -1)
00833 ((std::vector<double>*) eval_data.u.p)->push_back(__x);
00834 }
00835
00836 if(eval_data.n == 0)
00837 {
00838 std::vector<double>& v((*eval_data.d_data)[__data.node_num]);
00839 v.erase(v.begin(),v.end());
00840 v.insert(v.begin(),__data.n_children,0.);
00841 }
00842 eval_data.n++;
00843 if(eval_data.n == __data.n_children || ret == -1)
00844 delete (std::vector<double>*) eval_data.u.p;
00845 break;
00846 case EXPRINFO_HISTOGRAM:
00847 throw nyi_exception("func_d_evaluator: HISTOGRAM");
00848 break;
00849 case EXPRINFO_LEVEL:
00850 { int h = (int)eval_data.r;
00851 __x = __data.coeffs[eval_data.n]*__rval;
00852 interval _h;
00853
00854 if(h != INT_MAX)
00855 {
00856 while(h < __data.params.im().nrows())
00857 {
00858 _h = __data.params.im()[h][eval_data.n];
00859 if(_h.contains(__x))
00860 break;
00861 h++;
00862 }
00863 if(h == __data.params.im().nrows())
00864 {
00865 ret = -1;
00866 eval_data.r = INT_MAX;
00867 }
00868 else
00869 eval_data.r = h;
00870 }
00871 }
00872
00873 if(eval_data.n == 0)
00874 {
00875 std::vector<double>& v((*eval_data.d_data)[__data.node_num]);
00876 v.erase(v.begin(),v.end());
00877 v.insert(v.begin(),__data.n_children,0.);
00878 }
00879 eval_data.n++;
00880 break;
00881 case EXPRINFO_NEIGHBOR:
00882 if(eval_data.n == 0)
00883 eval_data.r = __data.coeffs[0]*__rval;
00884 else
00885 {
00886 double h = eval_data.r;
00887 eval_data.r = 0;
00888 __x = __data.coeffs[1]*__rval;
00889 for(unsigned int i = 0; i < __data.params.n().size(); i+=2)
00890 {
00891 if(h == __data.params.n()[i] && __x == __data.params.n()[i+1])
00892 {
00893 eval_data.r = 1;
00894 break;
00895 }
00896 }
00897 }
00898
00899 if(eval_data.n == 0)
00900 {
00901 std::vector<double>& v((*eval_data.d_data)[__data.node_num]);
00902 v.erase(v.begin(),v.end());
00903 v.insert(v.begin(),__data.n_children,0.);
00904 }
00905 eval_data.n++;
00906 break;
00907 case EXPRINFO_NOGOOD:
00908 {
00909 __x = __data.coeffs[eval_data.n]*__rval;
00910 if(eval_data.r == 0 || __data.params.n()[eval_data.n] != __x)
00911 {
00912 eval_data.r = 0;
00913 ret = -1;
00914 }
00915 }
00916
00917 if(eval_data.n == 0)
00918 {
00919 std::vector<double>& v((*eval_data.d_data)[__data.node_num]);
00920 v.erase(v.begin(),v.end());
00921 v.insert(v.begin(),__data.n_children,0.);
00922 }
00923 eval_data.n++;
00924 break;
00925 case EXPRINFO_EXPECTATION:
00926 throw nyi_exception("func_d_evaluator: EXPECTATION");
00927 break;
00928 case EXPRINFO_INTEGRAL:
00929 throw nyi_exception("func_d_evaluator: INTEGRAL");
00930 break;
00931 case EXPRINFO_LOOKUP:
00932 case EXPRINFO_PWLIN:
00933 case EXPRINFO_SPLINE:
00934 case EXPRINFO_PWCONSTLC:
00935 case EXPRINFO_PWCONSTRC:
00936 throw nyi_exception("func_d_evaluator: Table Operations");
00937 break;
00938 case EXPRINFO_DET:
00939 case EXPRINFO_COND:
00940 case EXPRINFO_PSD:
00941 case EXPRINFO_MPROD:
00942 case EXPRINFO_FEM:
00943 throw nyi_exception("func_d_evaluator: Matrix Operations");
00944 break;
00945 case EXPRINFO_RE:
00946 case EXPRINFO_IM:
00947 case EXPRINFO_ARG:
00948 case EXPRINFO_CPLXCONJ:
00949 throw nyi_exception("func_d_evaluator: Complex Operations");
00950 break;
00951 case EXPRINFO_CMPROD:
00952 case EXPRINFO_CGFEM:
00953 throw nyi_exception("func_d_evaluator: Const Matrix Operations");
00954 break;
00955 default:
00956 throw api_exception(apiee_evaluator,
00957 std::string("func_d_evaluator: unknown function type ")+
00958 convert_to_str(__data.operator_type));
00959 break;
00960 }
00961 }
00962 else if(__data.operator_type > 0)
00963
00964 eval_data.r = __data.f_evaluate(eval_data.n++, __data.params.nn(),
00965 *eval_data.x, *v_ind, eval_data.r, __rval,
00966 &(*eval_data.d_data)[__data.node_num]);
00967
00968 if(eval_data.f_cache && __data.n_parents > 1 && __data.n_children > 0)
00969 (*eval_data.f_cache)[__data.node_num] = eval_data.r;
00970 return ret;
00971 }
00972
00973 double calculate_value(bool eval_all)
00974 {
00975 return eval_data.r;
00976 }
00978 };
00979
00981
00985 struct der_eval_type
00986 {
00987 std::vector<std::vector<double> >* d_data;
00988 std::vector<std::vector<double> >* d_cache;
00989 std::vector<double>* grad_vec;
00990 const model* mod;
00991 double mult;
00992 double mult_trans;
00993 unsigned int child_n;
00994 };
00995
00997
01002 class der_eval : public
01003 cached_backward_evaluator_base<der_eval_type,expression_node,bool,
01004 expression_const_walker>
01005 {
01006 private:
01008 typedef cached_backward_evaluator_base<der_eval_type,expression_node,
01009 bool,expression_const_walker> _Base;
01010
01011 protected:
01014 bool is_cached(const node_data_type& __data)
01015 {
01016 if(eval_data.d_cache && __data.n_parents > 1 && __data.n_children > 0
01017 && (*eval_data.d_cache)[__data.node_num].size() > 0 &&
01018 v_ind->match(__data.var_indicator()))
01019 {
01020 return true;
01021 }
01022 else
01023 return false;
01024 }
01025
01026 public:
01035 der_eval(std::vector<std::vector<double> >& __der_data, variable_indicator& __v,
01036 const model& __m, std::vector<std::vector<double > >* __d,
01037 std::vector<double>& __grad)
01038 {
01039 eval_data.d_data = &__der_data;
01040 eval_data.d_cache = __d;
01041 eval_data.mod = &__m;
01042 eval_data.grad_vec = &__grad;
01043 eval_data.mult_trans = 1;
01044 eval_data.mult = 0;
01045 v_ind = &__v;
01046 }
01047
01049 der_eval(const der_eval& __d) { eval_data = __d.eval_data; }
01050
01052 ~der_eval() {}
01053
01057 void new_point(std::vector<std::vector<double> >& __der_data,
01058 const variable_indicator& __v)
01059 {
01060 eval_data.d_data = &__der_data;
01061 v_ind = &__v;
01062 }
01063
01065 void new_result(std::vector<double>& __grad)
01066 {
01067 eval_data.grad_vec = &__grad;
01068 }
01069
01071 void set_mult(double scal)
01072 {
01073 eval_data.mult_trans = scal;
01074 }
01075 public:
01076
01078 expression_const_walker short_cut_to(const expression_node& __data)
01079 { return eval_data.mod->node(0); }
01080
01082
01083
01084 void initialize()
01085 {
01086 eval_data.child_n = 0;
01087 }
01088
01089
01090 int calculate(const expression_node& __data)
01091 {
01092 if(__data.operator_type == EXPRINFO_CONSTANT)
01093 return 0;
01094 else if(__data.operator_type == EXPRINFO_VARIABLE)
01095 {
01096
01097 (*eval_data.grad_vec)[__data.params.nn()] += eval_data.mult_trans;
01098 return 0;
01099 }
01100 else if(__data.operator_type == EXPRINFO_LIN)
01101 {
01102 linalg::linalg_add(linalg_scale(eval_data.mod->lin[__data.params.nn()], eval_data.mult_trans),
01103 *eval_data.grad_vec,*eval_data.grad_vec);
01104 return 0;
01105 }
01106 else if(__data.operator_type == EXPRINFO_QUAD)
01107 {
01108 linalg::linalg_ssum(*eval_data.grad_vec, 2*eval_data.mult_trans,
01109 (*eval_data.d_data)[__data.node_num]);
01110 return 0;
01111 }
01112 else if(__data.ev && (*__data.ev)[DER_EVALUATOR])
01113
01114 {
01115 linalg::linalg_ssum(*eval_data.grad_vec, eval_data.mult,
01116 (*(der_evaluator)(*__data.ev)[DER_EVALUATOR])(
01117 (*eval_data.d_data)[__data.node_num], *v_ind));
01118 return 0;
01119 }
01120 else if(eval_data.mult_trans == 0)
01121
01122 return 0;
01123 else
01124 {
01125 eval_data.child_n = 1;
01126 eval_data.mult = eval_data.mult_trans;
01127 if(__data.n_parents > 1 && __data.n_children > 0 && eval_data.d_cache)
01128 {
01129 eval_data.mult_trans = (*eval_data.d_data)[__data.node_num][0];
01130 }
01131 else
01132 eval_data.mult_trans *= (*eval_data.d_data)[__data.node_num][0];
01133 return 1;
01134 }
01135 }
01136
01137
01138 void cleanup(const expression_node& __data)
01139 {
01140
01141 if(__data.n_parents > 1 && __data.n_children > 0 && eval_data.d_cache
01142 && (*eval_data.d_cache)[__data.node_num].size() == 0)
01143 {
01144 (*eval_data.d_cache)[__data.node_num] = *eval_data.grad_vec;
01145 linalg::linalg_smult(*eval_data.grad_vec, eval_data.mult);
01146 }
01147 }
01148
01149 void retrieve_from_cache(const expression_node& __data)
01150 {
01151
01152 linalg::linalg_ssum(*eval_data.grad_vec, eval_data.mult_trans,
01153 (*eval_data.d_cache)[__data.node_num]);
01154 }
01155
01156 int update(const bool& __rval)
01157 {
01158 eval_data.child_n++;
01159 return 0;
01160 }
01161
01162
01163 int update(const expression_node& __data, const bool& __rval)
01164 {
01165 if(__data.n_children == 0)
01166 return 0;
01167 if(__data.n_parents > 1 && __data.n_children > 0 && eval_data.d_cache)
01168 {
01169 if(eval_data.child_n < __data.n_children)
01170 eval_data.mult_trans =
01171 (*eval_data.d_data)[__data.node_num][eval_data.child_n];
01172 }
01173 else if(eval_data.child_n < __data.n_children)
01174 {
01175 eval_data.mult_trans = eval_data.mult *
01176 (*eval_data.d_data)[__data.node_num][eval_data.child_n];
01177 }
01178 eval_data.child_n++;
01179 return 0;
01180 }
01181
01182 bool calculate_value(bool eval_all)
01183 {
01184 return true;
01185 }
01187 };
01188
01189 }
01190
01191 #endif