hexagon: add support for TRI op (#22822)
* Hexagon: TRI HVX Kernel addition to ggml hexagon HTP ops and context * addressed PR review comments for TRI op * hexagon: clang format * hex-unary: remove merge conflict markers * hex-ggml: remove duplicate op cases (merge conflict) * hex-ggml: fix editor config errors --------- Co-authored-by: Todor Boinovski <todorb@qti.qualcomm.com> Co-authored-by: Max Krasnyansky <maxk@qti.qualcomm.com>
This commit is contained in:
@@ -2828,6 +2828,21 @@ static bool ggml_hexagon_supported_solve_tri(const struct ggml_hexagon_session *
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool ggml_hexagon_supported_tri(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) {
|
||||
|
||||
const struct ggml_tensor * src0 = op->src[0];
|
||||
const struct ggml_tensor * dst = op;
|
||||
|
||||
if (src0->type != GGML_TYPE_F32) { return false; }
|
||||
if (dst->type != GGML_TYPE_F32) { return false; }
|
||||
if (!ggml_are_same_shape(src0, dst)) { return false; }
|
||||
if (!ggml_is_contiguous(src0) || !ggml_is_contiguous(dst)) { return false; }
|
||||
|
||||
return true;
|
||||
|
||||
GGML_UNUSED(sess);
|
||||
}
|
||||
|
||||
static const char * ggml_backend_hexagon_name(ggml_backend_t backend) {
|
||||
auto sess = static_cast<ggml_hexagon_session *>(backend->context);
|
||||
return sess->c_name();
|
||||
@@ -2869,6 +2884,7 @@ static htp_op_code op_remap_to_htp(const ggml_tensor * t) {
|
||||
case GGML_OP_FILL: return HTP_OP_FILL;
|
||||
case GGML_OP_DIAG: return HTP_OP_DIAG;
|
||||
case GGML_OP_SOLVE_TRI: return HTP_OP_SOLVE_TRI;
|
||||
case GGML_OP_TRI: return HTP_OP_TRI;
|
||||
case GGML_OP_PAD: return HTP_OP_PAD;
|
||||
|
||||
case GGML_OP_UNARY:
|
||||
@@ -3430,6 +3446,10 @@ static bool ggml_backend_hexagon_device_supports_op(ggml_backend_dev_t dev, cons
|
||||
supp = ggml_hexagon_supported_solve_tri(sess, op);
|
||||
break;
|
||||
|
||||
case GGML_OP_TRI:
|
||||
supp = ggml_hexagon_supported_tri(sess, op);
|
||||
break;
|
||||
|
||||
case GGML_OP_PAD:
|
||||
supp = ggml_hexagon_supported_pad(sess, op);
|
||||
break;
|
||||
|
||||
@@ -107,6 +107,7 @@ int op_fill(struct htp_ops_context * octx);
|
||||
int op_diag(struct htp_ops_context * octx);
|
||||
int op_solve_tri(struct htp_ops_context * octx);
|
||||
int op_gated_delta_net(struct htp_ops_context * octx);
|
||||
int op_tri(struct htp_ops_context * octx);
|
||||
int op_pad(struct htp_ops_context * octx);
|
||||
|
||||
#endif /* HTP_CTX_H */
|
||||
|
||||
@@ -86,6 +86,7 @@ enum htp_op_code {
|
||||
HTP_OP_SOLVE_TRI,
|
||||
HTP_OP_L2_NORM,
|
||||
HTP_OP_GATED_DELTA_NET,
|
||||
HTP_OP_TRI,
|
||||
HTP_OP_PAD,
|
||||
|
||||
HTP_OP_INVALID
|
||||
|
||||
@@ -601,6 +601,9 @@ static int execute_op(struct htp_ops_context * octx) {
|
||||
case HTP_OP_GATED_DELTA_NET:
|
||||
return op_gated_delta_net(octx);
|
||||
|
||||
case HTP_OP_TRI:
|
||||
return op_tri(octx);
|
||||
|
||||
case HTP_OP_INVALID:
|
||||
break;
|
||||
|
||||
|
||||
@@ -17,7 +17,6 @@
|
||||
#include "ggml-common.h"
|
||||
#include "htp-ctx.h"
|
||||
#include "htp-ops.h"
|
||||
#include "htp-ops.h"
|
||||
|
||||
struct htp_unary_context {
|
||||
struct htp_ops_context * octx;
|
||||
@@ -277,6 +276,95 @@ static void sigmoid_f32(const float * restrict src,
|
||||
}
|
||||
}
|
||||
|
||||
static void tri_f32(const float * restrict src,
|
||||
float * restrict dst,
|
||||
uint8_t * restrict spad,
|
||||
const uint32_t num_rows,
|
||||
const uint32_t row_elems,
|
||||
const size_t row_size,
|
||||
int32_t * op_params,
|
||||
const uint32_t ir,
|
||||
const struct htp_unary_context * uctx) {
|
||||
|
||||
const int32_t ttype = op_params[0];
|
||||
const HVX_Vector zero = hvx_vec_splat_f32(0.0f);
|
||||
const uint32_t nvec = row_elems / VLEN_FP32;
|
||||
const uint32_t nloe = row_elems % VLEN_FP32;
|
||||
|
||||
const uint32_t ne01 = uctx->octx->src[0]->ne[1];
|
||||
|
||||
for (uint32_t b = 0; b < num_rows; b++) {
|
||||
const uint32_t abs_row = ir + b;
|
||||
const uint32_t i01 = abs_row % ne01;
|
||||
|
||||
const HVX_Vector * restrict v_src = (const HVX_Vector *) ((const uint8_t *) src + b * row_size);
|
||||
HVX_Vector * restrict v_dst = (HVX_Vector *) ((uint8_t *) dst + b * row_size);
|
||||
|
||||
uint32_t boundary;
|
||||
int keep_left;
|
||||
switch (ttype) {
|
||||
case 0: boundary = i01; keep_left = 0; break; // keep col >= row
|
||||
case 1: boundary = i01 + 1; keep_left = 0; break; // keep col > row
|
||||
case 2: boundary = i01 + 1; keep_left = 1; break; // keep col <= row
|
||||
case 3: boundary = i01; keep_left = 1; break; // keep col < row
|
||||
default: boundary = 0; keep_left = 0; break;
|
||||
}
|
||||
if (boundary > row_elems) boundary = row_elems;
|
||||
|
||||
// Full HVX vectors — each starts at a 128-byte aligned offset
|
||||
for (uint32_t i = 0; i < nvec; i++) {
|
||||
const uint32_t vec_start = i * VLEN_FP32;
|
||||
const uint32_t vec_end = vec_start + VLEN_FP32;
|
||||
if (keep_left) {
|
||||
if (vec_end <= boundary) {
|
||||
v_dst[i] = v_src[i];
|
||||
} else if (vec_start >= boundary) {
|
||||
v_dst[i] = zero;
|
||||
} else {
|
||||
HVX_VectorPred mask = Q6_Q_vsetq_R((boundary - vec_start) * sizeof(float));
|
||||
v_dst[i] = Q6_V_vmux_QVV(mask, v_src[i], zero);
|
||||
}
|
||||
} else {
|
||||
if (vec_end <= boundary) {
|
||||
v_dst[i] = zero;
|
||||
} else if (vec_start >= boundary) {
|
||||
v_dst[i] = v_src[i];
|
||||
} else {
|
||||
HVX_VectorPred mask = Q6_Q_vsetq_R((boundary - vec_start) * sizeof(float));
|
||||
v_dst[i] = Q6_V_vmux_QVV(mask, zero, v_src[i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Tail elements (row_elems not a multiple of VLEN_FP32)
|
||||
if (nloe > 0) {
|
||||
const uint32_t vec_start = nvec * VLEN_FP32;
|
||||
const uint32_t vec_end = vec_start + nloe;
|
||||
HVX_Vector tail_val;
|
||||
if (keep_left) {
|
||||
if (vec_end <= boundary) {
|
||||
tail_val = v_src[nvec];
|
||||
} else if (vec_start >= boundary) {
|
||||
tail_val = zero;
|
||||
} else {
|
||||
HVX_VectorPred mask = Q6_Q_vsetq_R((boundary - vec_start) * sizeof(float));
|
||||
tail_val = Q6_V_vmux_QVV(mask, v_src[nvec], zero);
|
||||
}
|
||||
} else {
|
||||
if (vec_end <= boundary) {
|
||||
tail_val = zero;
|
||||
} else if (vec_start >= boundary) {
|
||||
tail_val = v_src[nvec];
|
||||
} else {
|
||||
HVX_VectorPred mask = Q6_Q_vsetq_R((boundary - vec_start) * sizeof(float));
|
||||
tail_val = Q6_V_vmux_QVV(mask, zero, v_src[nvec]);
|
||||
}
|
||||
}
|
||||
hvx_vec_store_a(&v_dst[nvec], nloe * sizeof(float), tail_val);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static void softplus_f32(const float * restrict src,
|
||||
float * restrict dst,
|
||||
uint8_t * restrict spad,
|
||||
@@ -498,6 +586,9 @@ static void unary_job_f32_per_thread(unsigned int nth, unsigned int ith, void *
|
||||
case HTP_OP_L2_NORM:
|
||||
l2_norm_f32(src0_spad, dst_spad, NULL, block_size, ne0, src0_row_size_aligned, op_params);
|
||||
break;
|
||||
case HTP_OP_TRI:
|
||||
tri_f32(src0_spad, dst_spad, NULL, block_size, ne00, src0_row_size_aligned, op_params, ir, uctx);
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
@@ -571,6 +662,10 @@ static int execute_op_unary_f32(struct htp_ops_context * octx) {
|
||||
case HTP_OP_L2_NORM:
|
||||
op_type = "l2norm-f32";
|
||||
break;
|
||||
case HTP_OP_TRI:
|
||||
op_type = "tri-f32";
|
||||
break;
|
||||
|
||||
default:
|
||||
FARF(ERROR, "Unsupported unary Op %u\n", octx->op);
|
||||
return HTP_STATUS_NO_SUPPORT;
|
||||
@@ -640,6 +735,22 @@ static int execute_op_unary_f32(struct htp_ops_context * octx) {
|
||||
return err;
|
||||
}
|
||||
|
||||
int op_tri(struct htp_ops_context * octx) {
|
||||
int err = HTP_STATUS_OK;
|
||||
|
||||
switch (octx->src[0]->type) {
|
||||
case HTP_TYPE_F32:
|
||||
err = execute_op_unary_f32(octx);
|
||||
break;
|
||||
|
||||
default:
|
||||
err = HTP_STATUS_NO_SUPPORT;
|
||||
break;
|
||||
}
|
||||
|
||||
return err;
|
||||
}
|
||||
|
||||
int op_unary(struct htp_ops_context * octx) {
|
||||
int err = HTP_STATUS_OK;
|
||||
|
||||
|
||||
Reference in New Issue
Block a user