Crombie Tools
TmvaClassifier.cc
Go to the documentation of this file.
1 #include <iostream>
2 #include <fstream>
3 #include <stdio.h>
4 
5 #include "TFile.h"
6 #include "TGraph.h"
7 #include "TTreeFormula.h"
8 #include "TH1D.h"
9 
10 #include "PlotHists.h"
11 #include "TmvaClassifier.h"
12 #include "TreeContainer.h"
13 
14 #include "TMVA/Factory.h"
15 #include "TMVA/Reader.h"
16 
17 //--------------------------------------------------------------------
19 fSignalCut(""),
20  fBackgroundCut(""),
21  fWeight("1"),
22  fJobName("TmvaClassifier"),
23  fMethodName("BDT"),
24  fBDTDef(""),
25  fBDTName("bdt_test"),
26  fOutputName("TMVA.root"),
27  fConfigFile(""),
28  fUniformVariable(""),
29  fApplicationDirectory(""),
30  fApplicationTree("Events"),
31  fApplicationOutput("WithBDT"),
32  fReportFrequency(0)
33 {
34  // Constructor
35  fSignalFileNames.resize(0);
36  fSignalTreeNames.resize(0);
37  fBackgroundFileNames.resize(0);
38  fBackgroundTreeNames.resize(0);
39 
40  fVariables.resize(0);
41  fVarTypes.resize(0);
42  fSpectatorVariables.resize(0);
43 }
44 
45 //--------------------------------------------------------------------
47 {
48 }
49 
50 //--------------------------------------------------------------------
51 void
53 {
54  fConfigFile = name;
55 
56  TString BDTName;
57  std::ifstream configFile;
58  configFile.open(fConfigFile.Data());
59  TString tempFormula;
60 
61  std::vector<TString> Strings;
62 
63  configFile >> fBDTName >> fUniformVariable;
64 
65  while(!configFile.eof()){
66  configFile >> tempFormula;
67  if(tempFormula != ""){
68  AddVariable(tempFormula);
69  }
70  }
71 
72 }
73 
74 //--------------------------------------------------------------------
75 void
77 {
78  TFile *TMVAOutput = new TFile(fOutputName, "RECREATE");
79  TMVA::Factory *factory = new TMVA::Factory(fJobName, TMVAOutput,
80  "V:!Silent:Color:DrawProgressBar:Transformations=I;N");
81 
82  for (UInt_t iVar = 0; iVar != fVariables.size(); ++iVar)
83  factory->AddVariable(fVariables[iVar]);
84 
85  for (UInt_t iSpec = 0; iSpec != fSpectatorVariables.size(); ++iSpec)
86  factory->AddSpectator(fSpectatorVariables[iSpec]);
87 
88  if (fUniformVariable != "")
89  factory->AddSpectator(fUniformVariable);
90 
91  std::vector<TreeContainer*> SignalTrees;
92  TreeContainer *tempTree;
93  for (UInt_t iTree = 0; iTree != fSignalFileNames.size(); ++iTree) {
94  tempTree = new TreeContainer(fSignalFileNames[iTree]);
95  if (fSignalCut != "")
96  tempTree->SetSkimmingCut(fSignalCut);
97 
98  factory->AddSignalTree(tempTree->ReturnTree(fSignalTreeNames[iTree]), 1.0);
99  SignalTrees.push_back(tempTree); // Save TreeContainer for deleting at the end
100  }
101 
102  std::vector<TreeContainer*> BackgroundTrees;
103  for (UInt_t iTree = 0; iTree != fBackgroundFileNames.size(); ++iTree) {
104  tempTree = new TreeContainer(fBackgroundFileNames[iTree]);
105  if (fBackgroundCut != "")
106  tempTree->SetSkimmingCut(fBackgroundCut);
107 
108  factory->AddBackgroundTree(tempTree->ReturnTree(fBackgroundTreeNames[iTree]), 1.0);
109  BackgroundTrees.push_back(tempTree); // Save TreeContainer for deleting at the end
110  }
111 
112  factory->SetWeightExpression(fWeight.GetTitle());
113  factory->PrepareTrainingAndTestTree("1","SplitMode=Alternate:NormMode=NumEvents:V");
114 
115  factory->BookMethod(TMVA::Types::kBDT,fMethodName,fBDTDef);
116  factory->TrainAllMethods();
117  factory->TestAllMethods();
118  factory->EvaluateAllMethods();
119  TMVAOutput->Close();
120 
121  delete factory;
122 
123  for (UInt_t iTree = 0; iTree != SignalTrees.size(); ++iTree)
124  delete SignalTrees[iTree];
125 
126  for (UInt_t iTree = 0; iTree != BackgroundTrees.size(); ++iTree)
127  delete BackgroundTrees[iTree];
128 
129  SignalTrees.resize(0);
130  BackgroundTrees.resize(0);
131 }
132 
133 //--------------------------------------------------------------------
134 void
136 {
137  Apply(1, 0., 0.);
138 }
139 
140 //--------------------------------------------------------------------
141 void
142 TmvaClassifier::Apply(Int_t NumBins, Double_t VarMin, Double_t VarMax, Int_t NumMapPoints)
143 {
144  Double_t binWidth = (VarMax - VarMin)/NumBins;
145  Double_t VarVals[NumBins+1];
146  for (Int_t i0 = 0; i0 != NumBins + 1; ++i0)
147  VarVals[i0] = VarMin + i0 * binWidth;
148 
149  Apply(NumBins, VarVals, NumMapPoints);
150 }
151 
152 //--------------------------------------------------------------------
153 void
154 TmvaClassifier::Apply(Int_t NumBins, Double_t *VarVals, Int_t NumMapPoints)
155 {
156  // Now, look into setting up uniform distribution
157 
158  std::vector<TGraph*> transformGraphs;
159  TGraph *tempGraph;
160 
161  if (fUniformVariable != "") {
162  // First scale the BDT to be uniform
163  // Make the background shape
164 
165  TreeContainer *TrainingTreeContainer = new TreeContainer(fOutputName);
166  TrainingTreeContainer->SetSkimmingCut("classID == 1");
167 
168  TTree *BackgroundTree = TrainingTreeContainer->ReturnTree("TrainTree");
169 
170  PlotHists *HistPlotter = new PlotHists();
171  HistPlotter->SetDefaultTree(BackgroundTree);
172  HistPlotter->SetDefaultExpr(fMethodName);
173 
174  Double_t binWidth = 2.0/(NumMapPoints - 1);
175  Double_t BDTBins[NumMapPoints];
176  for (Int_t i0 = 0; i0 != NumMapPoints; ++i0)
177  BDTBins[i0] = i0 * binWidth - 1;
178 
179  for (Int_t iVarBin = 0; iVarBin != NumBins; ++iVarBin) {
180  char buffer [1023];
181  sprintf(buffer, "(%s>=%f&&%s<%f)",
182  fUniformVariable.Data(), VarVals[iVarBin],
183  fUniformVariable.Data(), VarVals[iVarBin+1]);
184  TCut BinCut = TCut(buffer);
185  HistPlotter->AddWeight(fWeight * "classID == 1" + BinCut);
186  }
187 
188  std::cout << "Making hists." << std::endl;
189 
190  std::vector<TH1D*> BDTHists = HistPlotter->MakeHists(NumMapPoints,-1,1);
191 
192  std::cout << "Finished hists." << std::endl;
193 
194 // delete BackgroundTree;
195  delete TrainingTreeContainer;
196 
197  std::cout << "Deleted containers." << std::endl;
198 
199  for (Int_t iVarBin = 0; iVarBin != NumBins; ++iVarBin) {
200  tempGraph = new TGraph(NumMapPoints);
201  transformGraphs.push_back(tempGraph);
202  Double_t FullIntegral = BDTHists[iVarBin]->Integral();
203  for (Int_t iMapPoint = 0; iMapPoint != NumMapPoints; ++iMapPoint) {
204  transformGraphs[iVarBin]->SetPoint(iMapPoint, BDTBins[iMapPoint],
205  BDTHists[iVarBin]->Integral(0,iMapPoint)/FullIntegral);
206  }
207  }
208 
209  std::cout << "Got map points." << std::endl;
210 
211  for (UInt_t iHist = 0; iHist != BDTHists.size(); ++iHist)
212  delete BDTHists[iHist];
213  }
214 
215  std::cout << "About to apply." << std::endl;
216 
217  // Then apply the BDT
218  TMVA::Reader* reader = new TMVA::Reader("Color:!Silent");
219 
220  std::vector<Int_t> discreteVars;
221  std::vector<Float_t> continuousVars;
222  std::vector<TTreeFormula*> discreteFormulae;
223  std::vector<TTreeFormula*> continuousFormulae;
224 
225  Int_t numDiscrete = 0;
226  Int_t numContinuous = 0;
227 
228  TTreeFormula* tempFormula;
229 
230  TreeContainer *ApplicationTreesContainer = new TreeContainer();
231  ApplicationTreesContainer->AddDirectory(fApplicationDirectory);
232  std::vector<TTree*> ApplicationTrees = ApplicationTreesContainer->ReturnTreeList(fApplicationTree);
233  std::vector<TTree*> CopiedTrees;
234  std::vector<TString> FileList = ApplicationTreesContainer->ReturnFileNames();
235 
236  Float_t UniformVar = 0.;
237  TTreeFormula *UniformFormula = new TTreeFormula(fUniformVariable, fUniformVariable, ApplicationTrees[0]);
238 
239  if (fUniformVariable != "")
240  reader->AddSpectator(fUniformVariable, &UniformVar);
241 
242  for (UInt_t iVar = 0; iVar != fVariables.size(); ++iVar) {
243  if (fVarTypes[iVar] == 'I') {
244  discreteVars.push_back(0);
245  tempFormula = new TTreeFormula(fVariables[iVar],fVariables[iVar],ApplicationTrees[0]);
246  discreteFormulae.push_back(tempFormula);
247  reader->AddVariable(fVariables[iVar], &discreteVars[numDiscrete]);
248  ++numDiscrete;
249  }
250  else if (fVarTypes[iVar] == 'F') {
251  continuousVars.push_back(0);
252  tempFormula = new TTreeFormula(fVariables[iVar],fVariables[iVar],ApplicationTrees[0]);
253  continuousFormulae.push_back(tempFormula);
254  reader->AddVariable(fVariables[iVar], &continuousVars[numContinuous]);
255  ++numContinuous;
256  }
257  }
258 
259  for (UInt_t iVar = 0; iVar != fSpectatorVariables.size(); ++iVar) {
260  continuousVars.push_back(0);
261  tempFormula = new TTreeFormula(fSpectatorVariables[iVar],fSpectatorVariables[iVar],ApplicationTrees[0]);
262  continuousFormulae.push_back(tempFormula);
263  reader->AddSpectator(fSpectatorVariables[iVar], &continuousVars[numContinuous]);
264  ++numContinuous;
265  }
266 
267  reader->BookMVA(fMethodName,TString("weights/") + fJobName + TString("_") + fMethodName +".weights.xml");
268 
269  for (UInt_t iTree = 0; iTree != ApplicationTrees.size(); ++iTree) {
270  TFile *newFile = new TFile(fApplicationOutput+"/"+FileList[iTree], "RECREATE");
271  TTree *tempTree = ApplicationTrees[iTree]->CloneTree();
272 
273  Float_t BDTOutput = 0.;
274  tempTree->Branch(fBDTName,&BDTOutput,fBDTName+"/F");
275 
276  for (UInt_t iForm = 0; iForm != discreteFormulae.size(); ++iForm)
277  discreteFormulae[iForm]->SetTree(ApplicationTrees[iTree]);
278 
279  Int_t NEntries = ApplicationTrees[iTree]->GetEntriesFast();
280  for (Int_t iEntry = 0; iEntry != NEntries; ++iEntry) {
281  if (fReportFrequency > 0 && iEntry % fReportFrequency == 0)
282  std::cout << "Processing event... " << iEntry << ": " << float(iEntry)/NEntries * 100 << "%" << std::endl;
283 
284  ApplicationTrees[iTree]->GetEntry(iEntry);
285 
286  for (UInt_t iForm = 0; iForm != discreteFormulae.size(); ++iForm)
287  discreteVars[iForm] = discreteFormulae[iForm]->EvalInstance();
288 
289  for (UInt_t iForm = 0; iForm != continuousFormulae.size(); ++iForm)
290  continuousVars[iForm] = continuousFormulae[iForm]->EvalInstance();
291 
292  BDTOutput = reader->EvaluateMVA(fMethodName);
293 
294  if (fUniformVariable != "") {
295  UniformVar = UniformFormula->EvalInstance();
296 
297  if (UniformVar >= VarVals[0] && UniformVar < VarVals[NumBins]) {
298  for (Int_t iBin = 0; iBin != NumBins; ++iBin) {
299  if (UniformVar < VarVals[iBin + 1]) {
300  BDTOutput = transformGraphs[iBin]->Eval(BDTOutput);
301  break;
302  }
303  }
304  }
305  }
306  tempTree->Fill();
307  }
308  tempTree->Write(0, TObject::kOverwrite);
309  newFile->Close();
310  }
311 
312  if (fUniformVariable != "")
313  delete UniformFormula;
314 
315  for (UInt_t iForm = 0; iForm != discreteFormulae.size(); ++iForm)
316  delete discreteFormulae[iForm];
317 
318  for (UInt_t iForm = 0; iForm != continuousFormulae.size(); ++iForm)
319  delete continuousFormulae[iForm];
320 
321  delete ApplicationTreesContainer;
322 
323  delete reader;
324 
325  for (Int_t iGraph = 0; iGraph != NumBins; ++iGraph)
326  delete transformGraphs[iGraph];
327 }