00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00027 #ifndef _FUNC_EVALUATOR_H_
00028 #define _FUNC_EVALUATOR_H_
00029
00030 #include <coconut_config.h>
00031 #include <evaluator.h>
00032 #include <expression.h>
00033 #include <model.h>
00034 #include <eval_main.h>
00035 #include <linalg.h>
00036 #include <math.h>
00037
00038 using namespace vgtl;
00039
00040 typedef double (*func_evaluator)(const std::vector<double>* __x,
00041 const variable_indicator& __v);
00042
00043 struct func_eval_type
00044 {
00045 const std::vector<double>* x;
00046 std::vector<double>* cache;
00047 const model* mod;
00048 union { void *p; double d; } u;
00049 double r;
00050 unsigned int n;
00051 };
00052
00053 class func_eval :
00054 public cached_forward_evaluator_base<func_eval_type,expression_node,double,
00055 model::walker>
00056 {
00057 private:
00058 typedef cached_forward_evaluator_base<func_eval_type,expression_node,
00059 double, model::walker> _Base;
00060
00061 protected:
00062 bool is_cached(const node_data_type& __data)
00063 {
00064 if(__data.operator_type == EXPRINFO_LIN ||
00065 __data.operator_type == EXPRINFO_QUAD)
00066 return true;
00067 if(eval_data.cache && __data.n_parents > 1 && __data.n_children > 0 &&
00068 v_ind->match(__data.var_indicator()))
00069 return true;
00070 else
00071 return false;
00072 }
00073
00074 private:
00075 double __power(double __coeff, double __x, int __exp)
00076 {
00077 if(__exp == 0)
00078 return 1.;
00079 else
00080 {
00081 double k = __coeff*__x;
00082 switch(__exp)
00083 {
00084 case 1:
00085 return k;
00086 break;
00087 case 2:
00088 return k*k;
00089 break;
00090 case -1:
00091 return 1./k;
00092 break;
00093 case -2:
00094 return 1./(k*k);
00095 break;
00096 default:
00097 if(__exp & 1)
00098 {
00099 if(k < 0)
00100 return -pow(-k, __exp);
00101 else
00102 return pow(k, __exp);
00103 }
00104 else
00105 return pow(fabs(k), __exp);
00106 break;
00107 }
00108 }
00109 }
00110
00111 public:
00112 func_eval(const std::vector<double>& __x, const variable_indicator& __v,
00113 const model& __m, std::vector<double>* __c) : _Base()
00114 {
00115 eval_data.x = &__x;
00116 eval_data.cache = __c;
00117 eval_data.mod = &__m;
00118 eval_data.u.d = 0.;
00119 v_ind = &__v;
00120 eval_data.n = 0;
00121 }
00122
00123 func_eval(const func_eval& __v) : _Base(__v) {}
00124
00125 ~func_eval() {}
00126
00127 model::walker short_cut_to(const expression_node& __data)
00128 { return eval_data.mod->node(0); }
00129
00130
00131 void new_point(const std::vector<double>& __x, const variable_indicator& __v)
00132 {
00133 eval_data.x = &__x;
00134 v_ind = &__v;
00135 }
00136
00137 void initialize() { return; }
00138
00139 int initialize(const expression_node& __data)
00140 {
00141 eval_data.n = 0;
00142 if(__data.ev != NULL && (*__data.ev)[FUNC_EVALUATOR] != NULL)
00143
00144 {
00145 eval_data.r = (*(func_evaluator)(*__data.ev)[FUNC_EVALUATOR])(eval_data.x,
00146 *v_ind);
00147 return 0;
00148 }
00149 else
00150 {
00151 switch(__data.operator_type)
00152 {
00153 case EXPRINFO_VARIABLE:
00154 eval_data.r = (*eval_data.x)[__data.params.nn()];
00155 break;
00156 case EXPRINFO_SUM:
00157 case EXPRINFO_PROD:
00158 case EXPRINFO_MAX:
00159 case EXPRINFO_MIN:
00160 case EXPRINFO_INVERT:
00161 case EXPRINFO_DIV:
00162 case EXPRINFO_CONSTANT:
00163 eval_data.r = __data.params.nd();
00164 break;
00165 case EXPRINFO_MONOME:
00166 case EXPRINFO_IN:
00167 case EXPRINFO_AND:
00168 case EXPRINFO_NOGOOD:
00169 eval_data.r = 1.;
00170 break;
00171 case EXPRINFO_ALLDIFF:
00172 eval_data.u.p = (void*) new std::vector<double>;
00173 ((std::vector<double>*)eval_data.u.p)->reserve(__data.n_children);
00174
00175 case EXPRINFO_MEAN:
00176 case EXPRINFO_IF:
00177 case EXPRINFO_OR:
00178 case EXPRINFO_NOT:
00179 case EXPRINFO_COUNT:
00180 case EXPRINFO_SCPROD:
00181 case EXPRINFO_NORM:
00182 case EXPRINFO_LEVEL:
00183 eval_data.r = 0.;
00184 break;
00185 case EXPRINFO_DET:
00186 case EXPRINFO_PSD:
00187
00188 break;
00189 case EXPRINFO_COND:
00190 case EXPRINFO_FEM:
00191 case EXPRINFO_MPROD:
00192
00193 break;
00194 }
00195 return 1;
00196 }
00197 }
00198
00199 void calculate(const expression_node& __data)
00200 {
00201 if(__data.operator_type > 0)
00202 {
00203 eval_data.r = __data.f_evaluate(-1, __data.params.nn(), *eval_data.x,
00204 *v_ind, eval_data.r, 0, NULL);
00205 }
00206 }
00207
00208 void retrieve_from_cache(const expression_node& __data)
00209 {
00210 if(__data.operator_type == EXPRINFO_LIN)
00211 eval_data.r = linalg_dot(eval_data.mod->lin[__data.params.nn()],
00212 *eval_data.x,0.);
00213 else if(__data.operator_type == EXPRINFO_QUAD)
00214 {
00215 std::vector<double> irslt;
00216
00217 linalg_matvec(__data.params.m(), *eval_data.x, irslt);
00218 eval_data.r = irslt.back();
00219 irslt.pop_back();
00220 eval_data.r += linalg_dot(irslt,*eval_data.x,0.);
00221 }
00222 else
00223 eval_data.r = (*eval_data.cache)[__data.node_num];
00224 }
00225
00226 int update(const double& __rval)
00227 {
00228 eval_data.r = __rval;
00229 return 0;
00230 }
00231
00232 int update(const expression_node& __data, const double& __rval)
00233 {
00234 double __x;
00235 int ret = 0;
00236 if(__data.operator_type < 0)
00237 {
00238 switch(__data.operator_type)
00239 {
00240 case EXPRINFO_CONSTANT:
00241 case EXPRINFO_VARIABLE:
00242 break;
00243 case EXPRINFO_SUM:
00244 case EXPRINFO_MEAN:
00245 eval_data.r += __data.coeffs[eval_data.n++]*__rval;
00246 break;
00247 case EXPRINFO_PROD:
00248
00249 eval_data.r *= __rval;
00250 if(eval_data.r == 0) ret=-1;
00251 break;
00252 case EXPRINFO_MONOME:
00253 eval_data.r *= __power(__data.coeffs[eval_data.n],
00254 __rval, __data.params.n()[eval_data.n]);
00255 if(eval_data.r == 0) ret=-1;
00256 eval_data.n++;
00257 break;
00258 case EXPRINFO_MAX:
00259 __x = __rval * __data.coeffs[eval_data.n++];
00260 if(__x > eval_data.r)
00261 eval_data.r = __x;
00262 break;
00263 case EXPRINFO_MIN:
00264 __x = __rval * __data.coeffs[eval_data.n++];
00265 if(__x < eval_data.r)
00266 eval_data.r = __x;
00267 break;
00268 case EXPRINFO_SCPROD:
00269 { double h = __data.coeffs[eval_data.n]*__rval;
00270
00271
00272 if(eval_data.n & 1)
00273 eval_data.r += __data.coeffs[eval_data.n]*__rval*h;
00274 else
00275 eval_data.u.d = h;
00276 }
00277 eval_data.n++;
00278 break;
00279 case EXPRINFO_NORM:
00280 if(__data.params.nd() == INFINITY)
00281 {
00282 __x = fabs(__rval*__data.coeffs[eval_data.n]);
00283 if(__x > eval_data.r)
00284 eval_data.r = __x;
00285 }
00286 else
00287 { __x = fabs(__data.coeffs[eval_data.n]*__rval);
00288 eval_data.r += pow(__x, __data.params.nd());
00289 }
00290 eval_data.n++;
00291 if(eval_data.n == __data.n_children && __data.params.nd() != INFINITY)
00292 eval_data.r = pow(eval_data.r, 1./__data.params.nd());
00293 break;
00294 case EXPRINFO_INVERT:
00295 eval_data.r /= __rval;
00296 break;
00297 case EXPRINFO_DIV:
00298
00299
00300 if(eval_data.n == 0)
00301 {
00302 eval_data.r *= __rval;
00303 if(eval_data.r == 0) ret=-1;
00304 }
00305 else
00306 eval_data.r /= __rval;
00307 ++eval_data.n;
00308 break;
00309 case EXPRINFO_SQUARE:
00310 __x = __data.coeffs[0]*__rval+__data.params.nd();
00311 eval_data.r = __x*__x;
00312 break;
00313 case EXPRINFO_INTPOWER:
00314 eval_data.r = __power(__data.coeffs[0], __rval, __data.params.nn());
00315 break;
00316 case EXPRINFO_SQROOT:
00317 eval_data.r = sqrt(__data.coeffs[0]*__rval+__data.params.nd());
00318 break;
00319 case EXPRINFO_ABS:
00320 eval_data.r = fabs(__data.coeffs[0]*__rval+__data.params.nd());
00321 break;
00322 case EXPRINFO_POW:
00323 __x = __rval * __data.coeffs[eval_data.n];
00324 if(eval_data.n == 0)
00325 eval_data.r = __x+__data.params.nd();
00326 else
00327 eval_data.r = pow(eval_data.r, __x);
00328 ++eval_data.n;
00329 break;
00330 case EXPRINFO_EXP:
00331 eval_data.r = exp(__rval*__data.coeffs[0]+__data.params.nd());
00332 break;
00333 case EXPRINFO_LOG:
00334 eval_data.r = log(__rval*__data.coeffs[0]+__data.params.nd());
00335 break;
00336 case EXPRINFO_SIN:
00337 eval_data.r = sin(__rval*__data.coeffs[0]+__data.params.nd());
00338 break;
00339 case EXPRINFO_COS:
00340 eval_data.r = cos(__rval*__data.coeffs[0]+__data.params.nd());
00341 break;
00342 case EXPRINFO_ATAN2:
00343 __x = __rval * __data.coeffs[eval_data.n];
00344 if(eval_data.n == 0)
00345 eval_data.r = __x;
00346 else
00347 eval_data.r = atan2(eval_data.r, __x);
00348 ++eval_data.n;
00349 break;
00350 case EXPRINFO_GAUSS:
00351 __x = (__data.coeffs[0]*__rval-__data.params.d()[0])/
00352 __data.params.d()[1];
00353 eval_data.r = exp(-__x*__x);
00354 break;
00355 case EXPRINFO_POLY:
00356 std::cerr << "Polynomes NYI" << std::endl;
00357 throw "NYI";
00358 break;
00359 case EXPRINFO_LIN:
00360 case EXPRINFO_QUAD:
00361
00362 break;
00363 case EXPRINFO_IN:
00364 {
00365 __x = __data.coeffs[eval_data.n]*__rval;
00366 const interval& i(__data.params.i()[eval_data.n]);
00367 if(eval_data.r != -1 && i.contains(__x))
00368 {
00369 if(eval_data.r == 1 && (__x == i.inf() || __x == i.sup()))
00370 eval_data.r = 0;
00371 }
00372 else
00373 {
00374 eval_data.r = -1;
00375 ret = -1;
00376 }
00377 }
00378 eval_data.n++;
00379 break;
00380 case EXPRINFO_IF:
00381 __x = __rval * __data.coeffs[eval_data.n];
00382 if(eval_data.n == 0)
00383 {
00384 const interval& i(__data.params.ni());
00385 if(!i.contains(__x))
00386 ret = 1;
00387 }
00388 else
00389 {
00390 eval_data.r = __x;
00391 ret = -1;
00392 }
00393 eval_data.n += ret+1;
00394 break;
00395 case EXPRINFO_AND:
00396 { __x = __data.coeffs[eval_data.n]*__rval;
00397 const interval& i(__data.params.i()[eval_data.n]);
00398 if(eval_data.r == 1 && !i.contains(__x))
00399 {
00400 eval_data.r = 0;
00401 ret = -1;
00402 }
00403 }
00404 eval_data.n++;
00405 break;
00406 case EXPRINFO_OR:
00407 { __x = __data.coeffs[eval_data.n]*__rval;
00408 const interval& i(__data.params.i()[eval_data.n]);
00409 if(eval_data.r == 0 && i.contains(__x))
00410 {
00411 eval_data.r = 1;
00412 ret = -1;
00413 }
00414 }
00415 eval_data.n++;
00416 break;
00417 case EXPRINFO_NOT:
00418 { __x = __data.coeffs[0]*__rval;
00419 const interval& i(__data.params.ni());
00420 if(i.contains(__x))
00421 eval_data.r = 0;
00422 else
00423 eval_data.r = 1;
00424 }
00425 break;
00426 case EXPRINFO_IMPLIES:
00427 { const interval& i(__data.params.i()[eval_data.n]);
00428 __x = __rval * __data.coeffs[eval_data.n];
00429 if(eval_data.n == 0)
00430 {
00431 if(!i.contains(__x))
00432 {
00433 eval_data.r = 1;
00434 ret = -1;
00435 }
00436 }
00437 else
00438 eval_data.r = i.contains(__x) ? 1 : 0;
00439 ++eval_data.n;
00440 }
00441 break;
00442 case EXPRINFO_COUNT:
00443 { __x = __data.coeffs[eval_data.n]*__rval;
00444 const interval& i(__data.params.i()[eval_data.n]);
00445 if(i.contains(__x))
00446 eval_data.r += 1;
00447 }
00448 eval_data.n++;
00449 break;
00450 case EXPRINFO_ALLDIFF:
00451 { __x = __data.coeffs[eval_data.n]*__rval;
00452 for(std::vector<double>::const_iterator _b =
00453 ((std::vector<double>*)eval_data.u.p)->begin();
00454 _b != ((std::vector<double>*)eval_data.u.p)->end(); ++_b)
00455 {
00456 if(fabs(__x-*_b) <= __data.params.nd())
00457 {
00458 eval_data.r = 0;
00459 ret = -1;
00460 break;
00461 }
00462 }
00463 if(ret != -1)
00464 ((std::vector<double>*) eval_data.u.p)->push_back(__x);
00465 }
00466 eval_data.n++;
00467 if(eval_data.n == __data.n_children || ret == -1)
00468 delete (std::vector<double>*) eval_data.u.p;
00469 break;
00470 case EXPRINFO_HISTOGRAM:
00471 std::cerr << "func_evaluator: histogram NYI" << std::endl;
00472 throw "NYI";
00473 break;
00474 case EXPRINFO_LEVEL:
00475 { int h = (int)eval_data.r;
00476 __x = __data.coeffs[eval_data.n]*__rval;
00477 interval _h;
00478
00479 if(h != INT_MAX)
00480 {
00481 while(h < __data.params.im().nrows())
00482 {
00483 _h = __data.params.im()[h][eval_data.n];
00484 if(_h.contains(__x))
00485 break;
00486 h++;
00487 }
00488 if(h == __data.params.im().nrows())
00489 {
00490 ret = -1;
00491 eval_data.r = INT_MAX;
00492 }
00493 else
00494 eval_data.r = h;
00495 }
00496 }
00497 eval_data.n++;
00498 break;
00499 case EXPRINFO_NEIGHBOR:
00500 if(eval_data.n == 0)
00501 eval_data.r = __data.coeffs[0]*__rval;
00502 else
00503 {
00504 double h = eval_data.r;
00505 eval_data.r = 0;
00506 __x = __data.coeffs[1]*__rval;
00507 for(unsigned int i = 0; i < __data.params.n().size(); i+=2)
00508 {
00509 if(h == __data.params.n()[i] && __x == __data.params.n()[i+1])
00510 {
00511 eval_data.r = 1;
00512 break;
00513 }
00514 }
00515 }
00516 eval_data.n++;
00517 break;
00518 case EXPRINFO_NOGOOD:
00519 {
00520 __x = __data.coeffs[eval_data.n]*__rval;
00521 if(eval_data.r == 0 || __data.params.n()[eval_data.n] != __x)
00522 {
00523 eval_data.r = 0;
00524 ret = -1;
00525 }
00526 }
00527 eval_data.n++;
00528 break;
00529 case EXPRINFO_EXPECTATION:
00530 std::cerr << "func_evaluator: E NYI" << std::endl;
00531 throw "NYI";
00532 break;
00533 case EXPRINFO_INTEGRAL:
00534 std::cerr << "func_evaluator: INT NYI" << std::endl;
00535 throw "NYI";
00536 break;
00537 case EXPRINFO_LOOKUP:
00538 case EXPRINFO_PWLIN:
00539 case EXPRINFO_SPLINE:
00540 case EXPRINFO_PWCONSTLC:
00541 case EXPRINFO_PWCONSTRC:
00542 std::cerr << "func_evaluator: Table Operations NYI" << std::endl;
00543 throw "NYI";
00544 break;
00545 case EXPRINFO_DET:
00546 case EXPRINFO_COND:
00547 case EXPRINFO_PSD:
00548 case EXPRINFO_MPROD:
00549 case EXPRINFO_FEM:
00550 std::cerr << "func_evaluator: Matrix Operations NYI" << std::endl;
00551 throw "NYI";
00552 break;
00553 case EXPRINFO_RE:
00554 case EXPRINFO_IM:
00555 case EXPRINFO_ARG:
00556 case EXPRINFO_CPLXCONJ:
00557 std::cerr << "func_evaluator: Complex Numbers NYI" << std::endl;
00558 throw "NYI";
00559 break;
00560 case EXPRINFO_CMPROD:
00561 case EXPRINFO_CGFEM:
00562 std::cerr << "func_evaluator: Const Matrix Operations NYI" << std::endl;
00563 throw "NYI";
00564 break;
00565 default:
00566 std::cerr << "func_evaluator: unknown function type " <<
00567 __data.operator_type << std::endl;
00568 throw "Programming error";
00569 break;
00570 }
00571 }
00572 else if(__data.operator_type > 0)
00573
00574 eval_data.r = __data.f_evaluate(eval_data.n++, __data.params.nn(),
00575 *eval_data.x, *v_ind, eval_data.r,
00576 __rval, NULL);
00577
00578 if(eval_data.cache && __data.n_parents > 1 && __data.n_children > 0)
00579 (*eval_data.cache)[__data.node_num] = eval_data.r;
00580 return ret;
00581 }
00582
00583 double calculate_value(bool eval_all)
00584 {
00585 return eval_data.r;
00586 }
00587 };
00588
00589 #endif