Fleet  0.0.9
Inference in the LOT
MCMCChain.h
Go to the documentation of this file.
1 #pragma once
2 
3 #include <utility>
4 #include <functional>
5 #include "SampleStreams.h"
6 #include "MCMCChain.h"
7 #include "FiniteHistory.h"
8 #include "Control.h"
9 #include "OrderedLock.h"
10 #include "FleetStatistics.h"
11 
12 //#define DEBUG_MCMC 1
13 
22 template<typename _HYP>
23 class MCMCChain {
24 
25 public:
26  using HYP = _HYP;
27 
29 
30  // It's a little important that we use an OrderedLock, because otherwise we have
31  // no guarantees about chains accessing in a FIFO order. Non-FIFO is especially
32  // bad for ParallelTempering, where there are threads doing the adaptation etc.
34 
35  typename HYP::data_t* data;
36 
37  // this stores the maximum found since we've restarted
38  // (not the overall max)
39  double maxval;
40 
41  unsigned long samples; // total number of samples we've done
42  unsigned long proposals;
43  unsigned long acceptances;
44  unsigned long steps_since_improvement;
45 
46  std::atomic<double> temperature; // make atomic b/c ParallelTempering may try to change
47 
49 
50  MCMCChain(HYP& h0, typename HYP::data_t* d) :
51  current(h0), data(d), maxval(-infinity),
52  samples(0), proposals(0), acceptances(0), steps_since_improvement(0),
53  temperature(1.0), history(100) {
54  runOnCurrent();
55  }
56 
57  MCMCChain(HYP&& h0, typename HYP::data_t* d) :
58  current(h0), data(d), maxval(-infinity),
59  samples(0), proposals(0), acceptances(0), steps_since_improvement(0),
60  temperature(1.0), history(100) {
61  runOnCurrent();
62  }
63 
64  MCMCChain(const MCMCChain& m) :
65  current(m.current), data(m.data), maxval(m.maxval),
66  samples(m.samples), proposals(m.proposals), acceptances(m.acceptances),
67  steps_since_improvement(m.steps_since_improvement) {
68  temperature = m.temperature.load();
69  history = m.history;
70 
71  }
73  current = m.current;
74  data = m.data;
75  maxval = m.maxval;
76  samples = m.samples;
77  proposals = m.proposals;
78  acceptances = m.acceptances;
79  steps_since_improvement = m.steps_since_improvement;
80 
81  temperature = m.temperature.load();
82  history = std::move(m.history);
83  }
84 
85  virtual ~MCMCChain() { }
86 
92  void set_data(typename HYP::data_t* d, bool recompute_posterior=true) {
93  data = d;
94  if(recompute_posterior) {
95  current.compute_posterior(*data);
96  }
97  }
98 
104  return current;
105  }
106 
107  const HYP& getCurrent() const {
112  return current;
113  }
114 
115  void runOnCurrent() {
120  std::lock_guard guard(current_mutex);
121  current.compute_posterior(*data);
122  // NOTE: We do NOT count this as a "sample" since it is not yielded
123  }
124 
125 
126  const HYP& getMax() {
127  return maxval;
128  }
129 
130  void restart() {
131 
132  current = current.restart();
133  current.compute_posterior(*data);
134 
135  steps_since_improvement = 0; // reset the couter
136  maxval = current.posterior; // and the new max
137  }
138 
144  virtual bool check(HYP& p) {
145  return true;
146  }
147 
154 
155  assert(ctl.nthreads == 1 && "*** You seem to have called MCMCChain with nthreads>1. This is not how you parallel. Check out ChainPool");
156 
157  #ifdef DEBUG_MCMC
158  DEBUG("# Starting MCMC Chain on\t", current.posterior, current.prior, current.likelihood, current.string());
159  #endif
160 
161  // I may have copied its start time from somewhere else, so change that here
162  ctl.start();
163  while(true) {
164 
165  if(not ctl.running())
166  break;
167 
168  std::lock_guard guard(current_mutex);
169 
170  if(current.posterior > maxval) { // if we improve, store it
171  maxval = current.posterior;
172  steps_since_improvement = 0;
173  }
174  else { // else keep track of how long
176  }
177 
178  // if we haven't improved
179  if(ctl.restart>0 and steps_since_improvement > ctl.restart){
180  [[unlikely]];
181  restart();
182  //print("RESTARTING (from no improvement)", current.string());
183  }
184  else if (std::isnan(current.posterior) or std::isinf(current.posterior)) { // either inf is a restart
185  [[unlikely]];
186  //print("RESTARTING (from -inf)", current.string());
187  // This is a special case where we just propose from restarting
188  restart();
189 
190  // Should we count in history? // Hmm maybe not.
191  }
192  else {
193  // normally we go here and do a proper proposal
194 
195  #ifdef DEBUG_MCMC
196  DEBUG("# Current", current.posterior, current.prior, current.likelihood, current.string());
197  #endif
198 
199  // propose, but restart if we're -infinity
200  auto p = current.propose();
201  if(not p) { continue; }// proposal failed
202 
203  auto [proposal, fb] = p.value();
204 
205  ++proposals;
206 
207  // A lot of proposals end up with the same function, so if so, save time by not
208  // computing the posterior
209  if(proposal == current) {
210  // copy all the properties
211  // NOTE: This is necessary because == might just check value, but operator= will copy everything else
212  proposal = current;
213 
214  // we treat this as an accept
215  history << true;
216  ++acceptances;
217 
218  #ifdef DEBUG_MCMC
219  // they are equal but we just use current here
220  DEBUG("# Proposed(eq)", current.posterior, current.prior, current.likelihood, current.string(), "fb="+str(fb));
221  #endif
222 
224  co_yield current; // must be done with lock
225  }
226  }
227  else {
228 
229  // we add a subroutine "check" here that can reject proposals right away
230  // this is useful for enforcing some constraints on the proposals
231  // defaultly, check does nothing. NOTE: it is important to the shibbholeth sampler that
232  // this happens before we compute posteriors
233  if(not check(proposal)) {
234  history << false;
235  continue;
236  }
237 
238 
239  // here we actually need to compute, but we can do so at the breakout
240  // TODO: This is a little inefficient in that we compute log(uniform()) even
241  // when we are special (and it is thus unused)
242  const double u = log(uniform());
243 
244  // NOTE: The above is NOT right because the prior is not at temperature, so
245  // instead of multiplying by temperature we have to do something smarter to fix the fact that
246  // its only on the likelihood. Reverting now to breakout=-infinity but keeping the rest of code in place
247  // for when this is fixed
248 // const auto breakoutpair = std::make_pair(-infinity, 1.0);
249 
250  // ok we will accept if u < proposal.at_temperature(temperature) - current.at_temperature(temperature) - fb;
251  // or u + current.at_temperature(temperature) + fb < proposal.at_temperature(temperature)
252  // or (u + current.at_temperature(temperature) + fb - PRIOR)*temperature < LIKELIHOOD
253  // NOTE then that in compute_posterior and compute_likelihood, we must NOT take into
254  // account tempearture
255  // NOTE: This breakout is on *posteriors* but in compute_posterior it is converted to one
256  // on likelihoods for compute_likelihood
257  const auto breakoutpair = std::make_pair(u + current.at_temperature(temperature) + fb, (double)temperature);
258 
259  proposal.compute_posterior(*data, breakoutpair);
260 // proposal.compute_posterior(*data);
261 
262 // #ifdef DEBUG_MCMC
263 // DEBUG("# Proposed", proposal.posterior, proposal.prior, proposal.likelihood, proposal.string(), "fb="+str(fb));
264 // #endif
265  if(FleetArgs::print_proposals != 0) [[unlikely]] {
266  print("#Proposed", proposal.posterior, proposal.prior, proposal.likelihood, proposal.string(), "fb="+str(fb));
267  }
268 
269  const double ratio = proposal.at_temperature(temperature) - current.at_temperature(temperature) - fb;
270 
271  // this is just a little debugging/checking code to see that we are making the same decision as
272  // without breakout. It should be commented out unless we're check
273 // assert( u < ratio == proposal.at_temperature(temperature) - current.at_temperature(temperature) - fb
274 
275  if((not std::isnan(proposal.posterior)) and u < ratio) {
276  [[unlikely]];
277 
278  #ifdef DEBUG_MCMC
279  DEBUG("# ACCEPT");
280  #endif
281 
282  current = std::move(proposal);
283 
284  history << true;
285  ++acceptances;
286 
287  // we always yield accepts
288  co_yield current; // must be done with lock
289  }
290  else {
291  history << false;
292 
293  #ifdef DEBUG_MCMC
294  DEBUG("# REJECT");
295  #endif
296 
297  // only yield rejects when not MCMCYieldOnlyChanges
299  co_yield current;
300  }
301  }
302 
303 
304  }
305  }
306 
307  ++samples;
309 
310 
311 
312  }
313  }
314 
315  void run() {
319  run(Control(0,0));
320  }
321 
322  double acceptance_ratio() {
327  return history.mean();
328  }
329 
330  double at_temperature(double t){
336  return current.at_temperature(t);
337  }
338 
339 };
Definition: OrderedLock.h:16
unsigned long samples
Definition: MCMCChain.h:41
std::atomic< uintmax_t > global_sample_count(0)
virtual ~MCMCChain()
Definition: MCMCChain.h:85
unsigned long proposals
Definition: MCMCChain.h:42
double acceptance_ratio()
Definition: MCMCChain.h:322
void run()
Definition: MCMCChain.h:315
HYP::data_t * data
Definition: MCMCChain.h:35
unsigned long steps_since_improvement
Definition: MCMCChain.h:44
void restart()
Definition: MCMCChain.h:130
FiniteHistory< bool > history
Definition: MCMCChain.h:48
double uniform()
Definition: Random.h:11
const HYP & getCurrent() const
Definition: MCMCChain.h:107
bool print_proposals
Definition: FleetArgs.h:36
_HYP HYP
Definition: MCMCChain.h:26
unsigned long acceptances
Definition: MCMCChain.h:43
MCMCChain(HYP &&h0, typename HYP::data_t *d)
Definition: MCMCChain.h:57
void runOnCurrent()
Definition: MCMCChain.h:115
Definition: MCMCChain.h:23
HYP current
Definition: MCMCChain.h:28
std::string str(BindingTree *t)
Definition: BindingTree.h:195
MCMCChain(HYP &h0, typename HYP::data_t *d)
Definition: MCMCChain.h:50
void start()
Definition: Control.h:54
void print(FIRST f, ARGS... args)
Lock output_lock and print to std:cout.
Definition: IO.h:53
double maxval
Definition: MCMCChain.h:39
Definition: Control.h:23
A FiniteHistory stores the previous N examples of something of type T. This is used e...
virtual bool check(HYP &p)
This allows us to overwrite/enforce stuff about proposals in subclasses of MCMCChain.
Definition: MCMCChain.h:144
constexpr double infinity
Definition: Numerics.h:20
A FIFO mutex (from stackoverflow) https://stackoverflow.com/questions/14792016/creating-a-lock-that-p...
Definition: generator.hpp:21
HYP & getCurrent()
Definition: MCMCChain.h:99
double at_temperature(double t)
Definition: MCMCChain.h:330
size_t nthreads
Definition: Control.h:26
bool MCMCYieldOnlyChanges
Definition: FleetArgs.h:50
void DEBUG(FIRST f, ARGS... args)
Print to std:ccout with debugging info.
Definition: IO.h:73
bool running()
Definition: Control.h:63
std::atomic< double > temperature
Definition: MCMCChain.h:46
OrderedLock current_mutex
Definition: MCMCChain.h:33
void set_data(typename HYP::data_t *d, bool recompute_posterior=true)
Set this data.
Definition: MCMCChain.h:92
const HYP & getMax()
Definition: MCMCChain.h:126
generator< HYP & > run(Control ctl)
Run MCMC according to the control parameters passed in. NOTE: ctl cannot be passed by reference...
Definition: MCMCChain.h:153
MCMCChain(const MCMCChain &m)
Definition: MCMCChain.h:64
MCMCChain(MCMCChain &&m)
Definition: MCMCChain.h:72
unsigned long restart
Definition: Control.h:27
double mean()
Compute the average.
Definition: FiniteHistory.h:94
This represents an MCMC hain on a hypothesis of type HYP. It uses HYP::propose and HYP::compute_poste...
This class has all the information for running MCMC or MCTS in a little package. It defaultly constru...