#!/usr/bin/env python3

import os

# XXX: vectorize for more platforms
# XXX: can integrate xor loops with int{bits}_sort
# XXX: for int{bits}_sortdown, can reverse the output array as an alternative

def preamble():
  f.write('/* WARNING: auto-generated (by autogen/useint); do not edit */\n\n')

  if vec == 'avx2':
    f.write(f'''#include <immintrin.h>
typedef __m256i {intIvec};
#define {intIvec}_load(z) _mm256_loadu_si256((__m256i *) (z))
#define {intIvec}_store(z,i) _mm256_storeu_si256((__m256i *) (z),(i))
#define {intIvec}_broadcast _mm256_set1_epi{bits}{'x' if bits == 64 else ''}
''')
    if what == 'float':
      if bits == 64:
        # avx2 does not have srai_epi64
        topbit = hex(1<<(bits-1))+'ULL'
        f.write(f'#define {intIvec}_floatmask(y) _mm256_sub_epi{bits}((y)&{intIvec}_broadcast({topbit}),_mm256_srli_epi{bits}(y,{bits-1}))\n')
      else:
        f.write(f'#define {intIvec}_floatmask(y) _mm256_srli_epi{bits}(_mm256_srai_epi{bits}(y,{bits-1}),1)\n')

  f.write(f'''#include "djbsort.h"
#include "{fun}_sort.h"
''')
  if what == 'float':
    f.write(f'#include "crypto_int{bits}.h"\n')

def intxor():
  if vec == 'avx2':
    f.write(f'''  long long j;
  {intIvec} vecxor = {intIvec}_broadcast({xor});
  for (j = 0;j+{2*vectorlen} <= n;j += {2*vectorlen}) {{
    {intIvec} x0 = {intIvec}_load(x+j);
    {intIvec} x1 = {intIvec}_load(x+j+{vectorlen});
    x0 ^= vecxor;
    x1 ^= vecxor;
    {intIvec}_store(x+j,x0);
    {intIvec}_store(x+j+{vectorlen},x1);
  }}
  for (;j < n;++j) x[j] ^= {xor};
  djbsort_int{bits}({cast}x,n);
  for (j = 0;j+{2*vectorlen} <= n;j += {2*vectorlen}) {{
    {intIvec} x0 = {intIvec}_load(x+j);
    {intIvec} x1 = {intIvec}_load(x+j+{vectorlen});
    x0 ^= vecxor;
    x1 ^= vecxor;
    {intIvec}_store(x+j,x0);
    {intIvec}_store(x+j+{vectorlen},x1);
  }}
  for (;j < n;++j) x[j] ^= {xor};
''')
  else:
    f.write(f'''  long long j;
  for (j = 0;j < n;++j) x[j] ^= {xor};
  djbsort_int{bits}({cast}x,n);
  for (j = 0;j < n;++j) x[j] ^= {xor};
''')

def floatxor():
  if vec == 'avx2':
    vecxordown = f' ^ {intIvec}_broadcast(-1)' if down == 'down' else ''
    f.write(f'''  int{bits}_t *y = (int{bits}_t *) x;
  long long j;

  for (j = 0;j+{2*vectorlen} <= n;j += {2*vectorlen}) {{
    {intIvec} y0 = {intIvec}_load(y+j);
    {intIvec} y1 = {intIvec}_load(y+j+{vectorlen});
    y0 ^= {intIvec}_floatmask(y0);
    y1 ^= {intIvec}_floatmask(y1);
    {intIvec}_store(y+j,y0{vecxordown});
    {intIvec}_store(y+j+{vectorlen},y1{vecxordown});
  }}
  for (;j < n;++j) {{
    int{bits}_t yj = y[j];
    yj ^= ((uint{bits}_t) crypto_int{bits}_negative_mask(yj)) >> 1;
    y[j] = yj{xordown};
  }}
  djbsort_int{bits}(y,n);
  for (j = 0;j+{2*vectorlen} <= n;j += {2*vectorlen}) {{
    {intIvec} y0 = {intIvec}_load(y+j){vecxordown};
    {intIvec} y1 = {intIvec}_load(y+j+{vectorlen}){vecxordown};
    y0 ^= {intIvec}_floatmask(y0);
    y1 ^= {intIvec}_floatmask(y1);
    {intIvec}_store(y+j,y0);
    {intIvec}_store(y+j+{vectorlen},y1);
  }}
  for (;j < n;++j) {{
    int{bits}_t yj = y[j]{xordown};
    yj ^= ((uint{bits}_t) crypto_int{bits}_negative_mask(yj)) >> 1;
    y[j] = yj;
  }}
''')
  else:
    f.write(f'''  int{bits}_t *y = (int{bits}_t *) x;
  long long j;

  for (j = 0;j < n;++j) {{
    int{bits}_t yj = y[j];
    yj ^= ((uint{bits}_t) crypto_int{bits}_negative_mask(yj)) >> 1;
    y[j] = yj{xordown};
  }}
  djbsort_int{bits}(y,n);
  for (j = 0;j < n;++j) {{
    int{bits}_t yj = y[j]{xordown};
    yj ^= ((uint{bits}_t) crypto_int{bits}_negative_mask(yj)) >> 1;
    y[j] = yj;
  }}
''')

for bits in 32,64:
  for down in '','down':
    for what in 'int','uint','float':
      for vec in '','avx2':
        if what == 'int' and down == '': continue

        fun = f'{what}{bits}{down}'
        impldir = f'{fun}/{vec}useint{bits}'

        T = f'{what}{bits}_t'
        if T == 'float32_t': T = 'float'
        if T == 'float64_t': T = 'double'

        vectorlen = 256//bits
        intIvec = f'int{bits}x{vectorlen}'

        os.makedirs(f'{impldir}',exist_ok=True)

        if vec == 'avx2':
          with open(f'{impldir}/architectures','w') as f:
            f.write('amd64 avx2\nx86 avx2\n')

        with open(f'{impldir}/sort.c','w') as f:
          preamble()

          f.write(f'''
void {fun}_sort({T} *x,long long n)
{{
''')

          if what == 'int':
            xor = -1
            cast = ''
            intxor()

          if what == 'uint':
            if down == 'down':
              xor = hex((1<<(bits-1))-1)
            else:
              xor = hex(1<<(bits-1))
            if bits == 64: xor += 'ULL'
            cast = f'(int{bits}_t *) '
            intxor()

          if what == 'float':
            xordown = ' ^ -1' if down == 'down' else ''
            floatxor()

          f.write('}\n')
