00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00028 #ifndef _DIFFI_EVALUATOR_H_
00029 #define _DIFFI_EVALUATOR_H_
00030
00031 #include <coconut_config.h>
00032 #include <evaluator.h>
00033 #include <expression.h>
00034 #include <model.h>
00035 #include <eval_main.h>
00036 #include <linalg.h>
00037 #include <math.h>
00038 #include <api_exception.h>
00039 #include <diffI.h>
00040
00041 using namespace vgtl;
00042
00043 namespace coco {
00044
00046 typedef diffI (*diffI_evaluator)(const std::vector<diffI>* __x,
00047 const variable_indicator& __v);
00048
00050
00055
00056
00057 struct diffI_eval_type
00058 {
00059 const std::vector<diffI>* x;
00060 std::vector<diffI>* cache;
00061 const model* mod;
00062 void *p;
00063 diffI r;
00064 unsigned int n;
00065 int maxdeg;
00066 };
00067
00069
00075 class diffI_eval :
00076 public cached_forward_evaluator_base<diffI_eval_type,expression_node,
00077 diffI,
00078 expression_const_walker>
00079 {
00080 private:
00081
00082 typedef cached_forward_evaluator_base<diffI_eval_type,expression_node,
00083 diffI, expression_const_walker> _Base;
00084
00085 protected:
00088 bool is_cached(const node_data_type& __data)
00089 { std::cout<<"diffI_evaluator.is_cached\n";
00090 if(__data.operator_type == EXPRINFO_LIN ||
00091 __data.operator_type == EXPRINFO_QUAD)
00092 return true;
00093 else
00094 return false;
00095 #if 0
00096 if(eval_data.cache && __data.n_parents > 1 && __data.n_children > 0)
00097 return true;
00098 else
00099 return false;
00100 #endif
00101 }
00102
00103 private:
00106 diffI __power(double __coeff, diffI __x, int __exp)
00107 {
00108 if(__exp == 0)
00109 return 1.;
00110 else
00111 {
00112 diffI k = __coeff*__x;
00113 switch(__exp)
00114 {
00115 case 1:
00116 return k;
00117 break;
00118 case 2:
00119 return k*k;
00120 break;
00121 case -1:
00122 return 1./k;
00123 break;
00124 case -2:
00125 return 1./(k*k);
00126 break;
00127 default:
00128
00129 break;
00130 }
00131 }
00132 return diffI(0,eval_data.maxdeg);
00133 }
00134
00135 public:
00143 diffI_eval(const std::vector<diffI>& __x, const variable_indicator& __v,
00144 const model& __m, std::vector<diffI>* __c, int __maxd) : _Base()
00145 {
00146 std::cout<<"diffI_eval\n";
00147 eval_data.x = &__x;
00148 eval_data.cache = __c;
00149 eval_data.mod = &__m;
00150 v_ind = &__v;
00151 eval_data.n = 0;
00152 eval_data.maxdeg=__maxd;
00153 eval_data.r.setDeg(__maxd);
00154 eval_data.r=diffI(0.,__maxd);
00155 }
00156
00158 diffI_eval(const diffI_eval& __v) : _Base(__v) {}
00159
00161 ~diffI_eval() {}
00162
00164 expression_const_walker short_cut_to(const expression_node& __data)
00165 {
00166 std::cout<<"short_cut_to\n";
00167 return eval_data.mod->node(0); }
00168
00172 void new_interval(const std::vector<diffI>& __x, const variable_indicator& __v)
00173 {
00174 eval_data.x = &__x;
00175 v_ind = &__v;
00176 }
00177
00179
00180 void initialize() { return; }
00181
00182 int initialize(const expression_node& __data)
00183 {
00184 std::cout<<"initialize\n";
00185 eval_data.n = 0;
00186 if(__data.ev && (*__data.ev)[FUNC_RANGE])
00187
00188 {
00189 eval_data.r =
00190 (*(diffI_evaluator)(*__data.ev)[FUNC_RANGE])(eval_data.x,
00191 *v_ind);
00192 return 0;
00193 }
00194 else
00195 {
00196 switch(__data.operator_type)
00197 {
00198 case EXPRINFO_VARIABLE:
00199 std::cout<<"variable "<<__data.params.nn()<<"\n";
00200 eval_data.r = (*eval_data.x)[__data.params.nn()];
00201 break;
00202 case EXPRINFO_SUM:
00203 case EXPRINFO_PROD:
00204 case EXPRINFO_MAX:
00205 case EXPRINFO_MIN:
00206 case EXPRINFO_INVERT:
00207 case EXPRINFO_DIV:
00208 eval_data.r = diffI(__data.params.nd(),eval_data.maxdeg);
00209 break;
00210 case EXPRINFO_MONOME:
00211 case EXPRINFO_IN:
00212 case EXPRINFO_AND:
00213 case EXPRINFO_NOGOOD:
00214 eval_data.r = diffI(1.,eval_data.maxdeg);
00215 break;
00216 case EXPRINFO_ALLDIFF:
00217
00218
00219
00220 case EXPRINFO_MEAN:
00221 case EXPRINFO_IF:
00222 case EXPRINFO_OR:
00223 case EXPRINFO_NOT:
00224 case EXPRINFO_COUNT:
00225 case EXPRINFO_SCPROD:
00226 case EXPRINFO_NORM:
00227 case EXPRINFO_LEVEL:
00228 eval_data.r = diffI(0.,eval_data.maxdeg);
00229 break;
00230 case EXPRINFO_DET:
00231 case EXPRINFO_PSD:
00232
00233 break;
00234 case EXPRINFO_COND:
00235 case EXPRINFO_FEM:
00236 case EXPRINFO_MPROD:
00237
00238 break;
00239 }
00240 return 1;
00241 }
00242 }
00243
00244 void calculate(const expression_node& __data)
00245 {
00246 if(__data.operator_type > 0){
00247 std::cout<<"diffI_evaluator.calculate\n";
00248 }
00249 }
00250
00251 void retrieve_from_cache(const expression_node& __data)
00252 {
00253 std::cout<<"diffI_evaluator.retrieve_from_cache\n";
00254 }
00255
00256 int update(const diffI& __rval)
00257 {
00258 std::cout<<"diffI_evaluator.update\n";
00259 eval_data.r = __rval;
00260 return 0;
00261 }
00262
00263 int update(const expression_node& __data, const diffI& __rval)
00264 {
00265 std::cout<<"diffI_evaluator.update2 ";
00266 std::cout<<__rval<<"\n";
00267
00268 diffI __x(0.,eval_data.maxdeg);
00269 int ret = 0;
00270 if(__data.operator_type < 0)
00271 {
00272 switch(__data.operator_type)
00273 {
00274 case EXPRINFO_CONSTANT:
00275 case EXPRINFO_VARIABLE:
00276 break;
00277 case EXPRINFO_SUM:
00278 case EXPRINFO_MEAN:
00279 std::cout<<"sum: "<<eval_data.r<<" -> ";
00280 eval_data.r = eval_data.r + __data.coeffs[eval_data.n++]*__rval;
00281 std::cout<<eval_data.r<<"\n";
00282 break;
00283 case EXPRINFO_PROD:
00284
00285 std::cout<<"prod";
00286 eval_data.r = eval_data.r * __rval;
00287
00288 break;
00289 case EXPRINFO_MONOME:
00290 std::cout<<"monome";
00291 eval_data.r = eval_data.r * __power(__data.coeffs[eval_data.n],
00292 __rval, __data.params.n()[eval_data.n]);
00293
00294 eval_data.n++;
00295 break;
00296
00297
00298
00299
00300
00301
00302
00303
00304
00305
00306
00307
00308
00309
00310
00311
00312
00313
00314
00315
00316
00317
00318
00319
00320
00321
00322
00323
00324
00325
00326
00327
00328
00329
00330
00331
00332 case EXPRINFO_INVERT:
00333 eval_data.r = eval_data.r / __rval;
00334 break;
00335 case EXPRINFO_DIV:
00336
00337
00338 if(eval_data.n == 0)
00339 {
00340 eval_data.r = eval_data.r * __rval;
00341
00342 }
00343 else
00344 eval_data.r = eval_data.r / __rval;
00345 ++eval_data.n;
00346 break;
00347 case EXPRINFO_SQUARE:
00348 std::cout<<"square :"<<__rval<<" -> ";
00349 __x = __data.coeffs[0]*__rval+__data.params.nd();
00350 eval_data.r = __x*__x;
00351 std::cout<<eval_data.r<<"\n";
00352 std::cout<<"coeffs: "<<__data.coeffs[0]<<" params: "<<__data.params.nd()<<"\n";
00353 break;
00354 case EXPRINFO_INTPOWER:
00355 std::cout<<"intpower \n";
00356 eval_data.r = __power(__data.coeffs[0], __rval, __data.params.nn());
00357 break;
00358 case EXPRINFO_SQROOT:
00359 eval_data.r = sqrt(__data.coeffs[0]*__rval+__data.params.nd());
00360 break;
00361
00362
00363
00364
00365
00366
00367
00368
00369
00370
00371
00372 case EXPRINFO_EXP:
00373 eval_data.r = exp(__rval*__data.coeffs[0]+__data.params.nd());
00374 break;
00375
00376
00377
00378 case EXPRINFO_SIN:
00379 eval_data.r = sin(__rval*__data.coeffs[0]+__data.params.nd());
00380 break;
00381 case EXPRINFO_COS:
00382 eval_data.r = cos(__rval*__data.coeffs[0]+__data.params.nd());
00383 break;
00384 case EXPRINFO_ATAN2:
00385 __x = __rval * __data.coeffs[eval_data.n];
00386 if(eval_data.n == 0)
00387 eval_data.r = __x;
00388 else
00389 eval_data.r = atan2(eval_data.r, __x);
00390 ++eval_data.n;
00391 break;
00392
00393
00394
00395
00396
00397
00398
00399
00400
00401
00402
00403
00404
00405
00406
00407
00408
00409
00410
00411
00412
00413
00414
00415
00416
00417
00418
00419
00420
00421
00422
00423
00424
00425
00426
00427
00428
00429
00430
00431
00432
00433
00434
00435
00436
00437
00438
00439
00440
00441
00442
00443
00444
00445
00446
00447
00448
00449
00450
00451
00452
00453
00454
00455
00456
00457
00458
00459
00460
00461
00462
00463
00464
00465
00466
00467
00468
00469
00470
00471
00472
00473
00474
00475
00476
00477
00478
00479
00480
00481
00482
00483
00484
00485
00486
00487
00488
00489
00490
00491
00492
00493
00494
00495
00496
00497
00498
00499
00500
00501
00502
00503
00504
00505
00506
00507
00508
00509
00510
00511
00512
00513
00514
00515
00516
00517
00518
00519
00520
00521
00522
00523
00524
00525
00526
00527
00528
00529
00530
00531
00532
00533
00534
00535
00536
00537
00538
00539
00540
00541
00542
00543
00544
00545
00546
00547
00548
00549
00550
00551
00552
00553
00554
00555
00556
00557
00558
00559
00560
00561
00562
00563
00564
00565
00566
00567
00568
00569
00570
00571
00572
00573
00574
00575
00576
00577
00578
00579
00580
00581
00582
00583
00584
00585
00586
00587
00588
00589
00590
00591
00592
00593
00594
00595
00596
00597
00598
00599
00600
00601
00602 case EXPRINFO_EXPECTATION:
00603 throw nyi_exception("func_evaluator: EXPECTATION");
00604 break;
00605 case EXPRINFO_INTEGRAL:
00606 throw nyi_exception("func_evaluator: INTEGRAL");
00607 break;
00608 case EXPRINFO_LOOKUP:
00609 throw nyi_exception("func_evaluator: LOOKUP");
00610 break;
00611
00612
00613
00614
00615
00616
00617
00618
00619
00620
00621
00622
00623
00624
00625
00626
00627
00628
00629
00630
00631 case EXPRINFO_SPLINE:
00632 throw nyi_exception("func_evaluator: SPLINE");
00633 break;
00634
00635
00636
00637
00638
00639
00640
00641
00642
00643
00644
00645
00646
00647
00648
00649
00650
00651
00652
00653
00654
00655
00656
00657
00658
00659
00660
00661
00662 case EXPRINFO_DET:
00663 case EXPRINFO_COND:
00664 case EXPRINFO_PSD:
00665 case EXPRINFO_MPROD:
00666 case EXPRINFO_FEM:
00667 throw nyi_exception("func_evaluator: Matrix Operations");
00668 break;
00669 case EXPRINFO_RE:
00670 case EXPRINFO_IM:
00671 case EXPRINFO_ARG:
00672 case EXPRINFO_CPLXCONJ:
00673 throw nyi_exception("func_evaluator: Complex Operations");
00674 break;
00675 case EXPRINFO_CMPROD:
00676 case EXPRINFO_CGFEM:
00677 throw nyi_exception("func_evaluator: Const Matrix Operations");
00678 default:
00679 throw api_exception(apiee_evaluator,
00680 std::string("func_evaluator: unknown function type ")+
00681 convert_to_str(__data.operator_type));
00682 break;
00683 }
00684 }
00685 else if(__data.operator_type > 0)
00686
00687
00688
00689
00690
00691
00692 if(eval_data.cache && __data.n_parents > 1 && __data.n_children > 0)
00693 (*eval_data.cache)[__data.node_num] = eval_data.r;
00694 return ret;
00695 }
00696
00697 diffI calculate_value(bool eval_all)
00698 {
00699 std::cout<<"diffI_evaluator.calculate_value\n";
00700 return eval_data.r;
00701 }
00703 };
00704
00705 }
00706
00707 #endif