hexagon: add support for TRI op (#22822)

* Hexagon: TRI HVX Kernel addition to ggml hexagon HTP ops and context

* addressed PR review comments for TRI op

* hexagon: clang format

* hex-unary: remove merge conflict markers

* hex-ggml: remove duplicate op cases (merge conflict)

* hex-ggml: fix editor config errors

---------

Co-authored-by: Todor Boinovski <todorb@qti.qualcomm.com>
Co-authored-by: Max Krasnyansky <maxk@qti.qualcomm.com>
This commit is contained in:
Pranav Dhinakar
2026-05-18 14:04:57 -07:00
committed by GitHub
parent b7340443d4
commit 9a532ae4ba
5 changed files with 137 additions and 1 deletions
+20
View File
@@ -2828,6 +2828,21 @@ static bool ggml_hexagon_supported_solve_tri(const struct ggml_hexagon_session *
return true;
}
static bool ggml_hexagon_supported_tri(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) {
const struct ggml_tensor * src0 = op->src[0];
const struct ggml_tensor * dst = op;
if (src0->type != GGML_TYPE_F32) { return false; }
if (dst->type != GGML_TYPE_F32) { return false; }
if (!ggml_are_same_shape(src0, dst)) { return false; }
if (!ggml_is_contiguous(src0) || !ggml_is_contiguous(dst)) { return false; }
return true;
GGML_UNUSED(sess);
}
static const char * ggml_backend_hexagon_name(ggml_backend_t backend) {
auto sess = static_cast<ggml_hexagon_session *>(backend->context);
return sess->c_name();
@@ -2869,6 +2884,7 @@ static htp_op_code op_remap_to_htp(const ggml_tensor * t) {
case GGML_OP_FILL: return HTP_OP_FILL;
case GGML_OP_DIAG: return HTP_OP_DIAG;
case GGML_OP_SOLVE_TRI: return HTP_OP_SOLVE_TRI;
case GGML_OP_TRI: return HTP_OP_TRI;
case GGML_OP_PAD: return HTP_OP_PAD;
case GGML_OP_UNARY:
@@ -3430,6 +3446,10 @@ static bool ggml_backend_hexagon_device_supports_op(ggml_backend_dev_t dev, cons
supp = ggml_hexagon_supported_solve_tri(sess, op);
break;
case GGML_OP_TRI:
supp = ggml_hexagon_supported_tri(sess, op);
break;
case GGML_OP_PAD:
supp = ggml_hexagon_supported_pad(sess, op);
break;
+1
View File
@@ -107,6 +107,7 @@ int op_fill(struct htp_ops_context * octx);
int op_diag(struct htp_ops_context * octx);
int op_solve_tri(struct htp_ops_context * octx);
int op_gated_delta_net(struct htp_ops_context * octx);
int op_tri(struct htp_ops_context * octx);
int op_pad(struct htp_ops_context * octx);
#endif /* HTP_CTX_H */
+1
View File
@@ -86,6 +86,7 @@ enum htp_op_code {
HTP_OP_SOLVE_TRI,
HTP_OP_L2_NORM,
HTP_OP_GATED_DELTA_NET,
HTP_OP_TRI,
HTP_OP_PAD,
HTP_OP_INVALID
+3
View File
@@ -601,6 +601,9 @@ static int execute_op(struct htp_ops_context * octx) {
case HTP_OP_GATED_DELTA_NET:
return op_gated_delta_net(octx);
case HTP_OP_TRI:
return op_tri(octx);
case HTP_OP_INVALID:
break;
+112 -1
View File
@@ -17,7 +17,6 @@
#include "ggml-common.h"
#include "htp-ctx.h"
#include "htp-ops.h"
#include "htp-ops.h"
struct htp_unary_context {
struct htp_ops_context * octx;
@@ -277,6 +276,95 @@ static void sigmoid_f32(const float * restrict src,
}
}
static void tri_f32(const float * restrict src,
float * restrict dst,
uint8_t * restrict spad,
const uint32_t num_rows,
const uint32_t row_elems,
const size_t row_size,
int32_t * op_params,
const uint32_t ir,
const struct htp_unary_context * uctx) {
const int32_t ttype = op_params[0];
const HVX_Vector zero = hvx_vec_splat_f32(0.0f);
const uint32_t nvec = row_elems / VLEN_FP32;
const uint32_t nloe = row_elems % VLEN_FP32;
const uint32_t ne01 = uctx->octx->src[0]->ne[1];
for (uint32_t b = 0; b < num_rows; b++) {
const uint32_t abs_row = ir + b;
const uint32_t i01 = abs_row % ne01;
const HVX_Vector * restrict v_src = (const HVX_Vector *) ((const uint8_t *) src + b * row_size);
HVX_Vector * restrict v_dst = (HVX_Vector *) ((uint8_t *) dst + b * row_size);
uint32_t boundary;
int keep_left;
switch (ttype) {
case 0: boundary = i01; keep_left = 0; break; // keep col >= row
case 1: boundary = i01 + 1; keep_left = 0; break; // keep col > row
case 2: boundary = i01 + 1; keep_left = 1; break; // keep col <= row
case 3: boundary = i01; keep_left = 1; break; // keep col < row
default: boundary = 0; keep_left = 0; break;
}
if (boundary > row_elems) boundary = row_elems;
// Full HVX vectors — each starts at a 128-byte aligned offset
for (uint32_t i = 0; i < nvec; i++) {
const uint32_t vec_start = i * VLEN_FP32;
const uint32_t vec_end = vec_start + VLEN_FP32;
if (keep_left) {
if (vec_end <= boundary) {
v_dst[i] = v_src[i];
} else if (vec_start >= boundary) {
v_dst[i] = zero;
} else {
HVX_VectorPred mask = Q6_Q_vsetq_R((boundary - vec_start) * sizeof(float));
v_dst[i] = Q6_V_vmux_QVV(mask, v_src[i], zero);
}
} else {
if (vec_end <= boundary) {
v_dst[i] = zero;
} else if (vec_start >= boundary) {
v_dst[i] = v_src[i];
} else {
HVX_VectorPred mask = Q6_Q_vsetq_R((boundary - vec_start) * sizeof(float));
v_dst[i] = Q6_V_vmux_QVV(mask, zero, v_src[i]);
}
}
}
// Tail elements (row_elems not a multiple of VLEN_FP32)
if (nloe > 0) {
const uint32_t vec_start = nvec * VLEN_FP32;
const uint32_t vec_end = vec_start + nloe;
HVX_Vector tail_val;
if (keep_left) {
if (vec_end <= boundary) {
tail_val = v_src[nvec];
} else if (vec_start >= boundary) {
tail_val = zero;
} else {
HVX_VectorPred mask = Q6_Q_vsetq_R((boundary - vec_start) * sizeof(float));
tail_val = Q6_V_vmux_QVV(mask, v_src[nvec], zero);
}
} else {
if (vec_end <= boundary) {
tail_val = zero;
} else if (vec_start >= boundary) {
tail_val = v_src[nvec];
} else {
HVX_VectorPred mask = Q6_Q_vsetq_R((boundary - vec_start) * sizeof(float));
tail_val = Q6_V_vmux_QVV(mask, zero, v_src[nvec]);
}
}
hvx_vec_store_a(&v_dst[nvec], nloe * sizeof(float), tail_val);
}
}
}
static void softplus_f32(const float * restrict src,
float * restrict dst,
uint8_t * restrict spad,
@@ -498,6 +586,9 @@ static void unary_job_f32_per_thread(unsigned int nth, unsigned int ith, void *
case HTP_OP_L2_NORM:
l2_norm_f32(src0_spad, dst_spad, NULL, block_size, ne0, src0_row_size_aligned, op_params);
break;
case HTP_OP_TRI:
tri_f32(src0_spad, dst_spad, NULL, block_size, ne00, src0_row_size_aligned, op_params, ir, uctx);
break;
default:
break;
}
@@ -571,6 +662,10 @@ static int execute_op_unary_f32(struct htp_ops_context * octx) {
case HTP_OP_L2_NORM:
op_type = "l2norm-f32";
break;
case HTP_OP_TRI:
op_type = "tri-f32";
break;
default:
FARF(ERROR, "Unsupported unary Op %u\n", octx->op);
return HTP_STATUS_NO_SUPPORT;
@@ -640,6 +735,22 @@ static int execute_op_unary_f32(struct htp_ops_context * octx) {
return err;
}
int op_tri(struct htp_ops_context * octx) {
int err = HTP_STATUS_OK;
switch (octx->src[0]->type) {
case HTP_TYPE_F32:
err = execute_op_unary_f32(octx);
break;
default:
err = HTP_STATUS_NO_SUPPORT;
break;
}
return err;
}
int op_unary(struct htp_ops_context * octx) {
int err = HTP_STATUS_OK;