00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00028 #ifndef _DIFFNUMBER_EVALUATOR_H_
00029 #define _DIFFNUMBER_EVALUATOR_H_
00030
00031 #include <coconut_config.h>
00032 #include <evaluator.h>
00033 #include <expression.h>
00034 #include <model.h>
00035 #include <eval_main.h>
00036 #include <linalg.h>
00037 #include <math.h>
00038 #include <api_exception.h>
00039 #include <diffNumber.h>
00040
00041 using namespace vgtl;
00042
00043 namespace coco {
00044
00046 typedef diffNumber (*diffNumber_evaluator)(const std::vector<diffNumber>* __x,
00047 const variable_indicator& __v);
00048
00050
00055
00056
00057 struct diffNumber_eval_type
00058 {
00059 const std::vector<diffNumber>* x;
00060 std::vector<diffNumber>* cache;
00061 const model* mod;
00062 void *p;
00063 diffNumber r;
00064 unsigned int n;
00065 int maxdeg;
00066 };
00067
00069
00075 class diffNumber_eval :
00076 public cached_forward_evaluator_base<diffNumber_eval_type,expression_node,
00077 diffNumber,
00078 expression_const_walker>
00079 {
00080 private:
00081
00082 typedef cached_forward_evaluator_base<diffNumber_eval_type,expression_node,
00083 diffNumber, expression_const_walker> _Base;
00084
00085 protected:
00088 bool is_cached(const node_data_type& __data)
00089 {
00090 #ifdef DEBUG
00091 std::cout<<"diffNumber_evaluator.is_cached\n";
00092 #endif
00093 if(__data.operator_type == EXPRINFO_LIN ||
00094 __data.operator_type == EXPRINFO_QUAD)
00095 return true;
00096 else
00097 return false;
00098 #if 0
00099 if(eval_data.cache && __data.n_parents > 1 && __data.n_children > 0)
00100 return true;
00101 else
00102 return false;
00103 #endif
00104 }
00105
00106 private:
00109 diffNumber __power(double __coeff, diffNumber __x, int __exp)
00110 {
00111 if(__exp == 0)
00112 return 1.;
00113 else
00114 {
00115 diffNumber k = __coeff*__x;
00116 switch(__exp)
00117 {
00118 case 1:
00119 return k;
00120 break;
00121 case 2:
00122 return k*k;
00123 break;
00124 case -1:
00125 return 1./k;
00126 break;
00127 case -2:
00128 return 1./(k*k);
00129 break;
00130 default:
00131
00132 break;
00133 }
00134 }
00135 return diffNumber(0,eval_data.maxdeg);
00136 }
00137
00138 public:
00146 diffNumber_eval(const std::vector<diffNumber>& __x, const variable_indicator& __v,
00147 const model& __m, std::vector<diffNumber>* __c, int __maxd) : _Base()
00148 {
00149 #ifdef DEBUG
00150 std::cout<<"diffNumber_eval\n";
00151 #endif
00152 eval_data.x = &__x;
00153 eval_data.cache = __c;
00154 eval_data.mod = &__m;
00155 v_ind = &__v;
00156 eval_data.n = 0;
00157 eval_data.maxdeg=__maxd;
00158 eval_data.r.setDeg(__maxd);
00159 eval_data.r=diffNumber(0.,__maxd);
00160 }
00161
00163 diffNumber_eval(const diffNumber_eval& __v) : _Base(__v) {}
00164
00166 ~diffNumber_eval() {}
00167
00169 expression_const_walker short_cut_to(const expression_node& __data)
00170 {
00171 #ifdef DEBUG
00172 std::cout<<"short_cut_to\n";
00173 #endif
00174 return eval_data.mod->node(0); }
00175
00179 void new_point(const std::vector<diffNumber>& __x, const variable_indicator& __v)
00180 {
00181 eval_data.x = &__x;
00182 v_ind = &__v;
00183 }
00184
00186
00187 void initialize() { return; }
00188
00189 int initialize(const expression_node& __data)
00190 {
00191 #ifdef DEBUG
00192 std::cout<<"initialize\n";
00193 #endif
00194 eval_data.n = 0;
00195 if(__data.ev && (*__data.ev)[FUNC_RANGE])
00196
00197 {
00198 eval_data.r =
00199 (*(diffNumber_evaluator)(*__data.ev)[FUNC_RANGE])(eval_data.x,
00200 *v_ind);
00201 return 0;
00202 }
00203 else
00204 {
00205 switch(__data.operator_type)
00206 {
00207 case EXPRINFO_VARIABLE:
00208 #ifdef DEBUG
00209 std::cout<<"variable "<<__data.params.nn()<<"\n";
00210 #endif
00211 eval_data.r = (*eval_data.x)[__data.params.nn()];
00212 break;
00213 case EXPRINFO_SUM:
00214 case EXPRINFO_PROD:
00215 case EXPRINFO_MAX:
00216 case EXPRINFO_MIN:
00217 case EXPRINFO_INVERT:
00218 case EXPRINFO_DIV:
00219 eval_data.r = diffNumber(__data.params.nd(),eval_data.maxdeg);
00220 break;
00221 case EXPRINFO_MONOME:
00222 case EXPRINFO_IN:
00223 case EXPRINFO_AND:
00224 case EXPRINFO_NOGOOD:
00225 eval_data.r = diffNumber(1.,eval_data.maxdeg);
00226 break;
00227 case EXPRINFO_ALLDIFF:
00228
00229
00230
00231 case EXPRINFO_MEAN:
00232 case EXPRINFO_IF:
00233 case EXPRINFO_OR:
00234 case EXPRINFO_NOT:
00235 case EXPRINFO_COUNT:
00236 case EXPRINFO_SCPROD:
00237 case EXPRINFO_NORM:
00238 case EXPRINFO_LEVEL:
00239 eval_data.r = diffNumber(0.,eval_data.maxdeg);
00240 break;
00241 case EXPRINFO_DET:
00242 case EXPRINFO_PSD:
00243
00244 break;
00245 case EXPRINFO_COND:
00246 case EXPRINFO_FEM:
00247 case EXPRINFO_MPROD:
00248
00249 break;
00250 }
00251 return 1;
00252 }
00253 }
00254
00255 void calculate(const expression_node& __data)
00256 {
00257 if(__data.operator_type > 0){
00258 #ifdef DEBUG
00259 std::cout<<"diffNumber_evaluator.calculate\n";
00260 #endif
00261 }
00262 }
00263
00264 void retrieve_from_cache(const expression_node& __data)
00265 {
00266 #ifdef DEBUG
00267 std::cout<<"diffNumber_evaluator.retrieve_from_cache\n";
00268 #endif
00269 }
00270
00271 int update(const diffNumber& __rval)
00272 {
00273 #ifdef DEBUG
00274 std::cout<<"diffNumber_evaluator.update\n";
00275 #endif
00276 eval_data.r = __rval;
00277 return 0;
00278 }
00279
00280 int update(const expression_node& __data, const diffNumber& __rval)
00281 {
00282 #ifdef DEBUG
00283 std::cout<<"diffNumber_evaluator.update2 ";
00284 std::cout<<__rval<<"\n";
00285 #endif
00286 diffNumber __x(0.,eval_data.maxdeg);
00287 int ret = 0;
00288 if(__data.operator_type < 0)
00289 {
00290 switch(__data.operator_type)
00291 {
00292 case EXPRINFO_CONSTANT:
00293 case EXPRINFO_VARIABLE:
00294 break;
00295 case EXPRINFO_SUM:
00296 case EXPRINFO_MEAN:
00297 #ifdef DEBUG
00298 std::cout<<"sum: "<<eval_data.r<<" -> ";
00299 #endif
00300 eval_data.r = eval_data.r + __data.coeffs[eval_data.n++]*__rval;
00301 #ifdef DEBUG
00302 std::cout<<eval_data.r<<"\n";
00303 #endif
00304 break;
00305 case EXPRINFO_PROD:
00306
00307 #ifdef DEBUG
00308 std::cout<<"prod";
00309 #endif
00310 eval_data.r = eval_data.r * __rval;
00311
00312 break;
00313 case EXPRINFO_MONOME:
00314 #ifdef DEBUG
00315 std::cout<<"monome";
00316 #endif
00317 eval_data.r = eval_data.r * __power(__data.coeffs[eval_data.n],
00318 __rval, __data.params.n()[eval_data.n]);
00319
00320 eval_data.n++;
00321 break;
00322
00323
00324
00325
00326
00327
00328
00329
00330
00331
00332
00333
00334
00335
00336
00337
00338
00339
00340
00341
00342
00343
00344
00345
00346
00347
00348
00349
00350
00351
00352
00353
00354
00355
00356
00357
00358 case EXPRINFO_INVERT:
00359 eval_data.r = eval_data.r / __rval;
00360 break;
00361 case EXPRINFO_DIV:
00362
00363
00364 if(eval_data.n == 0)
00365 {
00366 eval_data.r = eval_data.r * __rval;
00367
00368 }
00369 else
00370 eval_data.r = eval_data.r / __rval;
00371 ++eval_data.n;
00372 break;
00373 case EXPRINFO_SQUARE:
00374 #ifdef DEBUG
00375 std::cout<<"square :"<<__rval<<" -> ";
00376 #endif
00377 __x = __data.coeffs[0]*__rval+__data.params.nd();
00378 eval_data.r = __x*__x;
00379 #ifdef DEBUG
00380 std::cout<<eval_data.r<<"\n";
00381 std::cout<<"coeffs: "<<__data.coeffs[0]<<" params: "<<__data.params.nd()<<"\n";
00382 #endif
00383 break;
00384 case EXPRINFO_INTPOWER:
00385 #ifdef DEBUG
00386 std::cout<<"intpower \n";
00387 #endif
00388 eval_data.r = __power(__data.coeffs[0], __rval, __data.params.nn());
00389 break;
00390 case EXPRINFO_SQROOT:
00391 eval_data.r = sqrt(__data.coeffs[0]*__rval+__data.params.nd());
00392 break;
00393
00394
00395
00396
00397
00398
00399
00400
00401
00402
00403
00404 case EXPRINFO_EXP:
00405 eval_data.r = exp(__rval*__data.coeffs[0]+__data.params.nd());
00406 break;
00407
00408
00409
00410 case EXPRINFO_SIN:
00411 eval_data.r = sin(__rval*__data.coeffs[0]+__data.params.nd());
00412 break;
00413 case EXPRINFO_COS:
00414 eval_data.r = cos(__rval*__data.coeffs[0]+__data.params.nd());
00415 break;
00416 case EXPRINFO_ATAN2:
00417 __x = __rval * __data.coeffs[eval_data.n];
00418 if(eval_data.n == 0)
00419 eval_data.r = __x;
00420 else
00421 eval_data.r = atan2(eval_data.r, __x);
00422 ++eval_data.n;
00423 break;
00424
00425
00426
00427
00428
00429
00430
00431
00432
00433
00434
00435
00436
00437
00438
00439
00440
00441
00442
00443
00444
00445
00446
00447
00448
00449
00450
00451
00452
00453
00454
00455
00456
00457
00458
00459
00460
00461
00462
00463
00464
00465
00466
00467
00468
00469
00470
00471
00472
00473
00474
00475
00476
00477
00478
00479
00480
00481
00482
00483
00484
00485
00486
00487
00488
00489
00490
00491
00492
00493
00494
00495
00496
00497
00498
00499
00500
00501
00502
00503
00504
00505
00506
00507
00508
00509
00510
00511
00512
00513
00514
00515
00516
00517
00518
00519
00520
00521
00522
00523
00524
00525
00526
00527
00528
00529
00530
00531
00532
00533
00534
00535
00536
00537
00538
00539
00540
00541
00542
00543
00544
00545
00546
00547
00548
00549
00550
00551
00552
00553
00554
00555
00556
00557
00558
00559
00560
00561
00562
00563
00564
00565
00566
00567
00568
00569
00570
00571
00572
00573
00574
00575
00576
00577
00578
00579
00580
00581
00582
00583
00584
00585
00586
00587
00588
00589
00590
00591
00592
00593
00594
00595
00596
00597
00598
00599
00600
00601
00602
00603
00604
00605
00606
00607
00608
00609
00610
00611
00612
00613
00614
00615
00616
00617
00618
00619
00620
00621
00622
00623
00624
00625
00626
00627
00628
00629
00630
00631
00632
00633 case EXPRINFO_EXPECTATION:
00634 throw nyi_exception("func_evaluator: EXPECTATION");
00635 break;
00636 case EXPRINFO_INTEGRAL:
00637 throw nyi_exception("func_evaluator: INTEGRAL");
00638 break;
00639 case EXPRINFO_LOOKUP:
00640 throw nyi_exception("func_evaluator: LOOKUP");
00641 break;
00642
00643
00644
00645
00646
00647
00648
00649
00650
00651
00652
00653
00654
00655
00656
00657
00658
00659
00660
00661
00662 case EXPRINFO_SPLINE:
00663 throw nyi_exception("func_evaluator: SPLINE");
00664 break;
00665
00666
00667
00668
00669
00670
00671
00672
00673
00674
00675
00676
00677
00678
00679
00680
00681
00682
00683
00684
00685
00686
00687
00688
00689
00690
00691
00692
00693 case EXPRINFO_DET:
00694 case EXPRINFO_COND:
00695 case EXPRINFO_PSD:
00696 case EXPRINFO_MPROD:
00697 case EXPRINFO_FEM:
00698 throw nyi_exception("func_evaluator: Matrix Operations");
00699 break;
00700 case EXPRINFO_RE:
00701 case EXPRINFO_IM:
00702 case EXPRINFO_ARG:
00703 case EXPRINFO_CPLXCONJ:
00704 throw nyi_exception("func_evaluator: Complex Operations");
00705 break;
00706 case EXPRINFO_CMPROD:
00707 case EXPRINFO_CGFEM:
00708 throw nyi_exception("func_evaluator: Const Matrix Operations");
00709 default:
00710 throw api_exception(apiee_evaluator,
00711 std::string("func_evaluator: unknown function type ")+
00712 convert_to_str(__data.operator_type));
00713 break;
00714 }
00715 }
00716 else if(__data.operator_type > 0)
00717
00718
00719
00720
00721
00722
00723 if(eval_data.cache && __data.n_parents > 1 && __data.n_children > 0)
00724 (*eval_data.cache)[__data.node_num] = eval_data.r;
00725 return ret;
00726 }
00727
00728 diffNumber calculate_value(bool eval_all)
00729 {
00730 #ifdef DEBUG
00731 std::cout<<"diffNumber_evaluator.calculate_value\n";
00732 #endif
00733 return eval_data.r;
00734 }
00736 };
00737
00738 }
00739
00740 #endif