summaryrefslogtreecommitdiff
path: root/numpy/_build_utils/src/apple_sgemv_fix.c
blob: ffdfb81f705e42ab743f9e81a27841c3750f8d3f (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
/* This is a collection of ugly hacks to circumvent a bug in
 * Apple Accelerate framework's SGEMV subroutine.
 *
 * See: https://github.com/numpy/numpy/issues/4007
 *
 * SGEMV in Accelerate framework will segfault on MacOS X version 10.9
 * (aka Mavericks) if arrays are not aligned to 32 byte boundaries
 * and the CPU supports AVX instructions. This can produce segfaults
 * in np.dot.
 *
 * This patch overshadows the symbols cblas_sgemv, sgemv_ and sgemv
 * exported by Accelerate to produce the correct behavior. The MacOS X
 * version and CPU specs are checked on module import. If Mavericks and
 * AVX are detected the call to SGEMV is emulated with a call to SGEMM
 * if the arrays are not 32 byte aligned. If the exported symbols cannot
 * be overshadowed on module import, a fatal error is produced and the
 * process aborts. All the fixes are in a self-contained C file
 * and do not alter the multiarray C code. The patch is not applied
 * unless NumPy is configured to link with Apple's Accelerate
 * framework.
 *
 */

#define NPY_NO_DEPRECATED_API NPY_API_VERSION
#include "Python.h"
#include "numpy/arrayobject.h"

#include <string.h>
#include <dlfcn.h>
#include <stdlib.h>
#include <stdio.h>

/* ----------------------------------------------------------------- */
/* Original cblas_sgemv */

#define VECLIB_FILE "/System/Library/Frameworks/Accelerate.framework/Versions/A/Frameworks/vecLib.framework/vecLib"

enum CBLAS_ORDER {CblasRowMajor=101, CblasColMajor=102};
enum CBLAS_TRANSPOSE {CblasNoTrans=111, CblasTrans=112, CblasConjTrans=113};
extern void cblas_xerbla(int info, const char *rout, const char *form, ...);

typedef void cblas_sgemv_t(const enum CBLAS_ORDER order,
                 const enum CBLAS_TRANSPOSE TransA, const int M, const int N,
                 const float alpha, const float  *A, const int lda,
                 const float  *X, const int incX,
                 const float beta, float  *Y, const int incY);

typedef void cblas_sgemm_t(const enum CBLAS_ORDER order,
                 const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_TRANSPOSE TransB,
                 const int M, const int N, const int K,
                 const float alpha, const float  *A, const int lda,
                 const float  *B, const int ldb,
                 const float beta, float  *C, const int incC);

typedef void fortran_sgemv_t( const char* trans, const int* m, const int* n,
             const float* alpha, const float* A, const int* ldA,
             const float* X, const int* incX,
             const float* beta, float* Y, const int* incY );

static void *veclib = NULL;
static cblas_sgemv_t *accelerate_cblas_sgemv = NULL;
static cblas_sgemm_t *accelerate_cblas_sgemm = NULL;
static fortran_sgemv_t *accelerate_sgemv = NULL;
static int AVX_and_10_9 = 0;

/* Dynamic check for AVX support
 * __builtin_cpu_supports("avx") is available in gcc 4.8,
 * but clang and icc do not currently support it. */
#define cpu_supports_avx()\
(system("sysctl -n machdep.cpu.features | grep -q AVX") == 0)

/* Check if we are using MacOS X version 10.9 */
#define using_mavericks()\
(system("sw_vers -productVersion | grep -q 10\\.9\\.") == 0)

__attribute__((destructor))
static void unloadlib(void)
{
   if (veclib) dlclose(veclib);
}

__attribute__((constructor))
static void loadlib()
/* automatically executed on module import */
{
    char errormsg[1024];
    int AVX, MAVERICKS;
    memset((void*)errormsg, 0, sizeof(errormsg));
    /* check if the CPU supports AVX */
    AVX = cpu_supports_avx();
    /* check if the OS is MacOS X Mavericks */
    MAVERICKS = using_mavericks();
    /* we need the workaround when the CPU supports
     * AVX and the OS version is Mavericks */
    AVX_and_10_9 = AVX && MAVERICKS;
    /* load vecLib */
    veclib = dlopen(VECLIB_FILE, RTLD_LOCAL | RTLD_FIRST);
    if (!veclib) {
        veclib = NULL;
        snprintf(errormsg, sizeof(errormsg),
                 "Failed to open vecLib from location '%s'.", VECLIB_FILE);
        Py_FatalError(errormsg); /* calls abort() and dumps core */
    }
    /* resolve Fortran SGEMV from Accelerate */
    accelerate_sgemv = (fortran_sgemv_t*) dlsym(veclib, "sgemv_");
    if (!accelerate_sgemv) {
        unloadlib();
        Py_FatalError("Failed to resolve symbol 'sgemv_'.");
    }
    /* resolve cblas_sgemv from Accelerate */
    accelerate_cblas_sgemv = (cblas_sgemv_t*) dlsym(veclib, "cblas_sgemv");
    if (!accelerate_cblas_sgemv) {
        unloadlib();
        Py_FatalError("Failed to resolve symbol 'cblas_sgemv'.");
    }
    /* resolve cblas_sgemm from Accelerate */
    accelerate_cblas_sgemm = (cblas_sgemm_t*) dlsym(veclib, "cblas_sgemm");
    if (!accelerate_cblas_sgemm) {
        unloadlib();
        Py_FatalError("Failed to resolve symbol 'cblas_sgemm'.");
    }
}

/* ----------------------------------------------------------------- */
/* Fortran SGEMV override */

void sgemv_( const char* trans, const int* m, const int* n,
             const float* alpha, const float* A, const int* ldA,
             const float* X, const int* incX,
             const float* beta, float* Y, const int* incY )
{
    /* It is safe to use the original SGEMV if we are not using AVX on Mavericks
     * or the input arrays A, X and Y are all aligned on 32 byte boundaries. */
    #define BADARRAY(x) (((npy_intp)(void*)x) % 32)
    const int use_sgemm = AVX_and_10_9 && (BADARRAY(A) || BADARRAY(X) || BADARRAY(Y));
    if (!use_sgemm) {
        accelerate_sgemv(trans,m,n,alpha,A,ldA,X,incX,beta,Y,incY);
        return;
    }

    /* Arrays are misaligned, the CPU supports AVX, and we are running
     * Mavericks.
     *
     * Emulation of SGEMV with SGEMM:
     *
     * SGEMV allows vectors to be strided. SGEMM requires all arrays to be
     * contiguous along the leading dimension. To emulate striding in SGEMV
     * with the leading dimension arguments in SGEMM we compute
     *
     *    Y = alpha * op(A) @ X + beta * Y
     *
     * as
     *
     *    Y.T = alpha * X.T @ op(A).T + beta * Y.T
     *
     * Because Fortran uses column major order and X.T and Y.T are row vectors,
     * the leading dimensions of X.T and Y.T in SGEMM become equal to the
     * strides of the the column vectors X and Y in SGEMV. */

    switch (*trans) {
        case 'T':
        case 't':
        case 'C':
        case 'c':
            accelerate_cblas_sgemm( CblasColMajor, CblasNoTrans, CblasNoTrans,
                1, *n, *m, *alpha, X, *incX, A, *ldA, *beta, Y, *incY );
            break;
        case 'N':
        case 'n':
            accelerate_cblas_sgemm( CblasColMajor, CblasNoTrans, CblasTrans,
                1, *m, *n, *alpha, X, *incX, A, *ldA, *beta, Y, *incY );
            break;
        default:
            cblas_xerbla(1, "SGEMV", "Illegal transpose setting: %c\n", *trans);
    }
}

/* ----------------------------------------------------------------- */
/* Override for an alias symbol for sgemv_ in Accelerate */

void sgemv (char *trans,
            const int *m, const int *n,
            const float *alpha,
            const float *A, const int *lda,
            const float *B, const int *incB,
            const float *beta,
            float *C, const int *incC)
{
    sgemv_(trans,m,n,alpha,A,lda,B,incB,beta,C,incC);
}

/* ----------------------------------------------------------------- */
/* cblas_sgemv override, based on Netlib CBLAS code */

void cblas_sgemv(const enum CBLAS_ORDER order,
                 const enum CBLAS_TRANSPOSE TransA, const int M, const int N,
                 const float alpha, const float  *A, const int lda,
                 const float  *X, const int incX, const float beta,
                 float  *Y, const int incY)
{
   char TA;
   if (order == CblasColMajor)
   {
      if (TransA == CblasNoTrans) TA = 'N';
      else if (TransA == CblasTrans) TA = 'T';
      else if (TransA == CblasConjTrans) TA = 'C';
      else
      {
         cblas_xerbla(2, "cblas_sgemv","Illegal TransA setting, %d\n", TransA);
      }
      sgemv_(&TA, &M, &N, &alpha, A, &lda, X, &incX, &beta, Y, &incY);
   }
   else if (order == CblasRowMajor)
   {
      if (TransA == CblasNoTrans) TA = 'T';
      else if (TransA == CblasTrans) TA = 'N';
      else if (TransA == CblasConjTrans) TA = 'N';
      else
      {
         cblas_xerbla(2, "cblas_sgemv", "Illegal TransA setting, %d\n", TransA);
         return;
      }
      sgemv_(&TA, &N, &M, &alpha, A, &lda, X, &incX, &beta, Y, &incY);
   }
   else
      cblas_xerbla(1, "cblas_sgemv", "Illegal Order setting, %d\n", order);
}