Ver código fonte

tcg/tcti: implement vector immediate shifts

This now seems to be required as a result of the introduction of
gen_gvec_rev{16,32,64} in 38f9950c8e0315d7b26803018a3f73d5f42e6703.
osy 1 mês atrás
pai
commit
594fe2f680

+ 1 - 1
tcg/aarch64-tcti/tcg-target-has.h

@@ -84,7 +84,7 @@
 #define TCG_TARGET_HAS_roti_vec         0
 #define TCG_TARGET_HAS_rots_vec         0
 #define TCG_TARGET_HAS_rotv_vec         0
-#define TCG_TARGET_HAS_shi_vec          0
+#define TCG_TARGET_HAS_shi_vec          1
 #define TCG_TARGET_HAS_shs_vec          0
 #define TCG_TARGET_HAS_shv_vec          1
 #define TCG_TARGET_HAS_mul_vec          1

+ 1 - 0
tcg/aarch64-tcti/tcg-target-opc.h.inc

@@ -12,3 +12,4 @@
  */
 
 DEF(aa64_sshl_vec, 1, 2, 0, TCG_OPF_VECTOR)
+DEF(aa64_sli_vec, 1, 2, 1, TCG_OPF_VECTOR)

+ 42 - 3
tcg/aarch64-tcti/tcg-target.c.inc

@@ -217,6 +217,8 @@ tcg_target_op_def(TCGOpcode op, TCGType type, unsigned flags)
         return C_O1_I2(w, w, w);
     case INDEX_op_bitsel_vec:
         return C_O1_I3(w, w, w, w);
+    case INDEX_op_aa64_sli_vec:
+        return C_O1_I2(w, 0, w);
 
     default:
         return C_NotImplemented;
