|
@@ -23,6 +23,7 @@
|
|
#include "llvm/IR/Type.h"
|
|
#include "llvm/IR/Type.h"
|
|
#include "llvm/Support/Casting.h"
|
|
#include "llvm/Support/Casting.h"
|
|
#include "llvm/Support/Compiler.h"
|
|
#include "llvm/Support/Compiler.h"
|
|
|
|
+#include "llvm/Support/ScalableSize.h"
|
|
#include <cassert>
|
|
#include <cassert>
|
|
#include <cstdint>
|
|
#include <cstdint>
|
|
|
|
|
|
@@ -387,6 +388,8 @@ public:
|
|
SequentialType(const SequentialType &) = delete;
|
|
SequentialType(const SequentialType &) = delete;
|
|
SequentialType &operator=(const SequentialType &) = delete;
|
|
SequentialType &operator=(const SequentialType &) = delete;
|
|
|
|
|
|
|
|
+ /// For scalable vectors, this will return the minimum number of elements
|
|
|
|
+ /// in the vector.
|
|
uint64_t getNumElements() const { return NumElements; }
|
|
uint64_t getNumElements() const { return NumElements; }
|
|
Type *getElementType() const { return ContainedType; }
|
|
Type *getElementType() const { return ContainedType; }
|
|
|
|
|
|
@@ -422,14 +425,37 @@ uint64_t Type::getArrayNumElements() const {
|
|
|
|
|
|
/// Class to represent vector types.
|
|
/// Class to represent vector types.
|
|
class VectorType : public SequentialType {
|
|
class VectorType : public SequentialType {
|
|
- VectorType(Type *ElType, unsigned NumEl);
|
|
|
|
|
|
+ /// A fully specified VectorType is of the form <vscale x n x Ty>. 'n' is the
|
|
|
|
+ /// minimum number of elements of type Ty contained within the vector, and
|
|
|
|
+ /// 'scalable' indicates that the total element count is an integer multiple
|
|
|
|
+ /// of 'n', where the multiple is either guaranteed to be one, or is
|
|
|
|
+ /// statically unknown at compile time.
|
|
|
|
+ ///
|
|
|
|
+ /// If the multiple is known to be 1, then the extra term is discarded in
|
|
|
|
+ /// textual IR:
|
|
|
|
+ ///
|
|
|
|
+ /// <4 x i32> - a vector containing 4 i32s
|
|
|
|
+ /// <vscale x 4 x i32> - a vector containing an unknown integer multiple
|
|
|
|
+ /// of 4 i32s
|
|
|
|
+
|
|
|
|
+ VectorType(Type *ElType, unsigned NumEl, bool Scalable = false);
|
|
|
|
+ VectorType(Type *ElType, ElementCount EC);
|
|
|
|
+
|
|
|
|
+ // If true, the total number of elements is an unknown multiple of the
|
|
|
|
+ // minimum 'NumElements' from SequentialType. Otherwise the total number
|
|
|
|
+ // of elements is exactly equal to 'NumElements'.
|
|
|
|
+ bool Scalable;
|
|
|
|
|
|
public:
|
|
public:
|
|
VectorType(const VectorType &) = delete;
|
|
VectorType(const VectorType &) = delete;
|
|
VectorType &operator=(const VectorType &) = delete;
|
|
VectorType &operator=(const VectorType &) = delete;
|
|
|
|
|
|
/// This static method is the primary way to construct an VectorType.
|
|
/// This static method is the primary way to construct an VectorType.
|
|
- static VectorType *get(Type *ElementType, unsigned NumElements);
|
|
|
|
|
|
+ static VectorType *get(Type *ElementType, ElementCount EC);
|
|
|
|
+ static VectorType *get(Type *ElementType, unsigned NumElements,
|
|
|
|
+ bool Scalable = false) {
|
|
|
|
+ return VectorType::get(ElementType, {NumElements, Scalable});
|
|
|
|
+ }
|
|
|
|
|
|
/// This static method gets a VectorType with the same number of elements as
|
|
/// This static method gets a VectorType with the same number of elements as
|
|
/// the input type, and the element type is an integer type of the same width
|
|
/// the input type, and the element type is an integer type of the same width
|
|
@@ -438,7 +464,7 @@ public:
|
|
unsigned EltBits = VTy->getElementType()->getPrimitiveSizeInBits();
|
|
unsigned EltBits = VTy->getElementType()->getPrimitiveSizeInBits();
|
|
assert(EltBits && "Element size must be of a non-zero size");
|
|
assert(EltBits && "Element size must be of a non-zero size");
|
|
Type *EltTy = IntegerType::get(VTy->getContext(), EltBits);
|
|
Type *EltTy = IntegerType::get(VTy->getContext(), EltBits);
|
|
- return VectorType::get(EltTy, VTy->getNumElements());
|
|
|
|
|
|
+ return VectorType::get(EltTy, VTy->getElementCount());
|
|
}
|
|
}
|
|
|
|
|
|
/// This static method is like getInteger except that the element types are
|
|
/// This static method is like getInteger except that the element types are
|
|
@@ -446,7 +472,7 @@ public:
|
|
static VectorType *getExtendedElementVectorType(VectorType *VTy) {
|
|
static VectorType *getExtendedElementVectorType(VectorType *VTy) {
|
|
unsigned EltBits = VTy->getElementType()->getPrimitiveSizeInBits();
|
|
unsigned EltBits = VTy->getElementType()->getPrimitiveSizeInBits();
|
|
Type *EltTy = IntegerType::get(VTy->getContext(), EltBits * 2);
|
|
Type *EltTy = IntegerType::get(VTy->getContext(), EltBits * 2);
|
|
- return VectorType::get(EltTy, VTy->getNumElements());
|
|
|
|
|
|
+ return VectorType::get(EltTy, VTy->getElementCount());
|
|
}
|
|
}
|
|
|
|
|
|
/// This static method is like getInteger except that the element types are
|
|
/// This static method is like getInteger except that the element types are
|
|
@@ -456,29 +482,45 @@ public:
|
|
assert((EltBits & 1) == 0 &&
|
|
assert((EltBits & 1) == 0 &&
|
|
"Cannot truncate vector element with odd bit-width");
|
|
"Cannot truncate vector element with odd bit-width");
|
|
Type *EltTy = IntegerType::get(VTy->getContext(), EltBits / 2);
|
|
Type *EltTy = IntegerType::get(VTy->getContext(), EltBits / 2);
|
|
- return VectorType::get(EltTy, VTy->getNumElements());
|
|
|
|
|
|
+ return VectorType::get(EltTy, VTy->getElementCount());
|
|
}
|
|
}
|
|
|
|
|
|
/// This static method returns a VectorType with half as many elements as the
|
|
/// This static method returns a VectorType with half as many elements as the
|
|
/// input type and the same element type.
|
|
/// input type and the same element type.
|
|
static VectorType *getHalfElementsVectorType(VectorType *VTy) {
|
|
static VectorType *getHalfElementsVectorType(VectorType *VTy) {
|
|
- unsigned NumElts = VTy->getNumElements();
|
|
|
|
- assert ((NumElts & 1) == 0 &&
|
|
|
|
|
|
+ auto EltCnt = VTy->getElementCount();
|
|
|
|
+ assert ((EltCnt.Min & 1) == 0 &&
|
|
"Cannot halve vector with odd number of elements.");
|
|
"Cannot halve vector with odd number of elements.");
|
|
- return VectorType::get(VTy->getElementType(), NumElts/2);
|
|
|
|
|
|
+ return VectorType::get(VTy->getElementType(), EltCnt/2);
|
|
}
|
|
}
|
|
|
|
|
|
/// This static method returns a VectorType with twice as many elements as the
|
|
/// This static method returns a VectorType with twice as many elements as the
|
|
/// input type and the same element type.
|
|
/// input type and the same element type.
|
|
static VectorType *getDoubleElementsVectorType(VectorType *VTy) {
|
|
static VectorType *getDoubleElementsVectorType(VectorType *VTy) {
|
|
- unsigned NumElts = VTy->getNumElements();
|
|
|
|
- return VectorType::get(VTy->getElementType(), NumElts*2);
|
|
|
|
|
|
+ auto EltCnt = VTy->getElementCount();
|
|
|
|
+ assert((VTy->getNumElements() * 2ull) <= UINT_MAX &&
|
|
|
|
+ "Too many elements in vector");
|
|
|
|
+ return VectorType::get(VTy->getElementType(), EltCnt*2);
|
|
}
|
|
}
|
|
|
|
|
|
/// Return true if the specified type is valid as a element type.
|
|
/// Return true if the specified type is valid as a element type.
|
|
static bool isValidElementType(Type *ElemTy);
|
|
static bool isValidElementType(Type *ElemTy);
|
|
|
|
|
|
- /// Return the number of bits in the Vector type.
|
|
|
|
|
|
+ /// Return an ElementCount instance to represent the (possibly scalable)
|
|
|
|
+ /// number of elements in the vector.
|
|
|
|
+ ElementCount getElementCount() const {
|
|
|
|
+ uint64_t MinimumEltCnt = getNumElements();
|
|
|
|
+ assert(MinimumEltCnt <= UINT_MAX && "Too many elements in vector");
|
|
|
|
+ return { (unsigned)MinimumEltCnt, Scalable };
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ /// Returns whether or not this is a scalable vector (meaning the total
|
|
|
|
+ /// element count is a multiple of the minimum).
|
|
|
|
+ bool isScalable() const {
|
|
|
|
+ return Scalable;
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ /// Return the minimum number of bits in the Vector type.
|
|
/// Returns zero when the vector is a vector of pointers.
|
|
/// Returns zero when the vector is a vector of pointers.
|
|
unsigned getBitWidth() const {
|
|
unsigned getBitWidth() const {
|
|
return getNumElements() * getElementType()->getPrimitiveSizeInBits();
|
|
return getNumElements() * getElementType()->getPrimitiveSizeInBits();
|
|
@@ -494,6 +536,10 @@ unsigned Type::getVectorNumElements() const {
|
|
return cast<VectorType>(this)->getNumElements();
|
|
return cast<VectorType>(this)->getNumElements();
|
|
}
|
|
}
|
|
|
|
|
|
|
|
+bool Type::getVectorIsScalable() const {
|
|
|
|
+ return cast<VectorType>(this)->isScalable();
|
|
|
|
+}
|
|
|
|
+
|
|
/// Class to represent pointers.
|
|
/// Class to represent pointers.
|
|
class PointerType : public Type {
|
|
class PointerType : public Type {
|
|
explicit PointerType(Type *ElType, unsigned AddrSpace);
|
|
explicit PointerType(Type *ElType, unsigned AddrSpace);
|