Fleet  0.0.9
Inference in the LOT
TNormalVariable.h
Go to the documentation of this file.
1 #pragma once
2 
3 #include "EigenLib.h"
4 #include "Interfaces/MCMCable.h"
5 
16 template<float (*f)(float) > // transform set at compile time
17 class TNormalVariable : public MCMCable<TNormalVariable<f>, void*> {
18 public:
20  using data_t = typename MCMCable<TNormalVariable<f>, void*>::data_t;
21 
22  float MEAN = 0.0;
23  float SD = 1.0;
24  float PROPOSAL_SCALE = 0.250;
25 
27 
28  float value;
29  float fvalue; // transformed variable
30 
31  TNormalVariable() : can_propose(true) {
32  //set(0.0); // just so it doesn't start insane
33  }
34 
35  // Set the value (via float or double)
36  // NOTE: This ses the PRE-transformed value
37  template<typename T>
38  void set_untransformed(const T v) {
39  if(v != value) {
40  value = v;
41  fvalue = f(v);
42  }
43  }
44 
46  return value;
47  }
48 
52  float get() const {
53  return fvalue;
54  }
55 
56  virtual double compute_prior() override {
57  // Defaultly a unit normal
58  return this->prior = normal_lpdf(value, MEAN, SD);
59  }
60 
61 
62  virtual double compute_likelihood(const data_t& data, const double breakout=-infinity) override {
63  return 0.0; // this is here so we can run MCMC with no data
64 // throw YouShouldNotBeHereError("*** Should not call likelihood here");
65  }
66 
67  virtual std::optional<std::pair<self_t,double>> propose() const override {
68  if(not can_propose) return {};
69 
70  self_t out = *this;
71  out.set_untransformed(value + PROPOSAL_SCALE*random_normal());
72  return std::make_pair(out, 0.0); // everything is symmetrical so fb=0
73 
74  }
75 
76  virtual self_t restart() const override {
77  self_t out = *this;
78  if(can_propose) {
79  out.set_untransformed(MEAN + SD*random_normal());
80  }
81  else {
82  // should already have just copied it anyways
83  }
84  return out;
85  }
86 
87  virtual size_t hash() const override {
88  return std::hash<double>{}(value);
89  }
90 
91  virtual bool operator==(const self_t& h) const override {
92  return value == h.value;
93  }
94 
95  virtual std::string string(std::string prefix="") const override {
96  return prefix+"TN<"+str(value)+">";
97  }
98 
99 };
100 
101 
102 using UniformVariable = TNormalVariable< +[](float x)->float { return normal_cdf<float>(x, 0.0, 1.0); }>;
103 
104 using ExponentialVariable = TNormalVariable< +[](float x)->float { return -log(normal_cdf<float>(-x, 0.0, 1.0)); }>;
typename MCMCable< TNormalVariable< f >, void * >::data_t data_t
Definition: TNormalVariable.h:20
float PROPOSAL_SCALE
Definition: TNormalVariable.h:24
float get_untransformed()
Definition: TNormalVariable.h:45
virtual double compute_likelihood(const data_t &data, const double breakout=-infinity) override
Definition: TNormalVariable.h:62
Definition: MCMCable.h:14
double prior
Definition: Bayesable.h:42
virtual bool operator==(const self_t &h) const override
Definition: TNormalVariable.h:91
float MEAN
Definition: TNormalVariable.h:22
Definition: TNormalVariable.h:17
std::string str(BindingTree *t)
Definition: BindingTree.h:195
void set_untransformed(const T v)
Definition: TNormalVariable.h:38
double random_normal(double mu=0, double sd=1.0)
Definition: Random.h:39
constexpr double infinity
Definition: Numerics.h:20
virtual self_t restart() const override
Definition: TNormalVariable.h:76
virtual double compute_prior() override
Definition: TNormalVariable.h:56
virtual std::optional< std::pair< self_t, double > > propose() const override
Definition: TNormalVariable.h:67
float fvalue
Definition: TNormalVariable.h:29
virtual std::string string(std::string prefix="") const override
Definition: TNormalVariable.h:95
float value
Definition: TNormalVariable.h:28
std::vector< Args... > data_t
Definition: Bayesable.h:39
bool can_propose
Definition: TNormalVariable.h:26
float SD
Definition: TNormalVariable.h:23
TNormalVariable()
Definition: TNormalVariable.h:31
A class is MCMCable if it is Bayesable and lets us propose, restart, and check equality (which MCMC d...
T normal_lpdf(T x, T mu=0.0, T sd=1.0)
Definition: Random.h:45
virtual size_t hash() const override
Definition: TNormalVariable.h:87