00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00027 #ifndef _DER_EVALUATOR_H_
00028 #define _DER_EVALUATOR_H_
00029
00030 #include <coconut_config.h>
00031 #include <evaluator.h>
00032 #include <expression.h>
00033 #include <model.h>
00034 #include <eval_main.h>
00035 #include <linalg.h>
00036 #include <math.h>
00037
00038 using namespace vgtl;
00039
00040 typedef bool (*prep_d_evaluator)();
00041 typedef double (*func_d_evaluator)(const std::vector<double>* __x,
00042 const variable_indicator& __v,
00043 std::vector<double>& __d_data);
00044 typedef std::vector<double>& (*der_evaluator)(const std::vector<double>& __d_dat,
00045 const variable_indicator& __v);
00046
00047 class prep_d_eval : public
00048 cached_forward_evaluator_base<std::vector<std::vector<double> >*, expression_node, bool, model::walker>
00049 {
00050 private:
00051 typedef cached_forward_evaluator_base<std::vector<std::vector<double> >*,
00052 expression_node,bool,model::walker> _Base;
00053
00054 public:
00055 prep_d_eval(std::vector<std::vector<double> >& __d,
00056 unsigned int _num_of_nodes)
00057 {
00058 eval_data = &__d;
00059 if((*eval_data).size() < _num_of_nodes)
00060 (*eval_data).insert((*eval_data).end(),
00061 _num_of_nodes-(*eval_data).size(), std::vector<double>());
00062 }
00063
00064 prep_d_eval(const prep_d_eval& __x) { eval_data = __x.eval_data; }
00065
00066 ~prep_d_eval() {}
00067
00068 void initialize() { return; }
00069
00070 bool is_cached(const expression_node& __data)
00071 {
00072 return (*eval_data)[__data.node_num].size() > 0;
00073 }
00074
00075 void retrieve_from_cache(const expression_node& __data) { return; }
00076
00077 int initialize(const expression_node& __data)
00078 {
00079 (*eval_data)[__data.node_num].insert((*eval_data)[__data.node_num].end(),
00080 __data.n_children, 0);
00081 return 1;
00082 }
00083
00084 void calculate(const expression_node& __data) { return; }
00085
00086 int update(bool __rval) { return 0; }
00087
00088 int update(const expression_node& __data, bool __rval)
00089 { return 0; }
00090
00091 bool calculate_value(bool eval_all) { return true; }
00092 };
00093
00094
00095
00096
00097 struct func_d_eval_type
00098 {
00099 const std::vector<double>* x;
00100 std::vector<double>* f_cache;
00101 std::vector<std::vector<double> >* d_data;
00102 const model* mod;
00103 union { void* p; double d; } u;
00104 double r;
00105 unsigned int n,
00106 info;
00107 };
00108
00109 class func_d_eval : public
00110 cached_forward_evaluator_base<func_d_eval_type, expression_node,
00111 double, model::walker>
00112 {
00113 private:
00114 typedef cached_forward_evaluator_base<func_d_eval_type,expression_node,
00115 double, model::walker> _Base;
00116
00117 protected:
00118 bool is_cached(const node_data_type& __data)
00119 {
00120 if(__data.operator_type == EXPRINFO_LIN ||
00121 __data.operator_type == EXPRINFO_QUAD)
00122 return true;
00123 if(eval_data.f_cache && __data.n_parents > 1 && __data.n_children > 0 &&
00124 v_ind->match(__data.var_indicator()))
00125 return true;
00126 else
00127 return false;
00128 }
00129
00130 private:
00131 double __power(double __coeff, double __x, int __exp)
00132 {
00133 if(__exp == 0)
00134 return 1.;
00135 else
00136 {
00137 double k = __coeff*__x;
00138 switch(__exp)
00139 {
00140 case 1:
00141 return k;
00142 break;
00143 case 2:
00144 return k*k;
00145 break;
00146 case -1:
00147 return 1./k;
00148 break;
00149 case -2:
00150 return 1./(k*k);
00151 break;
00152 default:
00153 if(__exp & 1)
00154 {
00155 if(k < 0)
00156 return -pow(-k, __exp);
00157 else
00158 return pow(k, __exp);
00159 }
00160 else
00161 return pow(fabs(k), __exp);
00162 break;
00163 }
00164 }
00165 }
00166
00167 void __calc_max(double h, const expression_node& __data)
00168 {
00169 if(h >= eval_data.r)
00170 {
00171 (*eval_data.d_data)[__data.node_num][eval_data.info] = 0;
00172 if(h > eval_data.r)
00173 (*eval_data.d_data)[__data.node_num][eval_data.n] =
00174 __data.coeffs[eval_data.n];
00175 else
00176
00177 (*eval_data.d_data)[__data.node_num][eval_data.n] = 0;
00178 eval_data.r = h;
00179 eval_data.info = eval_data.n;
00180 }
00181 else
00182 {
00183 (*eval_data.d_data)[__data.node_num][eval_data.n] = 0;
00184 }
00185 }
00186
00187 public:
00188 func_d_eval(const std::vector<double>& __x, const variable_indicator& __v,
00189 const model& __m, std::vector<std::vector<double> >& __d,
00190 std::vector<double>* __c) : _Base()
00191 {
00192 eval_data.x = &__x;
00193 eval_data.f_cache = __c;
00194 eval_data.mod = &__m;
00195 eval_data.d_data = &__d;
00196 eval_data.n = 0;
00197 eval_data.r = 0;
00198 eval_data.u.d = 0;
00199 v_ind = &__v;
00200 }
00201
00202 func_d_eval(const func_d_eval& __x) : _Base(__x) {}
00203
00204 ~func_d_eval() {}
00205
00206 model::walker short_cut_to(const expression_node& __data)
00207 { return eval_data.mod->node(0); }
00208
00209 void new_point(const std::vector<double>& __x, const variable_indicator& __v)
00210 {
00211 eval_data.x = &__x;
00212 v_ind = &__v;
00213 }
00214
00215 void initialize() { return; }
00216
00217 int initialize(const expression_node& __data)
00218 {
00219 eval_data.n = 0;
00220 if(__data.ev != NULL && (*__data.ev)[FUNC_D_EVALUATOR] != NULL)
00221
00222 {
00223 eval_data.r =
00224 (*(func_d_evaluator)(*__data.ev)[FUNC_D_EVALUATOR])(eval_data.x,
00225 *v_ind, (*eval_data.d_data)[__data.node_num]);
00226 return 0;
00227 }
00228 else
00229 {
00230 switch(__data.operator_type)
00231 {
00232 case EXPRINFO_MAX:
00233 case EXPRINFO_MIN:
00234 eval_data.info = 0;
00235
00236 case EXPRINFO_SUM:
00237 case EXPRINFO_PROD:
00238 case EXPRINFO_INVERT:
00239 eval_data.r = __data.params.nd();
00240 break;
00241 case EXPRINFO_IN:
00242 case EXPRINFO_AND:
00243 case EXPRINFO_NOGOOD:
00244 eval_data.r = 1.;
00245 break;
00246 case EXPRINFO_ALLDIFF:
00247 eval_data.u.p = (void*) new std::vector<double>;
00248 ((std::vector<double>*)eval_data.u.p)->reserve(__data.n_children);
00249
00250 case EXPRINFO_MEAN:
00251 case EXPRINFO_IF:
00252 case EXPRINFO_OR:
00253 case EXPRINFO_NOT:
00254 case EXPRINFO_COUNT:
00255 case EXPRINFO_SCPROD:
00256 case EXPRINFO_LEVEL:
00257 eval_data.r = 0.;
00258 break;
00259 case EXPRINFO_NORM:
00260 eval_data.info = 0;
00261 eval_data.r = 0.;
00262 break;
00263 case EXPRINFO_DET:
00264 case EXPRINFO_PSD:
00265
00266 break;
00267 case EXPRINFO_COND:
00268 case EXPRINFO_FEM:
00269 case EXPRINFO_MPROD:
00270
00271 break;
00272 }
00273 return 1;
00274 }
00275 }
00276
00277 void calculate(const expression_node& __data)
00278 {
00279 if(__data.operator_type > 0)
00280 {
00281 eval_data.r = __data.f_evaluate(-1, __data.params.nn(), *eval_data.x,
00282 *v_ind, eval_data.r, 0,
00283 &((*eval_data.d_data)[__data.node_num]));
00284 }
00285 }
00286
00287 void retrieve_from_cache(const expression_node& __data)
00288 {
00289
00290 if(__data.operator_type == EXPRINFO_LIN)
00291 eval_data.r = linalg_dot(eval_data.mod->lin[__data.params.nn()],
00292 *eval_data.x,0.);
00293 else if(__data.operator_type == EXPRINFO_QUAD)
00294 {
00295 std::vector<double> irslt = *eval_data.x;
00296 unsigned int r = __data.params.m().nrows();
00297
00298
00299 irslt.push_back(0);
00300 linalg_matvec(__data.params.m(), irslt, irslt);
00301 irslt.pop_back();
00302
00303 eval_data.r = linalg_dot(__data.params.m()[r-1], *eval_data.x, 0.);
00304
00305 eval_data.r += linalg_dot(irslt,*eval_data.x,0.);
00306
00307 linalg_add(linalg_scale(__data.params.m()[r-1], 0.5), irslt,
00308 (*eval_data.d_data)[__data.node_num]);
00309 }
00310 else
00311 eval_data.r = (*eval_data.f_cache)[__data.node_num];
00312 }
00313
00314 int update(const double& __rval)
00315 {
00316 eval_data.r = __rval;
00317 return 0;
00318 }
00319
00320 int update(const expression_node& __data, const double& __rval)
00321 {
00322 int ret = 0;
00323 double __x;
00324 if(__data.operator_type < 0)
00325 {
00326 switch(__data.operator_type)
00327 {
00328 case EXPRINFO_CONSTANT:
00329 eval_data.r = __data.params.nd();
00330
00331 break;
00332 case EXPRINFO_VARIABLE:
00333 eval_data.r = (*eval_data.x)[__data.params.nn()];
00334
00335 break;
00336 case EXPRINFO_SUM:
00337 case EXPRINFO_MEAN:
00338 { double h = __data.coeffs[eval_data.n];
00339 eval_data.r += h*__rval;
00340 (*eval_data.d_data)[__data.node_num][eval_data.n++] = h;
00341 }
00342 break;
00343 case EXPRINFO_PROD:
00344 if(eval_data.n == 0)
00345 {
00346 eval_data.r *= __rval;
00347 (*eval_data.d_data)[__data.node_num][0] = __data.params.nd();
00348 }
00349 else
00350 {
00351 (*eval_data.d_data)[__data.node_num][eval_data.n] = eval_data.r;
00352 eval_data.r *= __rval;
00353 for(int i = eval_data.n-1; i >= 0; i--)
00354 (*eval_data.d_data)[__data.node_num][i] *= __rval;
00355 }
00356 ++eval_data.n;
00357 break;
00358 case EXPRINFO_MONOME:
00359 if(eval_data.n == 0)
00360 {
00361 int n = __data.params.n()[0];
00362 if(n != 0)
00363 {
00364 __x = __power(__data.coeffs[0], __rval, n-1)*__data.coeffs[0];
00365 eval_data.r = __x*__rval;
00366 (*eval_data.d_data)[__data.node_num][0] = n*__x;
00367 }
00368 else
00369 {
00370 (*eval_data.d_data)[__data.node_num][0] = 0;
00371 eval_data.r = 1.;
00372 }
00373 }
00374 else
00375 {
00376 int n = __data.params.n()[eval_data.n];
00377 if(n != 0)
00378 {
00379 __x = __power(__data.coeffs[eval_data.n], __rval, n-1)*
00380 __data.coeffs[eval_data.n];
00381 (*eval_data.d_data)[__data.node_num][eval_data.n] =
00382 eval_data.r*n*__x;
00383 __x *= __rval;
00384 eval_data.r *= __x;
00385 for(int i = eval_data.n-1; i >= 0; i--)
00386 (*eval_data.d_data)[__data.node_num][i] *= __x;
00387 }
00388 else
00389 (*eval_data.d_data)[__data.node_num][eval_data.n] = 0;
00390 }
00391 ++eval_data.n;
00392 break;
00393 case EXPRINFO_MAX:
00394 __calc_max(__rval * __data.coeffs[eval_data.n], __data);
00395 ++eval_data.n;
00396 break;
00397 case EXPRINFO_MIN:
00398 { double h = __rval * __data.coeffs[eval_data.n];
00399 if(h <= eval_data.r)
00400 {
00401 (*eval_data.d_data)[__data.node_num][eval_data.info] = 0;
00402 if(h < eval_data.r)
00403 (*eval_data.d_data)[__data.node_num][eval_data.n] =
00404 __data.coeffs[eval_data.n];
00405 else
00406
00407 (*eval_data.d_data)[__data.node_num][eval_data.n] = 0;
00408 eval_data.r = h;
00409 eval_data.info = eval_data.n;
00410 }
00411 else
00412 {
00413 (*eval_data.d_data)[__data.node_num][eval_data.n] = 0;
00414 }
00415 }
00416 ++eval_data.n;
00417 break;
00418 case EXPRINFO_SCPROD:
00419 { double h = __data.coeffs[eval_data.n]*__rval;
00420
00421
00422 if(eval_data.n & 1)
00423 {
00424 eval_data.r += eval_data.u.d*h;
00425 (*eval_data.d_data)[__data.node_num][eval_data.n] =
00426 eval_data.u.d*__data.coeffs[eval_data.n-1];
00427 (*eval_data.d_data)[__data.node_num][eval_data.n-1] =
00428 h*__data.coeffs[eval_data.n];
00429 }
00430 else
00431 eval_data.u.d = h;
00432 }
00433 eval_data.n++;
00434 break;
00435 case EXPRINFO_NORM:
00436 if(__data.params.nd() == INFINITY)
00437 __calc_max(fabs(__rval * __data.coeffs[eval_data.n]), __data);
00438 else
00439 {
00440 double h = __data.coeffs[eval_data.n]*fabs(__rval);
00441 double O = pow(h, __data.params.nd()-1);
00442 eval_data.r += O*h;
00443 (*eval_data.d_data)[__data.node_num][eval_data.n] = O;
00444 }
00445 eval_data.n++;
00446 if(eval_data.n == __data.n_children &&
00447 __data.params.nd() != INFINITY)
00448 {
00449 double h = pow(eval_data.r,1./(__data.params.nd())-1.);
00450 for(unsigned int i = 0; i < eval_data.n; ++i)
00451 (*eval_data.d_data)[__data.node_num][eval_data.n] *= h;
00452 eval_data.r = pow(eval_data.r,1./(__data.params.nd()));
00453 }
00454 break;
00455 case EXPRINFO_INVERT:
00456 { double h = 1/__rval;
00457 eval_data.r *= h;
00458 (*eval_data.d_data)[__data.node_num][0] = -__data.params.nd()*h*h;
00459 }
00460 break;
00461 case EXPRINFO_DIV:
00462 if(eval_data.n++ == 0)
00463 eval_data.r = __rval;
00464 else
00465 {
00466 double h = 1/__rval;
00467 eval_data.r *=
00468 (*eval_data.d_data)[__data.node_num][0] = __data.params.nd()*h;
00469 (*eval_data.d_data)[__data.node_num][1] = -eval_data.r*h;
00470 }
00471 break;
00472 case EXPRINFO_SQUARE:
00473 { double h = __data.coeffs[0]*__rval+__data.params.nd();
00474 eval_data.r = h*h;
00475 (*eval_data.d_data)[__data.node_num][0] = 2*h*__data.coeffs[0];
00476 }
00477 break;
00478 case EXPRINFO_INTPOWER:
00479 { int hl = __data.params.nn();
00480 if(hl == 0)
00481 {
00482 eval_data.r = 1;
00483 (*eval_data.d_data)[__data.node_num][0] = 0;
00484 }
00485 else
00486 {
00487 double kl = __data.coeffs[0]*__rval;
00488 switch(hl)
00489 {
00490 case 1:
00491 eval_data.r = kl;
00492 (*eval_data.d_data)[__data.node_num][0] = __data.coeffs[0];
00493 break;
00494 case 2:
00495 eval_data.r = kl*kl;
00496 (*eval_data.d_data)[__data.node_num][0] =
00497 2*kl*__data.coeffs[0];
00498 break;
00499 case -1:
00500 { double h = 1/kl;
00501 eval_data.r = h;
00502 (*eval_data.d_data)[__data.node_num][0] =
00503 -h*h*__data.coeffs[0];
00504 }
00505 break;
00506 case -2:
00507 { double h = 1/kl;
00508 double k = h*h;
00509 eval_data.r = k;
00510 (*eval_data.d_data)[__data.node_num][0] =
00511 -2*h*k*__data.coeffs[0];
00512 }
00513 break;
00514 default:
00515 { double h;
00516 if(hl & 1)
00517 h = pow(fabs(kl), hl-1);
00518 else
00519 {
00520 if(kl < 0)
00521 h = -pow(-kl, hl-1);
00522 else
00523 h = pow(kl, hl-1);
00524 }
00525 eval_data.r = h*kl;
00526 (*eval_data.d_data)[__data.node_num][0] =
00527 hl*h*__data.coeffs[0];
00528 }
00529 break;
00530 }
00531 }
00532 }
00533 break;
00534 case EXPRINFO_SQROOT:
00535 { double h = sqrt(__data.coeffs[0]*__rval+__data.params.nd());
00536 eval_data.r = h;
00537 (*eval_data.d_data)[__data.node_num][0] = 0.5*__data.coeffs[0]/h;
00538 }
00539 break;
00540 case EXPRINFO_ABS:
00541 { double h = __data.coeffs[0]*__rval+__data.params.nd();
00542 eval_data.r = fabs(h);
00543 (*eval_data.d_data)[__data.node_num][0] =
00544 h > 0 ? __data.coeffs[0] : (h < 0 ? -__data.coeffs[0] : 0);
00545 }
00546 break;
00547 case EXPRINFO_POW:
00548 { double hh = __rval * __data.coeffs[eval_data.n];
00549 if(eval_data.n++ == 0)
00550 eval_data.r = hh+__data.params.nd();
00551 else
00552 {
00553 if(hh == 0)
00554 {
00555 (*eval_data.d_data)[__data.node_num][0] = 0;
00556 (*eval_data.d_data)[__data.node_num][1] =
00557 log(eval_data.r)*__data.coeffs[1];
00558 eval_data.r = 1;
00559 }
00560 else
00561 {
00562 double h = pow(eval_data.r, hh);
00563
00564 (*eval_data.d_data)[__data.node_num][0] =
00565 hh*pow(eval_data.r, hh-1)*__data.coeffs[0];
00566 (*eval_data.d_data)[__data.node_num][1] =
00567 log(eval_data.r)*h*__data.coeffs[1];
00568 eval_data.r = h;
00569 }
00570 }
00571 }
00572 break;
00573 case EXPRINFO_EXP:
00574 { double h = exp(__rval*__data.coeffs[0]+__data.params.nd());
00575 eval_data.r = h;
00576 (*eval_data.d_data)[__data.node_num][0] = h*__data.coeffs[0];
00577 }
00578 break;
00579 case EXPRINFO_LOG:
00580 { double h = __rval*__data.coeffs[0]+__data.params.nd();
00581 eval_data.r = log(h);
00582 (*eval_data.d_data)[__data.node_num][0] = __data.coeffs[0]/h;
00583 }
00584 break;
00585 case EXPRINFO_SIN:
00586 { double h = __rval*__data.coeffs[0]+__data.params.nd();
00587 eval_data.r = sin(h);
00588 (*eval_data.d_data)[__data.node_num][0] = __data.coeffs[0]*cos(h);
00589 }
00590 break;
00591 case EXPRINFO_COS:
00592 { double h = __rval*__data.coeffs[0]+__data.params.nd();
00593 eval_data.r = cos(h);
00594 (*eval_data.d_data)[__data.node_num][0] = -__data.coeffs[0]*sin(h);
00595 }
00596 break;
00597 case EXPRINFO_ATAN2:
00598 { double hh = __rval * __data.coeffs[eval_data.n];
00599 if(eval_data.n++ == 0)
00600 eval_data.r = hh;
00601 else
00602 { double h = eval_data.r;
00603 h *= h;
00604 h += hh*hh;
00605 (*eval_data.d_data)[__data.node_num][0] = __data.coeffs[0]*hh/h;
00606 (*eval_data.d_data)[__data.node_num][1] =
00607 -__data.coeffs[1]*eval_data.r/h;
00608 eval_data.r = atan2(eval_data.r,hh);
00609 }
00610 }
00611 break;
00612 case EXPRINFO_GAUSS:
00613 { double h = (__data.coeffs[0]*__rval-__data.params.d()[0])/
00614 __data.params.d()[1];
00615 double k = exp(-h*h);
00616 eval_data.r = k;
00617 (*eval_data.d_data)[__data.node_num][0] =
00618 -2*__data.coeffs[0]*k*h/__data.params.d()[1];
00619 }
00620 break;
00621 case EXPRINFO_POLY:
00622 std::cerr << "func_d_evaluator: Polynomes NYI" << std::endl;
00623 throw "NYI";
00624 break;
00625 case EXPRINFO_LIN:
00626 case EXPRINFO_QUAD:
00627
00628 break;
00629 case EXPRINFO_IN:
00630 {
00631 __x = __data.coeffs[eval_data.n]*__rval;
00632 const interval& i(__data.params.i()[eval_data.n]);
00633 if(eval_data.r != -1 && i.contains(__x))
00634 {
00635 if(eval_data.r == 1 && (__x == i.inf() || __x == i.sup()))
00636 eval_data.r = 0;
00637 }
00638 else
00639 {
00640 eval_data.r = -1;
00641 ret = -1;
00642 }
00643 }
00644
00645 if(eval_data.n == 0)
00646 {
00647 std::vector<double>& v((*eval_data.d_data)[__data.node_num]);
00648 v.erase(v.begin(),v.end());
00649 v.insert(v.begin(),__data.n_children,0.);
00650 }
00651 eval_data.n++;
00652 break;
00653 case EXPRINFO_IF:
00654 __x = __rval * __data.coeffs[eval_data.n];
00655 if(eval_data.n == 0)
00656 {
00657 const interval& i(__data.params.ni());
00658 if(!i.contains(__x))
00659 {
00660 ret = 1;
00661 (*eval_data.d_data)[__data.node_num][1] = 0.;
00662 }
00663 else
00664 (*eval_data.d_data)[__data.node_num][2] = 0.;
00665 (*eval_data.d_data)[__data.node_num][0] = 0.;
00666 }
00667 else
00668 {
00669 eval_data.r = __x;
00670 (*eval_data.d_data)[__data.node_num][eval_data.n] =
00671 __data.coeffs[eval_data.n];
00672 ret = -1;
00673 }
00674 eval_data.n += ret+1;
00675 break;
00676 case EXPRINFO_AND:
00677 { __x = __data.coeffs[eval_data.n]*__rval;
00678 const interval& i(__data.params.i()[eval_data.n]);
00679 if(eval_data.r == 1 && !i.contains(__x))
00680 {
00681 eval_data.r = 0;
00682 ret = -1;
00683 }
00684 }
00685
00686 if(eval_data.n == 0)
00687 {
00688 std::vector<double>& v((*eval_data.d_data)[__data.node_num]);
00689 v.erase(v.begin(),v.end());
00690 v.insert(v.begin(),__data.n_children,0.);
00691 }
00692 eval_data.n++;
00693 break;
00694 case EXPRINFO_OR:
00695 { __x = __data.coeffs[eval_data.n]*__rval;
00696 const interval& i(__data.params.i()[eval_data.n]);
00697 if(eval_data.r == 0 && i.contains(__x))
00698 {
00699 eval_data.r = 1;
00700 ret = -1;
00701 }
00702 }
00703
00704 if(eval_data.n == 0)
00705 {
00706 std::vector<double>& v((*eval_data.d_data)[__data.node_num]);
00707 v.erase(v.begin(),v.end());
00708 v.insert(v.begin(),__data.n_children,0.);
00709 }
00710 eval_data.n++;
00711 break;
00712 case EXPRINFO_NOT:
00713 { __x = __data.coeffs[0]*__rval;
00714 const interval& i(__data.params.ni());
00715 if(i.contains(__x))
00716 eval_data.r = 0;
00717 else
00718 eval_data.r = 1;
00719
00720 (*eval_data.d_data)[__data.node_num][0] = 0.;
00721 }
00722 break;
00723 case EXPRINFO_IMPLIES:
00724 { const interval& i(__data.params.i()[eval_data.n]);
00725 __x = __rval * __data.coeffs[eval_data.n];
00726 if(eval_data.n == 0)
00727 {
00728 if(!i.contains(__x))
00729 {
00730 eval_data.r = 1;
00731 ret = -1;
00732 }
00733
00734 (*eval_data.d_data)[__data.node_num][0] = 0.;
00735 (*eval_data.d_data)[__data.node_num][1] = 0.;
00736 }
00737 else
00738 eval_data.r = i.contains(__x) ? 1 : 0;
00739 ++eval_data.n;
00740 }
00741 break;
00742 case EXPRINFO_COUNT:
00743 { __x = __data.coeffs[eval_data.n]*__rval;
00744 const interval& i(__data.params.i()[eval_data.n]);
00745 if(i.contains(__x))
00746 eval_data.r += 1;
00747 }
00748
00749 if(eval_data.n == 0)
00750 {
00751 std::vector<double>& v((*eval_data.d_data)[__data.node_num]);
00752 v.erase(v.begin(),v.end());
00753 v.insert(v.begin(),__data.n_children,0.);
00754 }
00755 eval_data.n++;
00756 break;
00757 case EXPRINFO_ALLDIFF:
00758 { __x = __data.coeffs[eval_data.n]*__rval;
00759 for(std::vector<double>::const_iterator _b =
00760 ((std::vector<double>*)eval_data.u.p)->begin();
00761 _b != ((std::vector<double>*)eval_data.u.p)->end(); ++_b)
00762 {
00763 if(fabs(__x-*_b) <= __data.params.nd())
00764 {
00765 eval_data.r = 0;
00766 ret = -1;
00767 break;
00768 }
00769 }
00770 if(ret != -1)
00771 ((std::vector<double>*) eval_data.u.p)->push_back(__x);
00772 }
00773
00774 if(eval_data.n == 0)
00775 {
00776 std::vector<double>& v((*eval_data.d_data)[__data.node_num]);
00777 v.erase(v.begin(),v.end());
00778 v.insert(v.begin(),__data.n_children,0.);
00779 }
00780 eval_data.n++;
00781 if(eval_data.n == __data.n_children || ret == -1)
00782 delete (std::vector<double>*) eval_data.u.p;
00783 break;
00784 case EXPRINFO_HISTOGRAM:
00785 std::cerr << "func_d_evaluator: histogram NYI" << std::endl;
00786 throw "NYI";
00787 break;
00788 case EXPRINFO_LEVEL:
00789 { int h = (int)eval_data.r;
00790 __x = __data.coeffs[eval_data.n]*__rval;
00791 interval _h;
00792
00793 if(h != INT_MAX)
00794 {
00795 while(h < __data.params.im().nrows())
00796 {
00797 _h = __data.params.im()[h][eval_data.n];
00798 if(_h.contains(__x))
00799 break;
00800 h++;
00801 }
00802 if(h == __data.params.im().nrows())
00803 {
00804 ret = -1;
00805 eval_data.r = INT_MAX;
00806 }
00807 else
00808 eval_data.r = h;
00809 }
00810 }
00811
00812 if(eval_data.n == 0)
00813 {
00814 std::vector<double>& v((*eval_data.d_data)[__data.node_num]);
00815 v.erase(v.begin(),v.end());
00816 v.insert(v.begin(),__data.n_children,0.);
00817 }
00818 eval_data.n++;
00819 break;
00820 case EXPRINFO_NEIGHBOR:
00821 if(eval_data.n == 0)
00822 eval_data.r = __data.coeffs[0]*__rval;
00823 else
00824 {
00825 double h = eval_data.r;
00826 eval_data.r = 0;
00827 __x = __data.coeffs[1]*__rval;
00828 for(unsigned int i = 0; i < __data.params.n().size(); i+=2)
00829 {
00830 if(h == __data.params.n()[i] && __x == __data.params.n()[i+1])
00831 {
00832 eval_data.r = 1;
00833 break;
00834 }
00835 }
00836 }
00837
00838 if(eval_data.n == 0)
00839 {
00840 std::vector<double>& v((*eval_data.d_data)[__data.node_num]);
00841 v.erase(v.begin(),v.end());
00842 v.insert(v.begin(),__data.n_children,0.);
00843 }
00844 eval_data.n++;
00845 break;
00846 case EXPRINFO_NOGOOD:
00847 {
00848 __x = __data.coeffs[eval_data.n]*__rval;
00849 if(eval_data.r == 0 || __data.params.n()[eval_data.n] != __x)
00850 {
00851 eval_data.r = 0;
00852 ret = -1;
00853 }
00854 }
00855
00856 if(eval_data.n == 0)
00857 {
00858 std::vector<double>& v((*eval_data.d_data)[__data.node_num]);
00859 v.erase(v.begin(),v.end());
00860 v.insert(v.begin(),__data.n_children,0.);
00861 }
00862 eval_data.n++;
00863 break;
00864 case EXPRINFO_EXPECTATION:
00865 std::cerr << "func_d_evaluator: E NYI" << std::endl;
00866 throw "NYI";
00867 break;
00868 case EXPRINFO_INTEGRAL:
00869 std::cerr << "func_d_evaluator: INT NYI" << std::endl;
00870 throw "NYI";
00871 break;
00872 case EXPRINFO_LOOKUP:
00873 case EXPRINFO_PWLIN:
00874 case EXPRINFO_SPLINE:
00875 case EXPRINFO_PWCONSTLC:
00876 case EXPRINFO_PWCONSTRC:
00877 std::cerr << "func_d_evaluator: Table Operations NYI" << std::endl;
00878 throw "NYI";
00879 break;
00880 case EXPRINFO_DET:
00881 case EXPRINFO_COND:
00882 case EXPRINFO_PSD:
00883 case EXPRINFO_MPROD:
00884 case EXPRINFO_FEM:
00885 std::cerr << "func_d_evaluator: Matrix Operations NYI" << std::endl;
00886 throw "NYI";
00887 break;
00888 case EXPRINFO_RE:
00889 case EXPRINFO_IM:
00890 case EXPRINFO_ARG:
00891 case EXPRINFO_CPLXCONJ:
00892 std::cerr << "func_d_evaluator: Complex Numbers NYI" << std::endl;
00893 throw "NYI";
00894 break;
00895 case EXPRINFO_CMPROD:
00896 case EXPRINFO_CGFEM:
00897 std::cerr << "func_d_evaluator: Const Matrix Operations NYI" << std::endl;
00898 throw "NYI";
00899 break;
00900 default:
00901 std::cerr << "func_d_evaluator: unknown function type " <<
00902 __data.operator_type << std::endl;
00903 throw "Programming error";
00904 break;
00905 }
00906 }
00907 else if(__data.operator_type > 0)
00908
00909 eval_data.r = __data.f_evaluate(eval_data.n++, __data.params.nn(),
00910 *eval_data.x, *v_ind, eval_data.r, __rval,
00911 &(*eval_data.d_data)[__data.node_num]);
00912
00913 if(eval_data.f_cache && __data.n_parents > 1 && __data.n_children > 0)
00914 (*eval_data.f_cache)[__data.node_num] = eval_data.r;
00915 return ret;
00916 }
00917
00918 double calculate_value(bool eval_all)
00919 {
00920 return eval_data.r;
00921 }
00922 };
00923
00924 struct der_eval_type
00925 {
00926 std::vector<std::vector<double> >* d_data;
00927 std::vector<std::vector<double> >* d_cache;
00928 std::vector<double>* grad_vec;
00929 const model* mod;
00930 double mult;
00931 double mult_trans;
00932 unsigned int child_n;
00933 };
00934
00935 class der_eval : public
00936 cached_backward_evaluator_base<der_eval_type,expression_node,bool,
00937 model::walker>
00938 {
00939 private:
00940 typedef cached_backward_evaluator_base<der_eval_type,expression_node,
00941 bool,model::walker> _Base;
00942
00943 protected:
00944 bool is_cached(const node_data_type& __data)
00945 {
00946 if(eval_data.d_cache && __data.n_parents > 1 && __data.n_children > 0
00947 && (*eval_data.d_cache)[__data.node_num].size() > 0 &&
00948 v_ind->match(__data.var_indicator()))
00949 {
00950 return true;
00951 }
00952 else
00953 return false;
00954 }
00955
00956 public:
00957 der_eval(std::vector<std::vector<double> >& __der_data, variable_indicator& __v,
00958 const model& __m, std::vector<std::vector<double > >* __d,
00959 std::vector<double>& __grad)
00960 {
00961 eval_data.d_data = &__der_data;
00962 eval_data.d_cache = __d;
00963 eval_data.mod = &__m;
00964 eval_data.grad_vec = &__grad;
00965 eval_data.mult_trans = 1;
00966 eval_data.mult = 0;
00967 v_ind = &__v;
00968 }
00969
00970 der_eval(const der_eval& __d) { eval_data = __d.eval_data; }
00971
00972 ~der_eval() {}
00973
00974 void new_point(std::vector<std::vector<double> >& __der_data,
00975 const variable_indicator& __v)
00976 {
00977 eval_data.d_data = &__der_data;
00978 v_ind = &__v;
00979 }
00980
00981 void new_result(std::vector<double>& __grad)
00982 {
00983 eval_data.grad_vec = &__grad;
00984 }
00985
00986 void set_mult(double scal)
00987 {
00988 eval_data.mult_trans = scal;
00989 }
00990 public:
00991
00992 model::walker short_cut_to(const expression_node& __data)
00993 { return eval_data.mod->node(0); }
00994
00995
00996 void initialize()
00997 {
00998 eval_data.child_n = 0;
00999 }
01000
01001
01002 int calculate(const expression_node& __data)
01003 {
01004 if(__data.operator_type == EXPRINFO_CONSTANT)
01005 return 0;
01006 else if(__data.operator_type == EXPRINFO_VARIABLE)
01007 {
01008
01009 (*eval_data.grad_vec)[__data.params.nn()] += eval_data.mult_trans;
01010 return 0;
01011 }
01012 else if(__data.operator_type == EXPRINFO_LIN)
01013 {
01014 linalg_add(linalg_scale(eval_data.mod->lin[__data.params.nn()], eval_data.mult_trans),
01015 *eval_data.grad_vec,*eval_data.grad_vec);
01016 return 0;
01017 }
01018 else if(__data.operator_type == EXPRINFO_QUAD)
01019 {
01020 linalg_ssum(*eval_data.grad_vec, 2*eval_data.mult_trans,
01021 (*eval_data.d_data)[__data.node_num]);
01022 return 0;
01023 }
01024 else if(__data.ev && (*__data.ev)[DER_EVALUATOR])
01025
01026 {
01027 linalg_ssum(*eval_data.grad_vec, eval_data.mult,
01028 (*(der_evaluator)(*__data.ev)[DER_EVALUATOR])(
01029 (*eval_data.d_data)[__data.node_num], *v_ind));
01030 return 0;
01031 }
01032 else if(eval_data.mult_trans == 0)
01033
01034 return 0;
01035 else
01036 {
01037 eval_data.child_n = 1;
01038 eval_data.mult = eval_data.mult_trans;
01039 if(__data.n_parents > 1 && __data.n_children > 0 && eval_data.d_cache)
01040 {
01041 eval_data.mult_trans = (*eval_data.d_data)[__data.node_num][0];
01042 }
01043 else
01044 eval_data.mult_trans *= (*eval_data.d_data)[__data.node_num][0];
01045 return 1;
01046 }
01047 }
01048
01049
01050 void cleanup(const expression_node& __data)
01051 {
01052
01053 if(__data.n_parents > 1 && __data.n_children > 0 && eval_data.d_cache
01054 && (*eval_data.d_cache)[__data.node_num].size() == 0)
01055 {
01056 (*eval_data.d_cache)[__data.node_num] = *eval_data.grad_vec;
01057 linalg_smult(*eval_data.grad_vec, eval_data.mult);
01058 }
01059 }
01060
01061 void retrieve_from_cache(const expression_node& __data)
01062 {
01063
01064 linalg_ssum(*eval_data.grad_vec, eval_data.mult_trans,
01065 (*eval_data.d_cache)[__data.node_num]);
01066 }
01067
01068 int update(const bool& __rval)
01069 {
01070 eval_data.child_n++;
01071 return 0;
01072 }
01073
01074
01075 int update(const expression_node& __data, const bool& __rval)
01076 {
01077 if(__data.n_children == 0)
01078 return 0;
01079 if(__data.n_parents > 1 && __data.n_children > 0 && eval_data.d_cache)
01080 {
01081 if(eval_data.child_n < __data.n_children)
01082 eval_data.mult_trans =
01083 (*eval_data.d_data)[__data.node_num][eval_data.child_n];
01084 }
01085 else if(eval_data.child_n < __data.n_children)
01086 {
01087 eval_data.mult_trans = eval_data.mult *
01088 (*eval_data.d_data)[__data.node_num][eval_data.child_n];
01089 }
01090 eval_data.child_n++;
01091 return 0;
01092 }
01093
01094 bool calculate_value(bool eval_all)
01095 {
01096 return true;
01097 }
01098 };
01099
01100 #endif