blob: d9dfc3b8c0a2f57a881c31c016296511cd7142ae [file] [log] [blame]
//
// Copyright (c) 2020 The Khronos Group Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
#ifndef SUBGROUPCOMMONTEMPLATES_H
#define SUBGROUPCOMMONTEMPLATES_H
#include "typeWrappers.h"
#include "CL/cl_half.h"
#include "subhelpers.h"
#include <set>
#include <algorithm>
// DESCRIPTION :
// sub_group_broadcast - each work_item registers it's own value.
// All work_items in subgroup takes one value from only one (any) work_item
// sub_group_broadcast_first - same as type 0. All work_items in
// subgroup takes only one value from only one chosen (the smallest subgroup ID)
// work_item
// sub_group_non_uniform_broadcast - same as type 0 but
// only 4 work_items from subgroup enter the code (are active)
template <typename Ty, SubgroupsBroadcastOp operation> struct BC
{
static void log_test(const WorkGroupParams &test_params,
const char *extra_text)
{
log_info(" sub_group_%s(%s)...%s\n", operation_names(operation),
TypeManager<Ty>::name(), extra_text);
}
static void gen(Ty *x, Ty *t, cl_int *m, const WorkGroupParams &test_params)
{
int i, ii, j, k, n;
int ng = test_params.global_workgroup_size;
int nw = test_params.local_workgroup_size;
int ns = test_params.subgroup_size;
int nj = (nw + ns - 1) / ns;
int d = ns > 100 ? 100 : ns;
int non_uniform_size = ng % nw;
ng = ng / nw;
int last_subgroup_size = 0;
ii = 0;
if (non_uniform_size)
{
ng++;
}
for (k = 0; k < ng; ++k)
{ // for each work_group
if (non_uniform_size && k == ng - 1)
{
set_last_workgroup_params(non_uniform_size, nj, ns, nw,
last_subgroup_size);
}
for (j = 0; j < nj; ++j)
{ // for each subgroup
ii = j * ns;
if (last_subgroup_size && j == nj - 1)
{
n = last_subgroup_size;
}
else
{
n = ii + ns > nw ? nw - ii : ns;
}
int bcast_if = 0;
int bcast_elseif = 0;
int bcast_index = (int)(genrand_int32(gMTdata) & 0x7fffffff)
% (d > n ? n : d);
// l - calculate subgroup local id from which value will be
// broadcasted (one the same value for whole subgroup)
if (operation != SubgroupsBroadcastOp::broadcast)
{
// reduce brodcasting index in case of non_uniform and
// last workgroup last subgroup
if (last_subgroup_size && j == nj - 1
&& last_subgroup_size < NR_OF_ACTIVE_WORK_ITEMS)
{
bcast_if = bcast_index % last_subgroup_size;
bcast_elseif = bcast_if;
}
else
{
bcast_if = bcast_index % NR_OF_ACTIVE_WORK_ITEMS;
bcast_elseif = NR_OF_ACTIVE_WORK_ITEMS
+ bcast_index % (n - NR_OF_ACTIVE_WORK_ITEMS);
}
}
for (i = 0; i < n; ++i)
{
if (operation == SubgroupsBroadcastOp::broadcast)
{
int midx = 4 * ii + 4 * i + 2;
m[midx] = (cl_int)bcast_index;
}
else
{
if (i < NR_OF_ACTIVE_WORK_ITEMS)
{
// index of the third
// element int the vector.
int midx = 4 * ii + 4 * i + 2;
// storing information about
// broadcasting index -
// earlier calculated
m[midx] = (cl_int)bcast_if;
}
else
{ // index of the third
// element int the vector.
int midx = 4 * ii + 4 * i + 3;
m[midx] = (cl_int)bcast_elseif;
}
}
// calculate value for broadcasting
cl_ulong number = genrand_int64(gMTdata);
set_value(t[ii + i], number);
}
}
// Now map into work group using map from device
for (j = 0; j < nw; ++j)
{ // for each element in work_group
// calculate index as number of subgroup
// plus subgroup local id
x[j] = t[j];
}
x += nw;
m += 4 * nw;
}
}
static test_status chk(Ty *x, Ty *y, Ty *mx, Ty *my, cl_int *m,
const WorkGroupParams &test_params)
{
int ii, i, j, k, l, n;
int ng = test_params.global_workgroup_size;
int nw = test_params.local_workgroup_size;
int ns = test_params.subgroup_size;
int nj = (nw + ns - 1) / ns;
Ty tr, rr;
int non_uniform_size = ng % nw;
ng = ng / nw;
int last_subgroup_size = 0;
if (non_uniform_size) ng++;
for (k = 0; k < ng; ++k)
{ // for each work_group
if (non_uniform_size && k == ng - 1)
{
set_last_workgroup_params(non_uniform_size, nj, ns, nw,
last_subgroup_size);
}
for (j = 0; j < nw; ++j)
{ // inside the work_group
mx[j] = x[j]; // read host inputs for work_group
my[j] = y[j]; // read device outputs for work_group
}
for (j = 0; j < nj; ++j)
{ // for each subgroup
ii = j * ns;
if (last_subgroup_size && j == nj - 1)
{
n = last_subgroup_size;
}
else
{
n = ii + ns > nw ? nw - ii : ns;
}
// Check result
if (operation == SubgroupsBroadcastOp::broadcast_first)
{
int lowest_active_id = -1;
for (i = 0; i < n; ++i)
{
lowest_active_id = i < NR_OF_ACTIVE_WORK_ITEMS
? 0
: NR_OF_ACTIVE_WORK_ITEMS;
// findout if broadcasted
// value is the same
tr = mx[ii + lowest_active_id];
// findout if broadcasted to all
rr = my[ii + i];
if (!compare(rr, tr))
{
log_error(
"ERROR: sub_group_broadcast_first(%s) "
"mismatch "
"for local id %d in sub group %d in group "
"%d\n",
TypeManager<Ty>::name(), i, j, k);
return TEST_FAIL;
}
}
}
else
{
for (i = 0; i < n; ++i)
{
if (operation == SubgroupsBroadcastOp::broadcast)
{
int midx = 4 * ii + 4 * i + 2;
l = (int)m[midx];
tr = mx[ii + l];
}
else
{
if (i < NR_OF_ACTIVE_WORK_ITEMS)
{ // take index of array where info
// which work_item will be
// broadcast its value is stored
int midx = 4 * ii + 4 * i + 2;
// take subgroup local id of
// this work_item
l = (int)m[midx];
// take value generated on host
// for this work_item
tr = mx[ii + l];
}
else
{
int midx = 4 * ii + 4 * i + 3;
l = (int)m[midx];
tr = mx[ii + l];
}
}
rr = my[ii + i]; // read device outputs for
// work_item in the subgroup
if (!compare(rr, tr))
{
log_error("ERROR: sub_group_%s(%s) "
"mismatch for local id %d in sub "
"group %d in group %d - %s\n",
operation_names(operation),
TypeManager<Ty>::name(), i, j, k,
print_expected_obtained(tr, rr).c_str());
return TEST_FAIL;
}
}
}
}
x += nw;
y += nw;
m += 4 * nw;
}
return TEST_PASS;
}
};
static float to_float(subgroups::cl_half x) { return cl_half_to_float(x.data); }
static subgroups::cl_half to_half(float x)
{
subgroups::cl_half value;
value.data = cl_half_from_float(x, g_rounding_mode);
return value;
}
// for integer types
template <typename Ty> inline Ty calculate(Ty a, Ty b, ArithmeticOp operation)
{
switch (operation)
{
case ArithmeticOp::add_: return a + b;
case ArithmeticOp::max_: return a > b ? a : b;
case ArithmeticOp::min_: return a < b ? a : b;
case ArithmeticOp::mul_: return a * b;
case ArithmeticOp::and_: return a & b;
case ArithmeticOp::or_: return a | b;
case ArithmeticOp::xor_: return a ^ b;
case ArithmeticOp::logical_and: return a && b;
case ArithmeticOp::logical_or: return a || b;
case ArithmeticOp::logical_xor: return !a ^ !b;
default: log_error("Unknown operation request\n"); break;
}
return 0;
}
// Specialize for floating points.
template <>
inline cl_double calculate(cl_double a, cl_double b, ArithmeticOp operation)
{
switch (operation)
{
case ArithmeticOp::add_: {
return a + b;
}
case ArithmeticOp::max_: {
return a > b ? a : b;
}
case ArithmeticOp::min_: {
return a < b ? a : b;
}
case ArithmeticOp::mul_: {
return a * b;
}
default: log_error("Unknown operation request\n"); break;
}
return 0;
}
template <>
inline cl_float calculate(cl_float a, cl_float b, ArithmeticOp operation)
{
switch (operation)
{
case ArithmeticOp::add_: {
return a + b;
}
case ArithmeticOp::max_: {
return a > b ? a : b;
}
case ArithmeticOp::min_: {
return a < b ? a : b;
}
case ArithmeticOp::mul_: {
return a * b;
}
default: log_error("Unknown operation request\n"); break;
}
return 0;
}
template <>
inline subgroups::cl_half calculate(subgroups::cl_half a, subgroups::cl_half b,
ArithmeticOp operation)
{
switch (operation)
{
case ArithmeticOp::add_: return to_half(to_float(a) + to_float(b));
case ArithmeticOp::max_:
return to_float(a) > to_float(b) || is_half_nan(b.data) ? a : b;
case ArithmeticOp::min_:
return to_float(a) < to_float(b) || is_half_nan(b.data) ? a : b;
case ArithmeticOp::mul_: return to_half(to_float(a) * to_float(b));
default: log_error("Unknown operation request\n"); break;
}
return to_half(0);
}
template <typename Ty> bool is_floating_point()
{
return std::is_floating_point<Ty>::value
|| std::is_same<Ty, subgroups::cl_half>::value;
}
template <typename Ty, ArithmeticOp operation>
void generate_inputs(Ty *x, Ty *t, cl_int *m, int ns, int nw, int ng)
{
int nj = (nw + ns - 1) / ns;
std::vector<cl_ulong> safe_values;
if (operation == ArithmeticOp::mul_ || operation == ArithmeticOp::add_)
{
fill_and_shuffle_safe_values(safe_values, ns);
}
for (int k = 0; k < ng; ++k)
{
for (int j = 0; j < nj; ++j)
{
int ii = j * ns;
int n = ii + ns > nw ? nw - ii : ns;
for (int i = 0; i < n; ++i)
{
cl_ulong out_value;
if (operation == ArithmeticOp::mul_
|| operation == ArithmeticOp::add_)
{
out_value = safe_values[i];
}
else
{
out_value = genrand_int64(gMTdata) % (32 * n);
if ((operation == ArithmeticOp::logical_and
|| operation == ArithmeticOp::logical_or
|| operation == ArithmeticOp::logical_xor)
&& ((out_value >> 32) & 1) == 0)
out_value = 0; // increase probability of false
}
set_value(t[ii + i], out_value);
}
}
// Now map into work group using map from device
for (int j = 0; j < nw; ++j)
{
x[j] = t[j];
}
x += nw;
m += 4 * nw;
}
}
template <typename Ty, ShuffleOp operation> struct SHF
{
static void log_test(const WorkGroupParams &test_params,
const char *extra_text)
{
log_info(" sub_group_%s(%s)...%s\n", operation_names(operation),
TypeManager<Ty>::name(), extra_text);
}
static void gen(Ty *x, Ty *t, cl_int *m, const WorkGroupParams &test_params)
{
int i, ii, j, k, n;
cl_uint l;
int nw = test_params.local_workgroup_size;
int ns = test_params.subgroup_size;
int ng = test_params.global_workgroup_size;
int nj = (nw + ns - 1) / ns;
ii = 0;
ng = ng / nw;
for (k = 0; k < ng; ++k)
{ // for each work_group
for (j = 0; j < nj; ++j)
{ // for each subgroup
ii = j * ns;
n = ii + ns > nw ? nw - ii : ns;
for (i = 0; i < n; ++i)
{
int midx = 4 * ii + 4 * i + 2;
l = (((cl_uint)(genrand_int32(gMTdata) & 0x7fffffff) + 1)
% (ns * 2 + 1))
- 1;
switch (operation)
{
case ShuffleOp::shuffle:
case ShuffleOp::shuffle_xor:
case ShuffleOp::shuffle_up:
case ShuffleOp::shuffle_down:
// storing information about shuffle index/delta
m[midx] = (cl_int)l;
break;
case ShuffleOp::rotate:
case ShuffleOp::clustered_rotate:
// Storing information about rotate delta.
// The delta must be the same for each thread in
// the subgroup.
if (i == 0)
{
m[midx] = (cl_int)l;
}
else
{
m[midx] = m[midx - 4];
}
break;
default: break;
}
cl_ulong number = genrand_int64(gMTdata);
set_value(t[ii + i], number);
}
}
// Now map into work group using map from device
for (j = 0; j < nw; ++j)
{ // for each element in work_group
x[j] = t[j];
}
x += nw;
m += 4 * nw;
}
}
static test_status chk(Ty *x, Ty *y, Ty *mx, Ty *my, cl_int *m,
const WorkGroupParams &test_params)
{
int ii, k;
size_t n;
cl_uint l;
size_t nw = test_params.local_workgroup_size;
size_t ns = test_params.subgroup_size;
int ng = test_params.global_workgroup_size;
size_t nj = (nw + ns - 1) / ns;
Ty tr, rr;
ng = ng / nw;
for (k = 0; k < ng; ++k)
{ // for each work_group
for (size_t j = 0; j < nw; ++j)
{ // inside the work_group
mx[j] = x[j]; // read host inputs for work_group
my[j] = y[j]; // read device outputs for work_group
}
for (size_t j = 0; j < nj; ++j)
{ // for each subgroup
ii = j * ns;
n = ii + ns > nw ? nw - ii : ns;
for (size_t i = 0; i < n; ++i)
{ // inside the subgroup
// shuffle index storage
int midx = 4 * ii + 4 * i + 2;
l = m[midx];
rr = my[ii + i];
cl_uint tr_idx;
bool skip = false;
switch (operation)
{
// shuffle basic - treat l as index
case ShuffleOp::shuffle: tr_idx = l; break;
// shuffle xor - treat l as mask
case ShuffleOp::shuffle_xor: tr_idx = i ^ l; break;
// shuffle up - treat l as delta
case ShuffleOp::shuffle_up:
if (l >= ns) skip = true;
tr_idx = i - l;
break;
// shuffle down - treat l as delta
case ShuffleOp::shuffle_down:
if (l >= ns) skip = true;
tr_idx = i + l;
break;
// rotate - treat l as delta
case ShuffleOp::rotate:
tr_idx = (i + l) % test_params.subgroup_size;
break;
case ShuffleOp::clustered_rotate: {
tr_idx = ((i & ~(test_params.cluster_size - 1))
+ ((i + l) % test_params.cluster_size));
break;
}
default: break;
}
if (!skip && tr_idx < n)
{
tr = mx[ii + tr_idx];
if (!compare(rr, tr))
{
log_error("ERROR: sub_group_%s(%s) mismatch for "
"local id %d in sub group %d in group "
"%d\n",
operation_names(operation),
TypeManager<Ty>::name(), i, j, k);
return TEST_FAIL;
}
}
}
}
x += nw;
y += nw;
m += 4 * nw;
}
return TEST_PASS;
}
};
template <typename Ty, ArithmeticOp operation> struct SCEX_NU
{
static void log_test(const WorkGroupParams &test_params,
const char *extra_text)
{
std::string func_name = (test_params.all_work_item_masks.size() > 0
? "sub_group_non_uniform_scan_exclusive"
: "sub_group_scan_exclusive");
log_info(" %s_%s(%s)...%s\n", func_name.c_str(),
operation_names(operation), TypeManager<Ty>::name(),
extra_text);
}
static void gen(Ty *x, Ty *t, cl_int *m, const WorkGroupParams &test_params)
{
int nw = test_params.local_workgroup_size;
int ns = test_params.subgroup_size;
int ng = test_params.global_workgroup_size;
ng = ng / nw;
generate_inputs<Ty, operation>(x, t, m, ns, nw, ng);
}
static test_status chk(Ty *x, Ty *y, Ty *mx, Ty *my, cl_int *m,
const WorkGroupParams &test_params)
{
int ii, i, j, k, n;
int nw = test_params.local_workgroup_size;
int ns = test_params.subgroup_size;
int ng = test_params.global_workgroup_size;
bs128 work_items_mask = test_params.work_items_mask;
int nj = (nw + ns - 1) / ns;
Ty tr, rr;
ng = ng / nw;
std::string func_name = (test_params.all_work_item_masks.size() > 0
? "sub_group_non_uniform_scan_exclusive"
: "sub_group_scan_exclusive");
// for uniform case take into consideration all workitems
if (!work_items_mask.any())
{
work_items_mask.set();
}
for (k = 0; k < ng; ++k)
{ // for each work_group
// Map to array indexed to array indexed by local ID and sub group
for (j = 0; j < nw; ++j)
{ // inside the work_group
mx[j] = x[j]; // read host inputs for work_group
my[j] = y[j]; // read device outputs for work_group
}
for (j = 0; j < nj; ++j)
{
ii = j * ns;
n = ii + ns > nw ? nw - ii : ns;
std::set<int> active_work_items;
for (i = 0; i < n; ++i)
{
if (work_items_mask.test(i))
{
active_work_items.insert(i);
}
}
if (active_work_items.empty())
{
continue;
}
else
{
tr = TypeManager<Ty>::identify_limits(operation);
for (const int &active_work_item : active_work_items)
{
rr = my[ii + active_work_item];
if (!compare_ordered(rr, tr))
{
log_error(
"ERROR: %s_%s(%s) "
"mismatch for local id %d in sub group %d in "
"group %d %s\n",
func_name.c_str(), operation_names(operation),
TypeManager<Ty>::name(), i, j, k,
print_expected_obtained(tr, rr).c_str());
return TEST_FAIL;
}
tr = calculate<Ty>(tr, mx[ii + active_work_item],
operation);
}
}
}
x += nw;
y += nw;
m += 4 * nw;
}
return TEST_PASS;
}
};
// Test for scan inclusive non uniform functions
template <typename Ty, ArithmeticOp operation> struct SCIN_NU
{
static void log_test(const WorkGroupParams &test_params,
const char *extra_text)
{
std::string func_name = (test_params.all_work_item_masks.size() > 0
? "sub_group_non_uniform_scan_inclusive"
: "sub_group_scan_inclusive");
log_info(" %s_%s(%s)...%s\n", func_name.c_str(),
operation_names(operation), TypeManager<Ty>::name(),
extra_text);
}
static void gen(Ty *x, Ty *t, cl_int *m, const WorkGroupParams &test_params)
{
int nw = test_params.local_workgroup_size;
int ns = test_params.subgroup_size;
int ng = test_params.global_workgroup_size;
ng = ng / nw;
generate_inputs<Ty, operation>(x, t, m, ns, nw, ng);
}
static test_status chk(Ty *x, Ty *y, Ty *mx, Ty *my, cl_int *m,
const WorkGroupParams &test_params)
{
int ii, i, j, k, n;
int nw = test_params.local_workgroup_size;
int ns = test_params.subgroup_size;
int ng = test_params.global_workgroup_size;
bs128 work_items_mask = test_params.work_items_mask;
int nj = (nw + ns - 1) / ns;
Ty tr, rr;
ng = ng / nw;
std::string func_name = (test_params.all_work_item_masks.size() > 0
? "sub_group_non_uniform_scan_inclusive"
: "sub_group_scan_inclusive");
// for uniform case take into consideration all workitems
if (!work_items_mask.any())
{
work_items_mask.set();
}
// std::bitset<32> mask32(use_work_items_mask);
// for (int k) mask32.count();
for (k = 0; k < ng; ++k)
{ // for each work_group
// Map to array indexed to array indexed by local ID and sub group
for (j = 0; j < nw; ++j)
{ // inside the work_group
mx[j] = x[j]; // read host inputs for work_group
my[j] = y[j]; // read device outputs for work_group
}
for (j = 0; j < nj; ++j)
{
ii = j * ns;
n = ii + ns > nw ? nw - ii : ns;
std::set<int> active_work_items;
int catch_frist_active = -1;
for (i = 0; i < n; ++i)
{
if (work_items_mask.test(i))
{
if (catch_frist_active == -1)
{
catch_frist_active = i;
}
active_work_items.insert(i);
}
}
if (active_work_items.empty())
{
continue;
}
else
{
tr = TypeManager<Ty>::identify_limits(operation);
for (const int &active_work_item : active_work_items)
{
rr = my[ii + active_work_item];
if (active_work_items.size() == 1)
{
tr = mx[ii + catch_frist_active];
}
else
{
tr = calculate<Ty>(tr, mx[ii + active_work_item],
operation);
}
if (!compare_ordered<Ty>(rr, tr))
{
log_error(
"ERROR: %s_%s(%s) "
"mismatch for local id %d in sub group %d "
"in "
"group %d %s\n",
func_name.c_str(), operation_names(operation),
TypeManager<Ty>::name(), active_work_item, j, k,
print_expected_obtained(tr, rr).c_str());
return TEST_FAIL;
}
}
}
}
x += nw;
y += nw;
m += 4 * nw;
}
return TEST_PASS;
}
};
// Test for reduce non uniform functions
template <typename Ty, ArithmeticOp operation> struct RED_NU
{
static void log_test(const WorkGroupParams &test_params,
const char *extra_text)
{
std::string func_name = (test_params.all_work_item_masks.size() > 0
? "sub_group_non_uniform_reduce"
: "sub_group_reduce");
log_info(" %s_%s(%s)...%s\n", func_name.c_str(),
operation_names(operation), TypeManager<Ty>::name(),
extra_text);
}
static void gen(Ty *x, Ty *t, cl_int *m, const WorkGroupParams &test_params)
{
int nw = test_params.local_workgroup_size;
int ns = test_params.subgroup_size;
int ng = test_params.global_workgroup_size;
ng = ng / nw;
generate_inputs<Ty, operation>(x, t, m, ns, nw, ng);
}
static test_status chk(Ty *x, Ty *y, Ty *mx, Ty *my, cl_int *m,
const WorkGroupParams &test_params)
{
int ii, i, j, k, n;
int nw = test_params.local_workgroup_size;
int ns = test_params.subgroup_size;
int ng = test_params.global_workgroup_size;
bs128 work_items_mask = test_params.work_items_mask;
int nj = (nw + ns - 1) / ns;
ng = ng / nw;
Ty tr, rr;
std::string func_name = (test_params.all_work_item_masks.size() > 0
? "sub_group_non_uniform_reduce"
: "sub_group_reduce");
for (k = 0; k < ng; ++k)
{
// Map to array indexed to array indexed by local ID and sub
// group
for (j = 0; j < nw; ++j)
{
mx[j] = x[j];
my[j] = y[j];
}
if (!work_items_mask.any())
{
work_items_mask.set();
}
for (j = 0; j < nj; ++j)
{
ii = j * ns;
n = ii + ns > nw ? nw - ii : ns;
std::set<int> active_work_items;
int catch_frist_active = -1;
for (i = 0; i < n; ++i)
{
if (work_items_mask.test(i))
{
if (catch_frist_active == -1)
{
catch_frist_active = i;
tr = mx[ii + i];
active_work_items.insert(i);
continue;
}
active_work_items.insert(i);
tr = calculate<Ty>(tr, mx[ii + i], operation);
}
}
if (active_work_items.empty())
{
continue;
}
for (const int &active_work_item : active_work_items)
{
rr = my[ii + active_work_item];
if (!compare_ordered<Ty>(rr, tr))
{
log_error("ERROR: %s_%s(%s) "
"mismatch for local id %d in sub group %d in "
"group %d %s\n",
func_name.c_str(), operation_names(operation),
TypeManager<Ty>::name(), active_work_item, j,
k, print_expected_obtained(tr, rr).c_str());
return TEST_FAIL;
}
}
}
x += nw;
y += nw;
m += 4 * nw;
}
return TEST_PASS;
}
};
#endif