Fleet  0.0.9
Inference in the LOT
DiscreteDistribution.h
Go to the documentation of this file.
1 #pragma once
2 
3 #include<map>
4 #include<vector>
5 #include<algorithm>
6 #include <iostream>
7 #include <assert.h>
8 #include <memory>
9 
10 #include "Miscellaneous.h"
11 #include "Numerics.h"
12 #include "IO.h"
13 #include "Strings.h"
14 
15 
16 
24 template<typename T>
26 
27 public:
28 
29  // m here will store the map from values to log probabilities. We need to ensure, though
30  // that when T is either float or double, it correctly sorts NaN (which std::less does not defaulty do)
31  // so, here we have a conditional to handle that, what a nightmare
32  using my_map_t = typename std::conditional< std::is_same<T,double>::value or std::is_same<T,float>::value,
33  std::map<T,double,floating_point_compare<double>>,
34  std::map<T,double>>::type;
36 
38 
39  virtual T argmax() const {
40  T best{}; // Note defaults to this when there are none!
41  double bestv = -infinity;
42  for(auto a : m) {
43  if(a.second > bestv) {
44  bestv = a.second;
45  best = a.first;
46  }
47  }
48  return best;
49  }
50 
51 
52  void show(std::ostream& out, unsigned long nprint=0) const { out << this->string(nprint); }
53  void show(unsigned long nprint=0) const { show(std::cout, nprint); }
54 
55  void erase(const T& k) {
56  m.erase(k);
57  }
58 
59 
60  std::string string(unsigned long nprint=0) const {
67  // put the strings into a vector
68  std::vector<T> v;
69  double z = Z();
70 
71  for(const auto& a : m) {
72  v.push_back(a.first);
73  }
74 
75  // sort them to be increasing
76  std::sort(v.begin(), v.end(), [this](T a, T b) { return this->m.at(a) > this->m.at(b); });
77 
78 
79  std::string out = "{";
80  auto upper_bound = (nprint == 0 ? v.size() : std::min(v.size(), nprint));
81  for(size_t i=0;i<upper_bound;i++){
82  out += "'" + str(v[i]) + "':" + std::to_string(m.at(v[i]));
83  if(i < upper_bound-1) { out += ", "; }
84  }
85  out += "} [Z=" + str(z) + ", N=" + std::to_string(v.size()) + "]";
86 
87  return out;
88  }
89 
90 
91  void addmass(T x, double v) {
98  // We can't store NaN in this container ugh
99  if constexpr (std::is_same<T, double>::value or
100  std::is_same<T, float>::value){
101  assert((not std::isnan(x)) && "*** Cannot store NaNs here, sorry. They don't work with maps and you'll be in for an unholy nightmare");
102  }
103 
104  if(m.find(x) == m.end()) {
105  m[x] = v;
106  }
107  else {
108  m[x] = logplusexp(m[x], v);
109  }
110  }
111 
112  double get(T x, double v) const {
113  if(m.find(x) == m.end()) {
114  return v;
115  }
116  else {
117  return m.at(x);
118  }
119  }
120 
121  bool contains(const T& x) const {
122  return m.contains(x);
123  }
124 
125  auto begin() { return m.begin(); }
126  auto end() { return m.end(); }
127 
128  const std::map<T,double>& values() const {
133  return m;
134  }
135 
136  void operator<<(const DiscreteDistribution<T>& x) {
137  // adds all of x to me
138  for(auto a : x.values()) {
139  addmass(a.first, a.second);
140  }
141  }
142 
143  double Z() const {
144  double Z = -infinity; // add up the mass
145  for(const auto& a : m){
146  Z = logplusexp(Z, a.second);
147  }
148  return Z;
149  }
150 
151  double lp(const T& x) {
157  if(m.count(x)) {
158  return m[x]-Z();
159  }
160  else {
161  return -infinity;
162  }
163  }
164 
165 
166  std::vector<T> best(size_t n, bool include_equal) const {
174  std::vector<std::pair<T,double>> v(m.size());
175  std::copy(m.begin(), m.end(), v.begin());
176  std::sort(v.begin(), v.end(), [](auto x, auto y){ return x.second > y.second; }); // put the big stuff first
177 
178  std::vector<T> out;
179  auto until = n-1;
180  for(size_t i=0; (i<v.size()) and ((i<=until) or (include_equal and v[i].second == v[until].second));i++){
181  out.push_back(v[i].first);
182  }
183 
184  return out;
185  }
186 
187  std::vector<std::pair<T,double>> sorted(bool decreasing=false) const {
192  std::vector<std::pair<T,double>> v(m.size());
193  std::copy(m.begin(), m.end(), v.begin());
194  if(decreasing) std::sort(v.begin(), v.end(), [](auto x, auto y){ return x.second > y.second; }); // put the big stuff first
195  else std::sort(v.begin(), v.end(), [](auto x, auto y){ return x.second < y.second; }); // put the big stuff first
196  return v;
197  }
198 
199 
200  // inherit some interfaces
201  size_t count(T x) const { return m.count(x); }
202  size_t size() const { return m.size(); }
203  double operator[](const T x) {
204  if(m.count(x)) return m[x];
205  else return -infinity;
206  }
207 // double& operator[](const T x) {
208 // return m[x];
209 // }
210  double at(T x) const {
211  if(m.count(x)) return m.at(x);
212  else return -infinity;
213  }
214 
215 
222  auto operator<=>(const DiscreteDistribution<T>& x) const {
223  return m <=> x.m;
224  }
225 
226  bool operator==(const DiscreteDistribution<T>& other) const {
227  // we can compare elements -- NOTE we set a tolerance so they don't have to be exactly
228  // equal (I hope this doesn't cause you troubles)
229  const double threshold = 1e-6;
230 
231  std::set<T> keys;
232  std::transform(m.begin(), m.end(), std::inserter(keys, keys.end()), [](auto pair){ return pair.first; });
233  std::transform(other.m.begin(), other.m.end(), std::inserter(keys, keys.end()), [](auto pair){ return pair.first; });
234 
235  for(auto& k : keys) {
236  if(abs(at(k) - other.at(k)) > threshold)
237  return false;
238  }
239 
240  return true;
241  }
242 
243 };
244 
245 
246 template<typename T>
247 std::ostream& operator<<(std::ostream& o, const DiscreteDistribution<T>& x) {
248  o << x.string();
249  return o;
250 }
251 
252 template<typename T>
253 std::string str(const DiscreteDistribution<T>& a ){
254  return a.string();
255 }
my_map_t m
Definition: DiscreteDistribution.h:35
std::string string(unsigned long nprint=0) const
Definition: DiscreteDistribution.h:60
auto end()
Definition: DiscreteDistribution.h:126
double at(T x) const
Definition: DiscreteDistribution.h:210
virtual T argmax() const
Definition: DiscreteDistribution.h:39
Definition: DiscreteDistribution.h:25
DiscreteDistribution()
Definition: DiscreteDistribution.h:37
double operator[](const T x)
Definition: DiscreteDistribution.h:203
std::vector< std::pair< T, double > > sorted(bool decreasing=false) const
Definition: DiscreteDistribution.h:187
constexpr double infinity
Definition: Numerics.h:20
bool contains(const T &x) const
Definition: DiscreteDistribution.h:121
void show(unsigned long nprint=0) const
Definition: DiscreteDistribution.h:53
T logplusexp(const T a, const T b)
Definition: Numerics.h:131
void addmass(T x, double v)
Definition: DiscreteDistribution.h:91
size_t count(T x) const
Definition: DiscreteDistribution.h:201
double lp(const T &x)
Definition: DiscreteDistribution.h:151
bool operator==(const DiscreteDistribution< T > &other) const
Definition: DiscreteDistribution.h:226
void show(std::ostream &out, unsigned long nprint=0) const
Definition: DiscreteDistribution.h:52
const std::map< T, double > & values() const
Definition: DiscreteDistribution.h:128
size_t size() const
Definition: DiscreteDistribution.h:202
std::string str(const DiscreteDistribution< T > &a)
Definition: DiscreteDistribution.h:253
auto begin()
Definition: DiscreteDistribution.h:125
std::vector< T > best(size_t n, bool include_equal) const
Definition: DiscreteDistribution.h:166
void erase(const T &k)
Definition: DiscreteDistribution.h:55
double Z() const
Definition: DiscreteDistribution.h:143
typename std::conditional< std::is_same< T, double >::value or std::is_same< T, float >::value, std::map< T, double, floating_point_compare< double > >, std::map< T, double > >::type my_map_t
Definition: DiscreteDistribution.h:34