@@ -490,6 +492,13 @@ static void tcg_out_ternary_gadget(TCGContext *s, const void *gadget_base[TCG_TA
     tcg_out_gadget(s, gadget_base[reg0][reg1][reg2]);
 }
 
+
+/* Write gadget pointer (three registers, last is immediate value). */
+static void tcg_out_ternary_immediate_gadget(TCGContext *s, const void *gadget_base[TCG_TARGET_GP_REGS][TCG_TARGET_GP_REGS][TCTI_GADGET_IMMEDIATE_ARRAY_LEN], unsigned reg0, unsigned reg1, unsigned reg2)
+{
+    tcg_out_gadget(s, gadget_base[reg0][reg1][reg2]);
+}
+
 /***************************
  *  TCG Scalar Operations  *
  ***************************/
@@ -1558,13 +1567,18 @@ static void tcg_out_nop_fill(tcg_insn_unit *p, int count)
     tcg_out_sized_vector_gadget_no64(s, name, ternary, vece, a, b, c)
 
 
-#define tcg_out_ternary_vector_gadget_with_scalar(s, name, is_scalar, vece, a, b, c) \
+#define tcg_out_sized_gadget_with_scalar(s, name, arity, is_scalar, vece, args...) \
     if (is_scalar) { \
-        tcg_out_ternary_gadget(s, gadget_ ## name ## _scalar, w0, w1, w2); \
+        tcg_out_ ## arity ## _gadget(s, gadget_ ## name ## _scalar, args); \
     } else { \
-        tcg_out_ternary_vector_gadget(s, name, vece, w0, w1, w2); \
+        tcg_out_sized_vector_gadget(s, name, arity, vece, args); \
     }
 
+#define tcg_out_ternary_vector_gadget_with_scalar(s, name, is_scalar, vece, a, b, c) \
+    tcg_out_sized_gadget_with_scalar(s, name, ternary, is_scalar, vece, a, b, c)
+
+#define tcg_out_ternary_immediate_vector_gadget_with_scalar(s, name, is_scalar, vece, a, b, c) \
+    tcg_out_sized_gadget_with_scalar(s, name, ternary_immediate, is_scalar, vece, a, b, c)
 
 /* Return true if v16 is a valid 16-bit shifted immediate.  */
 static bool is_shimm16(uint16_t v16, int *cmode, int *imm8)
@@ -1765,6 +1779,20 @@ static void tcg_out_vec_op(TCGContext *s, TCGOpcode opc, unsigned vecl, unsigned
         break;
     }
 
+    /* inhibit compiler warning because we use imm as a register */
+    case INDEX_op_shli_vec:
+        tcg_out_ternary_immediate_vector_gadget_with_scalar(s, shl, is_scalar, vece, w0, w1, r2);
+        break;
+    case INDEX_op_shri_vec:
+        tcg_out_ternary_immediate_vector_gadget_with_scalar(s, ushr, is_scalar, vece, w0, w1, r2 - 1);
+        break;
+    case INDEX_op_sari_vec:
+        tcg_out_ternary_immediate_vector_gadget_with_scalar(s, sshr, is_scalar, vece, w0, w1, r2 - 1);
+        break;
+    case INDEX_op_aa64_sli_vec:
+        tcg_out_ternary_immediate_vector_gadget_with_scalar(s, sli, is_scalar, vece, w0, w2, r3);
+        break;
+
     case INDEX_op_mov_vec:  /* Always emitted via tcg_out_mov.  */
     case INDEX_op_dup_vec:  /* Always emitted via tcg_out_dup_vec.  */
     default:
@@ -1787,6 +1815,9 @@ int tcg_can_emit_vec_op(TCGOpcode opc, TCGType type, unsigned vece)
     case INDEX_op_abs_vec:
     case INDEX_op_not_vec:
     case INDEX_op_cmp_vec:
+    case INDEX_op_shli_vec:
+    case INDEX_op_shri_vec:
+    case INDEX_op_sari_vec:
     case INDEX_op_ssadd_vec:
     case INDEX_op_sssub_vec:
     case INDEX_op_usadd_vec:
@@ -1827,6 +1858,14 @@ void tcg_expand_vec_op(TCGOpcode opc, TCGType type, unsigned vece,
     va_end(va);
 
     switch (opc) {
+    case INDEX_op_rotli_vec:
+        t1 = tcg_temp_new_vec(type);
+        tcg_gen_shri_vec(vece, t1, v1, -a2 & ((8 << vece) - 1));
+        vec_gen_4(INDEX_op_aa64_sli_vec, type, vece,
+                  tcgv_vec_arg(v0), tcgv_vec_arg(t1), tcgv_vec_arg(v1), a2);
+        tcg_temp_free_vec(t1);
+        break;
+
     case INDEX_op_shrv_vec:
     case INDEX_op_sarv_vec:
         /* Right shifts are negative left shifts for AArch64.  */

+ 59 - 4
tcg/aarch64-tcti/tcti-gadget-gen.py

@@ -113,7 +113,7 @@ def simple(name, *lines, export=True):
 
 
 
-def with_register_substitutions(name, substitutions, *lines, immediate_range=range(0)):
+def with_register_substitutions(name, substitutions, *lines, immediate_range=range(0), filter=lambda p: False):
     """ Generates a collection of gadgtes with register substitutions. """
 
     def _expand_op1_immediate(num):
@@ -166,6 +166,10 @@ def substitutions_for_letter(letter, number, line):
 
     #  For each permutation...
     for permutation in permutations:
+        # Filter any invalid combination
+        if filter(permutation): 
+            continue
+
         new_lines = lines
 
         # Replace each placeholder element with its proper value...
@@ -212,9 +216,9 @@ def with_dnm(name, *lines):
     print("};", file=c_file)
 
 
-def with_dn_immediate(name, *lines, immediate_range):
+def with_dn_immediate(name, *lines, immediate_range, filter=lambda m: False):
     """ Generates a collection of gadgets with substitutions for Xd, Xn, and Xm, and equivalents. """
-    with_register_substitutions(name, ["d", "n"], *lines, immediate_range=immediate_range)
+    with_register_substitutions(name, ["d", "n"], *lines, immediate_range=immediate_range, filter=lambda p: filter(p[-1]))
 
     # Fetch the files we'll be using for output.
     c_file, h_file = _get_output_files()
@@ -236,7 +240,10 @@ def with_dn_immediate(name, *lines, immediate_range):
 
             # M array
             for i in immediate_range:
-                print(f"gadget_{name}_arg{d}_arg{n}_arg{i}", end=", ", file=c_file)
+                if filter(i):
+                    print(f"(void *)0", end=", ", file=c_file)
+                else:
+                    print(f"gadget_{name}_arg{d}_arg{n}_arg{i}", end=", ", file=c_file)
 
             print("},", file=c_file)
         print("\t},", file=c_file)
@@ -625,6 +632,24 @@ def do_size_replacement(line, size):
             sized_lines = (scalar,)
         with_dnm(f"{name}_scalar", *sized_lines)
 
+def vector_dn_immediate(name, *lines, scalar=None, immediate_range, omit_sizes=(), filter=lambda s, m: False):
+    """ Creates a set of gadgets for every size of a given vector op. Accepts 'S' as a size placeholder. """
+
+    def do_size_replacement(line, size):
+        return line.replace(".S", f".{size}")
+        
+    # Create a variant for each size, replacing any placeholders.
+    for size in VECTOR_SIZES:
+        if size in omit_sizes:
+            continue
+
+        sized_lines = (do_size_replacement(line, size) for line in lines)
+        with_dn_immediate(f"{name}_{size}", *sized_lines, immediate_range=immediate_range, filter=lambda m: filter(size, m))
+
+    if scalar:
+        if isinstance(scalar, str):
+            sized_lines = (scalar,)
+        with_dn_immediate(f"{name}_scalar", *sized_lines, immediate_range=immediate_range, filter=lambda m: filter(None, m))
 
 def vector_math_dnm(name, operation):
     """ Generates a collection of gadgets for vector math instructions. """
@@ -647,6 +672,9 @@ def vector_logic_dnm(name, operation):
     with_dnm(f"{name}_d", f"{operation} Vd.8b, Vn.8b, Vm.8b")
     with_dnm(f"{name}_q", f"{operation} Vd.16b, Vn.16b, Vm.16b")
 
+def vector_math_dn_immediate(name, operation, immediate_range, filter=lambda x: False):
+    """ Generates a collection of gadgets for vector math instructions. """
+    vector_dn_immediate(name, f"{operation} Vd.S, Vn.S, #Ii", scalar=f"{operation} Dd, Dn, #Ii", immediate_range=immediate_range, filter=filter)
 
 #
 # Gadget definitions.
@@ -1088,6 +1116,33 @@ def vector_logic_dnm(name, operation):
 vector_math_dnm("shlv", "ushl")
 vector_math_dnm("sshl", "sshl")
 
+def filter_shl(size, imm):
+    match size:
+        case '16b': return imm >= 8
+        case '8b': return imm >= 8
+        case '4h': return imm >= 16
+        case '8h': return imm >= 16
+        case '2s': return imm >= 32
+        case '4s': return imm >= 32
+    return False
+
+def filter_shr(size, imm):
+    if imm == 0:
+        return True
+    match size:
+        case '16b': return imm > 8
+        case '8b': return imm > 8
+        case '4h': return imm > 16
+        case '8h': return imm > 16
+        case '2s': return imm > 32
+        case '4s': return imm > 32
+    return False
+
+vector_math_dn_immediate("shl", "shl", immediate_range=range(64), filter=filter_shl)
+vector_math_dn_immediate("ushr", "ushr", immediate_range=range(1,65), filter=filter_shr)
+vector_math_dn_immediate("sshr", "sshr", immediate_range=range(1,65), filter=filter_shr)
+vector_math_dn_immediate("sli", "sli", immediate_range=range(64), filter=filter_shl)
+
 vector_dnm("cmeq", "cmeq Vd.S, Vn.S, Vm.S", scalar="cmeq Dd, Dn, Dm")
 vector_dnm("cmgt", "cmgt Vd.S, Vn.S, Vm.S", scalar="cmgt Dd, Dn, Dm")
 vector_dnm("cmge", "cmge Vd.S, Vn.S, Vm.S", scalar="cmge Dd, Dn, Dm")