muda
temp_buffer.h
1 #pragma once
2 #include <cuda_runtime.h>
3 #include <muda/check/check.h>
4 namespace muda::details
5 {
6 template <typename T>
7 class TempBuffer
8 {
9  public:
10  TempBuffer() {}
11 
12  TempBuffer(size_t size) { resize(size); }
13 
14  ~TempBuffer()
15  {
16  if(m_data)
17  {
18  // we don't check the error here to prevent exception when app is shutting down
19  cudaFree(m_data);
20  }
21  }
22 
23  TempBuffer(TempBuffer&& other) noexcept
24  {
25  m_size = other.m_size;
26  m_capacity = other.m_capacity;
27  m_data = other.m_data;
28  other.m_size = 0;
29  other.m_capacity = 0;
30  other.m_data = nullptr;
31  }
32 
33  TempBuffer& operator=(TempBuffer&& other) noexcept
34  {
35  if(this == &other)
36  {
37  return *this;
38  }
39  m_size = other.m_size;
40  m_capacity = other.m_capacity;
41  m_data = other.m_data;
42  other.m_size = 0;
43  other.m_capacity = 0;
44  other.m_data = nullptr;
45  return *this;
46  }
47 
48  // no change on copy
49  TempBuffer(const TempBuffer&) noexcept {}
50  // no change on copy
51  TempBuffer& operator=(const TempBuffer&) noexcept { return *this; }
52 
53  void copy_to(std::vector<T>& vec, cudaStream_t stream = nullptr) const
54  {
55  vec.resize(m_size);
56  checkCudaErrors(cudaMemcpyAsync(
57  vec.data(), m_data, m_size * sizeof(T), cudaMemcpyDeviceToHost, stream));
58  }
59 
60  void copy_from(TempBuffer<T>& other, cudaStream_t stream = nullptr)
61  {
62  resize(other.size());
63  checkCudaErrors(cudaMemcpyAsync(
64  m_data, other.data(), other.size() * sizeof(T), cudaMemcpyDeviceToDevice, stream));
65  }
66 
67  void copy_from(const std::vector<T>& vec, cudaStream_t stream = nullptr)
68  {
69  resize(vec.size());
70  checkCudaErrors(cudaMemcpyAsync(
71  m_data, vec.data(), vec.size() * sizeof(T), cudaMemcpyHostToDevice, stream));
72  }
73 
74  TempBuffer(const std::vector<T>& vec) { copy_from(vec); }
75 
76  TempBuffer& operator=(const std::vector<T>& vec)
77  {
78  copy_from(vec);
79  return *this;
80  }
81 
82  void reserve(size_t new_cap, cudaStream_t stream = nullptr)
83  {
84  if(new_cap <= m_capacity)
85  {
86  return;
87  }
88  T* new_data = nullptr;
89  checkCudaErrors(cudaMalloc(&new_data, new_cap * sizeof(T)));
90  if(m_data)
91  {
92  checkCudaErrors(cudaFree(m_data));
93  }
94  m_data = new_data;
95  m_capacity = new_cap;
96  }
97 
98  void resize(size_t size, cudaStream_t stream = nullptr)
99  {
100  if(size <= m_capacity)
101  {
102  m_size = size;
103  return;
104  }
105  reserve(size, stream);
106  m_size = size;
107  }
108 
109  void free() noexcept
110  {
111  m_size = 0;
112  m_capacity = 0;
113  if(m_data)
114  {
115  checkCudaErrors(cudaFree(m_data));
116  m_data = nullptr;
117  }
118  }
119 
120  auto size() const noexcept { return m_size; }
121  auto data() const noexcept { return m_data; }
122  auto capacity() const noexcept { return m_capacity; }
123 
124  private:
125  size_t m_size = 0;
126  size_t m_capacity = 0;
127  T* m_data = nullptr;
128 };
129 
131 } // namespace muda::details
Definition: temp_buffer.h:7