Documentation/x86: Add an example program for AMX usage

Add an example program that demonstrates use of the arch_prctl based API
for dynamically requesting the use of AMX instructions, and then
performing a simple matrix dot product.

Cc: Dave Hansen <dave.hansen@linux.intel.com>
Cc: Chang Seok Bae <chang.seok.bae@intel.com>
Signed-off-by: Vishal Verma <vishal.l.verma@intel.com>
diff --git a/Documentation/x86/amx-example.rst b/Documentation/x86/amx-example.rst
new file mode 100644
index 0000000..4b7477e
--- /dev/null
+++ b/Documentation/x86/amx-example.rst
@@ -0,0 +1,376 @@
+Using arch_prctl to request AMX capabilities
+============================================
+
+Intel AMX (Advanced Matrix Extensions) is a dynamically enabled feature
+that requires a process to request and obtain prior permission from the
+kernel before it can be used.
+
+The  following is an example program that obtains the necessary permissions,
+sets up the sigaltstack, and then performs a simple matrix dot product.
+
+Toolchain notes
+---------------
+
+This requires at least gcc-11.2, glibc-2.34 and binutils-2.37.
+
+Example Program
+---------------
+
+.. code-block:: C
+
+ 
+        /*
+         * AMX usage example
+         *
+         * This performs the following high level steps:
+         *
+         *   1. Detect AMX tile architecture support - CPUID.0x7.0.EDX.AMX_TILE[bit 24] = 1
+         *   2. Setup sigaltstack with proper size (see setup_sigaltstack())
+         *   3. Request permission to use AMX tile data (see request_perm_xtile_data())
+         *   4. Load data and compute dot product (see load_rand_tiledata() and mult_abc())
+         */
+        
+        #define _GNU_SOURCE
+        #include <err.h>
+        #include <errno.h>
+        #include <stdio.h>
+        #include <string.h>
+        #include <stdbool.h>
+        #include <unistd.h>
+        #include <x86intrin.h>
+        #include <immintrin.h>
+        
+        #include <sys/auxv.h>
+        #include <sys/mman.h>
+        #include <sys/syscall.h>
+        #include <sys/signal.h>
+        
+        #define fatal_error(msg, ...)	err(1, "[FAIL]\t" msg, ##__VA_ARGS__)
+        
+        #ifndef AT_MINSIGSTKSZ
+        #  define AT_MINSIGSTKSZ	51
+        #endif
+        
+        #define XFEATURE_XTILECFG	17
+        #define XFEATURE_XTILEDATA	18
+        #define XFEATURE_MASK_XTILECFG	(1 << XFEATURE_XTILECFG)
+        #define XFEATURE_MASK_XTILEDATA	(1 << XFEATURE_XTILEDATA)
+        #define XFEATURE_MASK_XTILE	(XFEATURE_MASK_XTILECFG | XFEATURE_MASK_XTILEDATA)
+        
+        #define ARCH_GET_XCOMP_PERM	0x1022
+        #define ARCH_REQ_XCOMP_PERM	0x1023
+        
+        #define TILE_M			8
+        #define TILE_K			8
+        #define TILE_N			8
+        #define MAX_ELEMENTS		((TILE_M * TILE_K) + (TILE_K * TILE_N) + (TILE_M * TILE_N))
+        #define BYTES_PER_ELEMENT	4
+        
+        struct tile_buffer {
+        	union {
+        		struct {
+        			uint32_t a[TILE_M * TILE_K];
+        			uint32_t b[TILE_K * TILE_N];
+        			uint32_t c[TILE_M * TILE_N];
+        		};
+        		uint32_t bytes[0];
+        	};
+        };
+        
+        static inline void cpuid(uint32_t *eax, uint32_t *ebx, uint32_t *ecx, uint32_t *edx)
+        {
+        	asm volatile("cpuid;"
+        		     : "=a" (*eax), "=b" (*ebx), "=c" (*ecx), "=d" (*edx)
+        		     : "0" (*eax), "2" (*ecx));
+        }
+        
+        #define CPUID_LEAF_XFEATURE_ENUM		0x07
+        #define CPUID_LEAF_TILE_INFO			0x1d
+        #define CPUID_LEAF_TMUL_INFO			0x1e
+        
+        #define CPUID_LEAF_XFEATURE_AMX_TILE_SHIFT	24
+        #define CPUID_LEAF_XFEATURE_AMX_INT8_SHIFT	25
+        
+        #define CPUID_SUBLEAF_TILE_ECX_PALETTE_1	1
+        
+        #define CPUID_TILE_BYTES_MASK			0xffff
+        #define CPUID_TILE_BYTES_PER_TILE_SHIFT		16
+        #define CPUID_TILES_MAX_SHIFT			16
+        #define CPUID_TILE_BYTES_PER_ROW_MASK		0xffff
+        #define CPUID_TILE_ROWS_MAX_MASK		0xffff
+        
+        #define CPUID_TMUL_MAXK_MASK			0xff
+        #define CPUID_TMUL_MAXN_MASK			0xffff
+        #define CPUID_TMUL_MAXN_SHIFT			0x8
+        
+        #define TMM0	0
+        #define TMM1	1
+        #define TMM2	2
+        #define TMM3	3
+        #define TMM4	4
+        #define TMM5	5
+        #define TMM6	6
+        #define TMM7	7
+        
+        static uint32_t max_palette;
+        static uint32_t total_tile_bytes, bytes_per_tile, max_tiles;
+        static uint32_t bytes_per_row, max_rows;
+        static uint32_t tmul_maxk, tmul_maxn;
+        
+        static void amx_check_cpuid(void)
+        {
+        	uint32_t eax, ebx, ecx, edx;
+        
+        	eax = CPUID_LEAF_XFEATURE_ENUM;
+        	ecx = 0;
+        	cpuid(&eax, &ebx, &ecx, &edx);
+        	if (!((edx >> CPUID_LEAF_XFEATURE_AMX_TILE_SHIFT) & 0x1))
+        		fatal_error("CPUID: AMX Tile architecture not supported");
+        	if (!((edx >> CPUID_LEAF_XFEATURE_AMX_INT8_SHIFT) & 0x1))
+        		fatal_error("CPUID: AMX-INT8 operations not supported");
+        
+        	eax = CPUID_LEAF_TILE_INFO;
+        	ecx = 0;
+        	cpuid(&eax, &ebx, &ecx, &edx);
+        
+        	max_palette = eax;
+        	printf("CPUID Tile Info leaf:\n");
+        	printf("  max_palette: %u\n", max_palette);
+        
+        	if (!max_palette)
+        		fatal_error("AMX support missing (max_palette = 0)");
+        
+        	eax = CPUID_LEAF_TILE_INFO;
+        	ecx = CPUID_SUBLEAF_TILE_ECX_PALETTE_1;
+        	cpuid(&eax, &ebx, &ecx, &edx);
+        
+        	total_tile_bytes = eax & CPUID_TILE_BYTES_MASK;
+        	bytes_per_tile = eax >> CPUID_TILE_BYTES_PER_TILE_SHIFT;
+        	bytes_per_row = ebx & CPUID_TILE_BYTES_PER_ROW_MASK;
+        	max_tiles = ebx >> CPUID_TILES_MAX_SHIFT;
+        	max_rows = ecx & CPUID_TILE_ROWS_MAX_MASK;
+        	printf("  total_tile_bytes: %u\n", total_tile_bytes);
+        	printf("  bytes_per_tile: %u\n", bytes_per_tile);
+        	printf("  bytes_per_row: %u\n", bytes_per_row);
+        	printf("  max_tiles: %u\n", max_tiles);
+        	printf("  max_rows: %u\n", max_rows);
+        
+        	eax = CPUID_LEAF_TMUL_INFO;
+        	ecx = 0;
+        	cpuid(&eax, &ebx, &ecx, &edx);
+        
+        	tmul_maxk = ebx & CPUID_TMUL_MAXK_MASK;
+        	tmul_maxn = (ebx >> CPUID_TMUL_MAXN_SHIFT) & CPUID_TMUL_MAXN_MASK;
+        	printf("CPUID TMUL Info leaf:\n");
+        	printf("  tmul_maxk: %u\n", tmul_maxk);
+        	printf("  tmul_maxn: %u\n", tmul_maxn);
+        }
+        
+        static struct tilecfg {
+        	uint8_t palette;	/* byte 0 */
+        	uint8_t start_row;	/* byte 1 */
+        	char rsvd1[14];		/* bytes 2-15 */
+        	uint16_t tile_colsb[8];	/* bytes 16-31 */
+        	char rsvd2[16];		/* bytes 32-47 */
+        	uint8_t tile_rows[8];	/* bytes 48-55 */
+        	char rsvd3[8];		/* bytes 56-63 */
+        } __attribute__((packed)) tilecfg;
+        
+        static void print_tilecfg(struct tilecfg *t)
+        {
+        	int i;
+        
+        	printf("TILECFG:\n");
+        	printf("  palette: %d\n", t->palette);
+        	printf("  start_row: %d\n", t->start_row);
+        	for(i = 0; i < 8; i++)
+        		printf("  tmm%d: [ %d x %d ]\n", i, t->tile_rows[i], t->tile_colsb[i]);
+        }
+        
+        static void load_tile_config(struct tilecfg *t)
+        {
+        	t->palette = 1;
+        	t->start_row = 0;
+        
+        	t->tile_rows[TMM0] = TILE_M;	/* tmm0 -> A: src1 matrix, MxK */
+        	t->tile_colsb[TMM0] = TILE_K * BYTES_PER_ELEMENT;
+        
+        	t->tile_rows[TMM1] = TILE_K;	/* tmm1 -> B: src2 matrix, KxN */
+        	t->tile_colsb[TMM1] = TILE_N * BYTES_PER_ELEMENT;
+        
+        	t->tile_rows[TMM2] = TILE_M;	/* tmm2 -> C: dst matrix, MxN */
+        	t->tile_colsb[TMM2] = TILE_N * BYTES_PER_ELEMENT;
+        
+        	_tile_loadconfig(t);
+        }
+        
+        static void get_stored_tilecfg(struct tilecfg *t)
+        {
+        	_tile_storeconfig(t);
+        }
+        
+        static void set_rand_tiledata(struct tile_buffer *tbuf)
+        {
+        	int data;
+        	int i;
+        
+        	/*
+        	 * Ensure that 'data' is never 0.  This ensures that
+        	 * the registers are never in their initial configuration
+        	 * and thus never tracked as being in the init state.
+        	 */
+        
+        	for (i = 0; i < (MAX_ELEMENTS - (TILE_M * TILE_N)); i++) {
+        		data = (rand() % 0xff) | 1;
+        		tbuf->bytes[i] = data;
+        	}
+        }
+        
+        static void load_rand_tiledata(struct tile_buffer *tbuf)
+        {
+        	set_rand_tiledata(tbuf);
+        
+        	_tile_release();
+        	printf("TILERELEASE Done\n");
+        
+        	load_tile_config(&tilecfg);
+        	printf("LDTILECFG Done\n");
+        
+        	get_stored_tilecfg(&tilecfg);
+        	print_tilecfg(&tilecfg);
+        
+        	_tile_loadd(TMM0, &tbuf->a[0], TILE_K * BYTES_PER_ELEMENT);
+        	printf("TILELOADD tmm0 Done\n");
+        	_tile_loadd(TMM1, &tbuf->b[0], TILE_N * BYTES_PER_ELEMENT);
+        	printf("TILELOADD tmm1 Done\n");
+        	_tile_loadd(TMM2, &tbuf->c[0], TILE_N * BYTES_PER_ELEMENT);
+        	printf("TILELOADD tmm2 Done\n");
+        }
+        
+        static void mult_abc(struct tile_buffer *tbuf)
+        {
+        	_tile_dpbuud(TMM2, TMM0, TMM1);
+        	printf("TDPBUUD (tmm2 += tmm0 . tmm1) Done\n");
+        
+        	_tile_stored(TMM2, &tbuf->c[0], TILE_N * BYTES_PER_ELEMENT);
+        	printf("TILESTORED (tmm2-> 'C') Done\n");
+        }
+        
+        static void request_perm_xtile_data()
+        {
+        	unsigned long bitmask;
+        	long rc;
+        
+        	rc = syscall(SYS_arch_prctl, ARCH_REQ_XCOMP_PERM, XFEATURE_XTILEDATA);
+        	if (rc)
+        		fatal_error("XTILE_DATA request failed: %ld", rc);
+        
+        	rc = syscall(SYS_arch_prctl, ARCH_GET_XCOMP_PERM, &bitmask);
+        	if (rc)
+        		fatal_error("prctl(ARCH_GET_XCOMP_PERM) error: %ld", rc);
+        
+        	if (bitmask & XFEATURE_MASK_XTILE)
+        		printf("ARCH_REQ_XCOMP_PERM XTILE_DATA successful.\n");
+        }
+        
+        static void setup_sigaltstack()
+        {
+        	unsigned long minsigstksz, new_size;
+        	void *altstack;
+        	stack_t ss;
+        	int rc;
+        
+        	minsigstksz = getauxval(AT_MINSIGSTKSZ);
+        	printf("AT_MINSIGSTKSZ = %lu\n", minsigstksz);
+        	/*
+        	 * getauxval() itself can return 0 for failure or
+        	 * success.  But, in this case, AT_MINSIGSTKSZ
+        	 * will always return a >=0 value if implemented.
+        	 * Just check for 0.
+        	 */
+        	if (minsigstksz == 0)
+        		fatal_error("no support for AT_MINSIGSTKSZ");
+        
+        	new_size = minsigstksz * 2;
+        	altstack =  mmap(NULL, new_size, PROT_READ | PROT_WRITE,
+        			MAP_PRIVATE | MAP_ANONYMOUS | MAP_STACK, -1, 0);
+        	if (altstack == MAP_FAILED)
+        		fatal_error("mmap() for altstack");
+        
+        	memset(&ss, 0, sizeof(ss));
+        	ss.ss_size = new_size;
+        	ss.ss_sp = altstack;
+        
+        	rc = sigaltstack(&ss, NULL);
+        	if (rc)
+        		fatal_error("sigaltstack failed: %d", rc);
+        
+        }
+        
+        static void print_abc(struct tile_buffer *tbuf)
+        {
+        	int i, j;
+        
+        	/* printf("Raw Buffer\n [");
+        	for (i = 0; i < MAX_ELEMENTS; i++)
+        		printf(" %u", tbuf->bytes[i]);
+        	printf(" ]\n\n");
+        	*/
+        
+        	printf("Matrix A:\n");
+        	for (i = 0; i < TILE_M; i++) {
+        		printf(" [");
+        		for (j = 0; j < TILE_K; j++)
+        			printf(" %03u", tbuf->a[(i * TILE_K) + j]);
+        		printf(" ]\n");
+        	}
+        	printf("\n");
+        
+        	printf("Matrix B:\n");
+        	for (i = 0; i < TILE_K; i++) {
+        		printf(" [");
+        		for (j = 0; j < TILE_N; j++)
+        			printf(" %03u", tbuf->b[(i * TILE_N) + j]);
+        		printf(" ]\n");
+        	}
+        	printf("\n");
+        
+        	printf("Matrix C:\n");
+        	for (i = 0; i < TILE_M; i++) {
+        		printf(" [");
+        		for (j = 0; j < TILE_N; j++)
+        			printf(" %06u", tbuf->c[(i * TILE_N) + j]);
+        		printf(" ]\n");
+        	}
+        	printf("\n");
+        }
+        
+        int main(void)
+        {
+        	struct tile_buffer *tile;
+        
+        	amx_check_cpuid();
+        	tile = aligned_alloc(64, MAX_ELEMENTS * BYTES_PER_ELEMENT);
+        	if (!tile)
+        		fatal_error("failed to allocate tile");
+        
+        	setup_sigaltstack();
+        
+        	/* Load tile configuration and tile data for matrices */
+        	request_perm_xtile_data();
+        	load_rand_tiledata(tile);
+        
+        	printf("\nA, B, C matrices before dot product:\n");
+        	print_abc(tile);
+        
+        	/* compute the dot product, store result in the memory for 'C' */
+        	mult_abc(tile);
+        
+        	/* print multiplication result */
+        	printf("\nA, B, C matrices after dot product (C = A . B):\n");
+        	print_abc(tile);
+        
+        	free(tile);
+        	printf("All done\n");
+        	return 0;
+        }