Skip to content

Commit f54725f

Browse files
committed
device: dense multiplication example for UMTensor
1 parent 4d7b6ee commit f54725f

File tree

2 files changed

+377
-1
lines changed

2 files changed

+377
-1
lines changed

examples/device/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525

2626
if(TILEDARRAY_HAS_CUDA OR TILEDARRAY_HAS_HIP)
2727

28-
foreach(_exec device_task ta_dense_device ta_cc_abcd_device ta_vector_device ta_reduce_device)
28+
foreach(_exec device_task ta_dense_device ta_dense_um_tensor ta_cc_abcd_device ta_vector_device ta_reduce_device)
2929

3030
# Add executable
3131
add_ta_executable(${_exec} "${_exec}.cpp" "tiledarray")
Lines changed: 376 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,376 @@
1+
/*
2+
* This file is a part of TiledArray.
3+
* Copyright (C) 2025 Virginia Tech
4+
*
5+
* This program is free software: you can redistribute it and/or modify
6+
* it under the terms of the GNU General Public License as published by
7+
* the Free Software Foundation, either version 3 of the License, or
8+
* (at your option) any later version.
9+
*
10+
* This program is distributed in the hope that it will be useful,
11+
* but WITHOUT ANY WARRANTY; without even the implied warranty of
12+
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13+
* GNU General Public License for more details.
14+
*
15+
* You should have received a copy of the GNU General Public License
16+
* along with this program. If not, see <http://www.gnu.org/licenses/>.
17+
*
18+
*/
19+
20+
// clang-format off
21+
#include <tiledarray.h>
22+
#include <TiledArray/device/um_tensor.h>
23+
// clang-format on
24+
25+
#ifdef TILEDARRAY_HAS_CUDA
26+
#include <cuda_profiler_api.h>
27+
#endif // TILEDARRAY_HAS_CUDA
28+
29+
template <typename T>
30+
void do_main_body(TiledArray::World& world, const long Nm, const long Bm,
31+
const long Nn, const long Bn, const long Nk, const long Bk,
32+
const long nrepeat) {
33+
using RT = TiledArray::detail::scalar_t<T>;
34+
constexpr auto complex_T = TiledArray::detail::is_complex_v<T>;
35+
36+
const std::int64_t nflops =
37+
(complex_T ? 8 : 2) // 1 multiply takes 6/1 flops for complex/real
38+
// 1 add takes 2/1 flops for complex/real
39+
* static_cast<std::int64_t>(Nn) * static_cast<std::int64_t>(Nm) *
40+
static_cast<std::int64_t>(Nk);
41+
42+
// Construct TiledRange
43+
std::vector<unsigned int> blocking_m;
44+
for (long i = 0l; i <= Nm; i += Bm) blocking_m.push_back(i);
45+
const std::size_t Tm = blocking_m.size() - 1;
46+
47+
std::vector<unsigned int> blocking_n;
48+
for (long i = 0l; i <= Nn; i += Bn) blocking_n.push_back(i);
49+
const std::size_t Tn = blocking_n.size() - 1;
50+
51+
std::vector<unsigned int> blocking_k;
52+
for (long i = 0l; i <= Nk; i += Bk) blocking_k.push_back(i);
53+
const std::size_t Tk = blocking_k.size();
54+
55+
if (world.rank() == 0)
56+
std::cout << "TiledArray: UMTensor dense matrix multiply test...\n"
57+
<< "Number of nodes = " << world.size()
58+
<< "\nSize of A = " << Nm << "x" << Nk << " ("
59+
<< double(Nm * Nk * sizeof(T)) / 1.0e9 << " GB)"
60+
<< "\nSize of (largest) A block = " << Bm << "x" << Bk
61+
<< "\nSize of B = " << Nk << "x" << Nn << " ("
62+
<< double(Nk * Nn * sizeof(T)) / 1.0e9 << " GB)"
63+
<< "\nSize of (largest) B block = " << Bk << "x" << Bn
64+
<< "\nSize of C = " << Nm << "x" << Nn << " ("
65+
<< double(Nm * Nn * sizeof(T)) / 1.0e9 << " GB)"
66+
<< "\nSize of (largest) C block = " << Bm << "x" << Bn
67+
<< "\n# of blocks of C = " << Tm * Tn
68+
<< "\nAverage # of blocks of C/node = "
69+
<< double(Tm * Tn) / double(world.size()) << "\n";
70+
71+
// Structure of c
72+
std::vector<TiledArray::TiledRange1> blocking_C;
73+
blocking_C.reserve(2);
74+
blocking_C.push_back(
75+
TiledArray::TiledRange1(blocking_m.begin(), blocking_m.end()));
76+
blocking_C.push_back(
77+
TiledArray::TiledRange1(blocking_n.begin(), blocking_n.end()));
78+
79+
// Structure of a
80+
std::vector<TiledArray::TiledRange1> blocking_A;
81+
blocking_A.reserve(2);
82+
blocking_A.push_back(
83+
TiledArray::TiledRange1(blocking_m.begin(), blocking_m.end()));
84+
blocking_A.push_back(
85+
TiledArray::TiledRange1(blocking_k.begin(), blocking_k.end()));
86+
87+
// Structure of b
88+
std::vector<TiledArray::TiledRange1> blocking_B;
89+
blocking_B.reserve(2);
90+
blocking_B.push_back(
91+
TiledArray::TiledRange1(blocking_k.begin(), blocking_k.end()));
92+
blocking_B.push_back(
93+
TiledArray::TiledRange1(blocking_n.begin(), blocking_n.end()));
94+
95+
TiledArray::TiledRange trange_c(blocking_C.begin(), blocking_C.end());
96+
97+
TiledArray::TiledRange trange_a(blocking_A.begin(), blocking_A.end());
98+
99+
TiledArray::TiledRange trange_b(blocking_B.begin(), blocking_B.end());
100+
101+
using DeviceTile = TA::UMTensor<T>;
102+
using DeviceMatrix = TA::DistArray<TA::Tile<DeviceTile>>;
103+
using HostTensor = TA::Tensor<T>;
104+
using HostMatrix = TA::DistArray<HostTensor>;
105+
106+
DeviceMatrix c(world, trange_c);
107+
auto val_a = 0.03;
108+
auto val_b = 0.02;
109+
110+
{
111+
// Construct and initialize arrays on host first
112+
HostMatrix a_host(world, trange_a);
113+
HostMatrix b_host(world, trange_b);
114+
115+
a_host.fill(val_a);
116+
b_host.fill(val_b);
117+
118+
// Convert to UMTensor arrays
119+
DeviceMatrix a(world, trange_a);
120+
DeviceMatrix b(world, trange_b);
121+
122+
// Copy data from host to device tensors
123+
// TODO: Wrap this into a reusable function
124+
for (auto it = a_host.begin(); it != a_host.end(); ++it) {
125+
const auto& index = it.index();
126+
const auto& host_tile_ref = *it;
127+
const auto& host_tile =
128+
host_tile_ref.get(); // Get actual tensor from reference
129+
130+
DeviceTile device_tile(host_tile.range());
131+
132+
std::copy(host_tile.data(), host_tile.data() + host_tile.size(),
133+
device_tile.data());
134+
TiledArray::detail::to_device(device_tile);
135+
136+
a.set(index, TA::Tile<DeviceTile>(std::move(device_tile)));
137+
}
138+
139+
for (auto it = b_host.begin(); it != b_host.end(); ++it) {
140+
const auto& index = it.index();
141+
const auto& host_tile_ref = *it;
142+
const auto& host_tile =
143+
host_tile_ref.get(); // Get actual tensor from reference
144+
DeviceTile device_tile(host_tile.range());
145+
146+
std::copy(host_tile.data(), host_tile.data() + host_tile.size(),
147+
device_tile.data());
148+
149+
TiledArray::detail::to_device(device_tile);
150+
151+
b.set(index, TA::Tile<DeviceTile>(std::move(device_tile)));
152+
}
153+
154+
world.gop.fence();
155+
156+
#ifdef TILEDARRAY_HAS_CUDA
157+
// start profiler
158+
cudaProfilerStart();
159+
#endif // TILEDARRAY_HAS_CUDA
160+
161+
double total_time = 0.0;
162+
double total_gflop_rate = 0.0;
163+
164+
// Do matrix multiplication
165+
for (int i = 0; i < nrepeat; ++i) {
166+
double iter_time_start = madness::wall_time();
167+
c("m,n") = a("m,k") * b("k,n");
168+
c.world().gop.fence(); // fence since GEMM can return early
169+
double iter_time_stop = madness::wall_time();
170+
const double iter_time = iter_time_stop - iter_time_start;
171+
total_time += iter_time;
172+
const double gflop_rate = double(nflops) / (iter_time * 1.e9);
173+
total_gflop_rate += gflop_rate;
174+
if (world.rank() == 0)
175+
std::cout << "Iteration " << i + 1 << " wall time: " << iter_time
176+
<< " sec\n";
177+
if (world.rank() == 0)
178+
std::cout << "Iteration " << i + 1 << " GFLOPS=" << gflop_rate
179+
<< "\n";
180+
}
181+
182+
#ifdef TILEDARRAY_HAS_CUDA
183+
// stop profiler
184+
cudaProfilerStop();
185+
#endif // TILEDARRAY_HAS_CUDA
186+
187+
if (world.rank() == 0)
188+
std::cout << "Average wall time = " << total_time / double(nrepeat)
189+
<< " sec\nAverage GFLOPS = "
190+
<< total_gflop_rate / double(nrepeat) << "\n";
191+
}
192+
193+
double threshold = std::numeric_limits<RT>::epsilon();
194+
auto dot_length = Nk;
195+
T result;
196+
if constexpr (complex_T) {
197+
result = T(dot_length * val_a * val_b, 0.);
198+
} else
199+
result = dot_length * val_a * val_b;
200+
201+
auto verify = [&world, &threshold, &result,
202+
&dot_length](TA::Tile<DeviceTile>& tile) {
203+
auto& um_tensor = tile.tensor();
204+
TiledArray::to_execution_space<TiledArray::ExecutionSpace::Host>(
205+
um_tensor, TiledArray::device::stream_for(um_tensor.range()));
206+
TiledArray::device::sync_madness_task_with(
207+
TiledArray::device::stream_for(um_tensor.range()));
208+
209+
auto n_elements = tile.size();
210+
for (std::size_t i = 0; i < n_elements; i++) {
211+
double abs_err = std::abs(tile[i] - result);
212+
double rel_err = abs_err / std::abs(result) / dot_length;
213+
if (rel_err > threshold) {
214+
auto to_string = [](const auto& v) {
215+
constexpr bool complex_T =
216+
TiledArray::detail::is_complex_v<std::decay_t<decltype(v)>>;
217+
if constexpr (complex_T) {
218+
std::string result;
219+
result = "{" + std::to_string(v.real()) + "," +
220+
std::to_string(v.imag()) + "}";
221+
return result;
222+
} else
223+
return std::to_string(v);
224+
};
225+
std::cout << "Node: " << world.rank() << " Tile: " << tile.range()
226+
<< " id: " << i
227+
<< std::string(" gpu: " + to_string(tile[i]) +
228+
" cpu: " + to_string(result) + "\n");
229+
break;
230+
}
231+
}
232+
};
233+
234+
for (auto iter = c.begin(); iter != c.end(); iter++) {
235+
world.taskq.add(verify, c.find(iter.index()));
236+
}
237+
238+
world.gop.fence();
239+
240+
if (world.rank() == 0) {
241+
std::cout << "Verification Passed" << std::endl;
242+
}
243+
}
244+
245+
int try_main(int argc, char** argv) {
246+
// Initialize runtime
247+
TiledArray::World& world = TA_SCOPED_INITIALIZE(argc, argv);
248+
249+
// Get command line arguments
250+
if (argc < 6) {
251+
std::cout
252+
<< "multiplies A(Nm,Nk) * B(Nk,Nn), with dimensions m, n, and k "
253+
"blocked by Bm, Bn, and Bk, respectively"
254+
<< std::endl
255+
<< "Usage: " << argv[0]
256+
<< " Nm Bm Nn Bn Nk Bk [# of repetitions = 5] [scalar = double]\n";
257+
return 0;
258+
}
259+
const long Nm = atol(argv[1]);
260+
const long Bm = atol(argv[2]);
261+
const long Nn = atol(argv[3]);
262+
const long Bn = atol(argv[4]);
263+
const long Nk = atol(argv[5]);
264+
const long Bk = atol(argv[6]);
265+
if (Nm <= 0 || Nn <= 0 || Nk <= 0) {
266+
std::cerr << "Error: dimensions must be greater than zero.\n";
267+
return 1;
268+
}
269+
if (Bm <= 0 || Bn <= 0 || Bk <= 0) {
270+
std::cerr << "Error: block sizes must be greater than zero.\n";
271+
return 1;
272+
}
273+
const long nrepeat = (argc >= 8 ? atol(argv[7]) : 5);
274+
if (nrepeat <= 0) {
275+
std::cerr << "Error: number of repetitions must be greater than zero.\n";
276+
return 1;
277+
}
278+
279+
const std::string scalar_type_str = (argc >= 9 ? argv[8] : "double");
280+
if (scalar_type_str != "double" && scalar_type_str != "float" &&
281+
scalar_type_str != "zdouble" && scalar_type_str != "zfloat") {
282+
std::cerr << "Error: invalid real type " << scalar_type_str << ".\n";
283+
std::cerr << " valid real types are \"double\", \"float\", "
284+
"\"zdouble\", and \"zfloat\".\n";
285+
return 1;
286+
}
287+
288+
std::cout << "Using TA::UMTensor<" << scalar_type_str << ">" << std::endl;
289+
290+
int driverVersion, runtimeVersion;
291+
auto error = TiledArray::device::driverVersion(&driverVersion);
292+
if (error != TiledArray::device::Success) {
293+
std::cout << "error(DriverGetVersion) = " << error << std::endl;
294+
}
295+
error = TiledArray::device::runtimeVersion(&runtimeVersion);
296+
if (error != TiledArray::device::Success) {
297+
std::cout << "error(RuntimeGetVersion) = " << error << std::endl;
298+
}
299+
std::cout << "device {driver,runtime} versions = " << driverVersion << ","
300+
<< runtimeVersion << std::endl;
301+
302+
{ // print device properties
303+
int num_devices = TA::deviceEnv::instance()->num_visible_devices();
304+
305+
if (num_devices <= 0) {
306+
throw std::runtime_error("No GPUs Found!\n");
307+
}
308+
309+
const int device_id = TA::deviceEnv::instance()->current_device_id();
310+
311+
int mpi_size = world.size();
312+
int mpi_rank = world.rank();
313+
314+
for (int i = 0; i < mpi_size; i++) {
315+
if (i == mpi_rank) {
316+
std::cout << "Device Information for MPI Process Rank: " << mpi_rank
317+
<< std::endl;
318+
TiledArray::device::deviceProp_t prop;
319+
auto error = TiledArray::device::getDeviceProperties(&prop, device_id);
320+
if (error != TiledArray::device::Success) {
321+
std::cout << "error(GetDeviceProperties) = " << error << std::endl;
322+
}
323+
std::cout << "Device #" << device_id << ": " << prop.name << std::endl
324+
<< " managedMemory = " << prop.managedMemory << std::endl;
325+
int result;
326+
error = TiledArray::device::deviceGetAttribute(
327+
&result, TiledArray::device::DevAttrUnifiedAddressing, device_id);
328+
std::cout << " attrUnifiedAddressing = " << result << std::endl;
329+
error = TiledArray::device::deviceGetAttribute(
330+
&result, TiledArray::device::DevAttrConcurrentManagedAccess,
331+
device_id);
332+
std::cout << " attrConcurrentManagedAccess = " << result << std::endl;
333+
error = TiledArray::device::setDevice(device_id);
334+
if (error != TiledArray::device::Success) {
335+
std::cout << "error(device::setDevice) = " << error << std::endl;
336+
}
337+
size_t free_mem, total_mem;
338+
error = TiledArray::device::memGetInfo(&free_mem, &total_mem);
339+
std::cout << " {total,free} memory = {" << total_mem << "," << free_mem
340+
<< "}" << std::endl;
341+
}
342+
world.gop.fence();
343+
}
344+
} // print device properties
345+
346+
if (scalar_type_str == "double")
347+
do_main_body<double>(world, Nm, Bm, Nn, Bn, Nk, Bk, nrepeat);
348+
else if (scalar_type_str == "float")
349+
do_main_body<float>(world, Nm, Bm, Nn, Bn, Nk, Bk, nrepeat);
350+
else if (scalar_type_str == "zdouble")
351+
do_main_body<std::complex<double>>(world, Nm, Bm, Nn, Bn, Nk, Bk, nrepeat);
352+
else if (scalar_type_str == "zfloat")
353+
do_main_body<std::complex<float>>(world, Nm, Bm, Nn, Bn, Nk, Bk, nrepeat);
354+
else {
355+
abort(); // unreachable
356+
}
357+
358+
return 0;
359+
}
360+
361+
int main(int argc, char* argv[]) {
362+
try {
363+
try_main(argc, argv);
364+
} catch (std::exception& ex) {
365+
std::cout << ex.what() << std::endl;
366+
367+
size_t free_mem, total_mem;
368+
auto result = TiledArray::device::memGetInfo(&free_mem, &total_mem);
369+
std::cout << "device memory stats: {total,free} = {" << total_mem << ","
370+
<< free_mem << "}" << std::endl;
371+
} catch (...) {
372+
std::cerr << "unknown exception" << std::endl;
373+
}
374+
375+
return 0;
376+
}

0 commit comments

Comments
 (0)