diff options
-rw-r--r-- | buildtools/wafsamba/samba_third_party.py | 2 | ||||
-rw-r--r-- | third_party/socket_wrapper/socket_wrapper.c | 1304 | ||||
-rw-r--r-- | third_party/socket_wrapper/wscript | 4 |
3 files changed, 962 insertions, 348 deletions
diff --git a/buildtools/wafsamba/samba_third_party.py b/buildtools/wafsamba/samba_third_party.py index 9c894e4aed4..1144f813ab6 100644 --- a/buildtools/wafsamba/samba_third_party.py +++ b/buildtools/wafsamba/samba_third_party.py @@ -42,7 +42,7 @@ Build.BuildContext.CHECK_CMOCKA = CHECK_CMOCKA @conf def CHECK_SOCKET_WRAPPER(conf): - return conf.CHECK_BUNDLED_SYSTEM_PKG('socket_wrapper', minversion='1.1.7') + return conf.CHECK_BUNDLED_SYSTEM_PKG('socket_wrapper', minversion='1.1.9') Build.BuildContext.CHECK_SOCKET_WRAPPER = CHECK_SOCKET_WRAPPER @conf diff --git a/third_party/socket_wrapper/socket_wrapper.c b/third_party/socket_wrapper/socket_wrapper.c index 43b92f76f1e..539d27dbc9d 100644 --- a/third_party/socket_wrapper/socket_wrapper.c +++ b/third_party/socket_wrapper/socket_wrapper.c @@ -79,6 +79,7 @@ #ifdef HAVE_RPC_RPC_H #include <rpc/rpc.h> #endif +#include <pthread.h> enum swrap_dbglvl_e { SWRAP_LOG_ERROR = 0, @@ -94,12 +95,26 @@ enum swrap_dbglvl_e { #define PRINTF_ATTRIBUTE(a,b) #endif /* HAVE_FUNCTION_ATTRIBUTE_FORMAT */ +#ifdef HAVE_CONSTRUCTOR_ATTRIBUTE +#define CONSTRUCTOR_ATTRIBUTE __attribute__ ((constructor)) +#else +#define CONSTRUCTOR_ATTRIBUTE +#endif /* HAVE_CONSTRUCTOR_ATTRIBUTE */ + #ifdef HAVE_DESTRUCTOR_ATTRIBUTE #define DESTRUCTOR_ATTRIBUTE __attribute__ ((destructor)) #else #define DESTRUCTOR_ATTRIBUTE #endif +#ifndef FALL_THROUGH +# ifdef HAVE_FALLTHROUGH_ATTRIBUTE +# define FALL_THROUGH __attribute__ ((fallthrough)) +# else /* HAVE_FALLTHROUGH_ATTRIBUTE */ +# define FALL_THROUGH +# endif /* HAVE_FALLTHROUGH_ATTRIBUTE */ +#endif /* FALL_THROUGH */ + #ifdef HAVE_ADDRESS_SANITIZER_ATTRIBUTE #define DO_NOT_SANITIZE_ADDRESS_ATTRIBUTE __attribute__((no_sanitize_address)) #else @@ -135,6 +150,8 @@ enum swrap_dbglvl_e { #define discard_const_p(type, ptr) ((type *)discard_const(ptr)) #endif +#define UNUSED(x) (void)(x) + #ifdef IPV6_PKTINFO # ifndef IPV6_RECVPKTINFO # define IPV6_RECVPKTINFO IPV6_PKTINFO @@ -152,6 +169,22 @@ enum swrap_dbglvl_e { # endif #endif +/* Macros for accessing mutexes */ +# define SWRAP_LOCK(m) do { \ + pthread_mutex_lock(&(m ## _mutex)); \ +} while(0) + +# define SWRAP_UNLOCK(m) do { \ + pthread_mutex_unlock(&(m ## _mutex)); \ +} while(0) + +/* Add new global locks here please */ +# define SWRAP_LOCK_ALL \ + SWRAP_LOCK(libc_symbol_binding); \ + +# define SWRAP_UNLOCK_ALL \ + SWRAP_UNLOCK(libc_symbol_binding); \ + #define SWRAP_DLIST_ADD(list,item) do { \ if (!(list)) { \ @@ -184,6 +217,20 @@ enum swrap_dbglvl_e { (item)->next = NULL; \ } while (0) +#define SWRAP_DLIST_ADD_AFTER(list, item, el) \ +do { \ + if ((list) == NULL || (el) == NULL) { \ + SWRAP_DLIST_ADD(list, item); \ + } else { \ + (item)->prev = (el); \ + (item)->next = (el)->next; \ + (el)->next = (item); \ + if ((item)->next != NULL) { \ + (item)->next->prev = (item); \ + } \ + } \ +} while (0) + #if defined(HAVE_GETTIMEOFDAY_TZ) || defined(HAVE_GETTIMEOFDAY_TZ_VOID) #define swrapGetTimeOfDay(tval) gettimeofday(tval,NULL) #else @@ -212,10 +259,20 @@ enum swrap_dbglvl_e { #define SOCKET_MAX_SOCKETS 1024 + +/* + * Maximum number of socket_info structures that can + * be used. Can be overriden by the environment variable + * SOCKET_WRAPPER_MAX_SOCKETS. + */ +#define SOCKET_WRAPPER_MAX_SOCKETS_DEFAULT 65535 + +#define SOCKET_WRAPPER_MAX_SOCKETS_LIMIT 256000 + /* This limit is to avoid broadcast sendto() needing to stat too many * files. It may be raised (with a performance cost) to up to 254 * without changing the format above */ -#define MAX_WRAPPED_INTERFACES 40 +#define MAX_WRAPPED_INTERFACES 64 struct swrap_address { socklen_t sa_socklen; @@ -233,11 +290,21 @@ struct swrap_address { struct socket_info_fd { struct socket_info_fd *prev, *next; int fd; + + /* + * Points to corresponding index in array of + * socket_info structures + */ + int si_index; }; +int first_free; + struct socket_info { - struct socket_info_fd *fds; + unsigned int refcount; + + int next_free; int family; int type; @@ -261,24 +328,27 @@ struct socket_info unsigned long pck_snd; unsigned long pck_rcv; } io; - - struct socket_info *prev, *next; }; +static struct socket_info *sockets; +static size_t max_sockets = 0; + /* - * File descriptors are shared between threads so we should share socket - * information too. + * While socket file descriptors are passed among different processes, the + * numerical value gets changed. So its better to store it locally to each + * process rather than including it within socket_info which will be shared. */ -struct socket_info *sockets; +static struct socket_info_fd *socket_fds; + +/* The mutex for accessing the global libc.symbols */ +static pthread_mutex_t libc_symbol_binding_mutex = PTHREAD_MUTEX_INITIALIZER; /* Function prototypes */ bool socket_wrapper_enabled(void); -void swrap_destructor(void) DESTRUCTOR_ATTRIBUTE; -#ifdef NDEBUG -# define SWRAP_LOG(...) -#else +void swrap_constructor(void) CONSTRUCTOR_ATTRIBUTE; +void swrap_destructor(void) DESTRUCTOR_ATTRIBUTE; static void swrap_log(enum swrap_dbglvl_e dbglvl, const char *func, const char *format, ...) PRINTF_ATTRIBUTE(3, 4); # define SWRAP_LOG(dbglvl, ...) swrap_log((dbglvl), __func__, __VA_ARGS__) @@ -291,42 +361,40 @@ static void swrap_log(enum swrap_dbglvl_e dbglvl, va_list va; const char *d; unsigned int lvl = 0; + const char *prefix = "SWRAP"; d = getenv("SOCKET_WRAPPER_DEBUGLEVEL"); if (d != NULL) { lvl = atoi(d); } + if (lvl < dbglvl) { + return; + } + va_start(va, format); vsnprintf(buffer, sizeof(buffer), format, va); va_end(va); - if (lvl >= dbglvl) { - switch (dbglvl) { - case SWRAP_LOG_ERROR: - fprintf(stderr, - "SWRAP_ERROR(%d) - %s: %s\n", - (int)getpid(), func, buffer); - break; - case SWRAP_LOG_WARN: - fprintf(stderr, - "SWRAP_WARN(%d) - %s: %s\n", - (int)getpid(), func, buffer); - break; - case SWRAP_LOG_DEBUG: - fprintf(stderr, - "SWRAP_DEBUG(%d) - %s: %s\n", - (int)getpid(), func, buffer); - break; - case SWRAP_LOG_TRACE: - fprintf(stderr, - "SWRAP_TRACE(%d) - %s: %s\n", - (int)getpid(), func, buffer); - break; - } + switch (dbglvl) { + case SWRAP_LOG_ERROR: + prefix = "SWRAP_ERROR"; + break; + case SWRAP_LOG_WARN: + prefix = "SWRAP_WARN"; + break; + case SWRAP_LOG_DEBUG: + prefix = "SWRAP_DEBUG"; + break; + case SWRAP_LOG_TRACE: + prefix = "SWRAP_TRACE"; + break; } + + fprintf(stderr, + "%s(%d) - %s: %s\n", + prefix, (int)getpid(), func, buffer); } -#endif /********************************************************* * SWRAP LOADING LIBC FUNCTIONS @@ -334,91 +402,149 @@ static void swrap_log(enum swrap_dbglvl_e dbglvl, #include <dlfcn.h> -struct swrap_libc_fns { #ifdef HAVE_ACCEPT4 - int (*libc_accept4)(int sockfd, - struct sockaddr *addr, - socklen_t *addrlen, - int flags); +typedef int (*__libc_accept4)(int sockfd, + struct sockaddr *addr, + socklen_t *addrlen, + int flags); #else - int (*libc_accept)(int sockfd, - struct sockaddr *addr, - socklen_t *addrlen); +typedef int (*__libc_accept)(int sockfd, + struct sockaddr *addr, + socklen_t *addrlen); +#endif +typedef int (*__libc_bind)(int sockfd, + const struct sockaddr *addr, + socklen_t addrlen); +typedef int (*__libc_close)(int fd); +typedef int (*__libc_connect)(int sockfd, + const struct sockaddr *addr, + socklen_t addrlen); +typedef int (*__libc_dup)(int fd); +typedef int (*__libc_dup2)(int oldfd, int newfd); +typedef int (*__libc_fcntl)(int fd, int cmd, ...); +typedef FILE *(*__libc_fopen)(const char *name, const char *mode); +#ifdef HAVE_FOPEN64 +typedef FILE *(*__libc_fopen64)(const char *name, const char *mode); #endif - int (*libc_bind)(int sockfd, - const struct sockaddr *addr, - socklen_t addrlen); - int (*libc_close)(int fd); - int (*libc_connect)(int sockfd, - const struct sockaddr *addr, - socklen_t addrlen); - int (*libc_dup)(int fd); - int (*libc_dup2)(int oldfd, int newfd); - int (*libc_fcntl)(int fd, int cmd, ...); - FILE *(*libc_fopen)(const char *name, const char *mode); #ifdef HAVE_EVENTFD - int (*libc_eventfd)(int count, int flags); +typedef int (*__libc_eventfd)(int count, int flags); #endif - int (*libc_getpeername)(int sockfd, - struct sockaddr *addr, - socklen_t *addrlen); - int (*libc_getsockname)(int sockfd, - struct sockaddr *addr, - socklen_t *addrlen); - int (*libc_getsockopt)(int sockfd, +typedef int (*__libc_getpeername)(int sockfd, + struct sockaddr *addr, + socklen_t *addrlen); +typedef int (*__libc_getsockname)(int sockfd, + struct sockaddr *addr, + socklen_t *addrlen); +typedef int (*__libc_getsockopt)(int sockfd, int level, int optname, void *optval, socklen_t *optlen); - int (*libc_ioctl)(int d, unsigned long int request, ...); - int (*libc_listen)(int sockfd, int backlog); - int (*libc_open)(const char *pathname, int flags, mode_t mode); - int (*libc_pipe)(int pipefd[2]); - int (*libc_read)(int fd, void *buf, size_t count); - ssize_t (*libc_readv)(int fd, const struct iovec *iov, int iovcnt); - int (*libc_recv)(int sockfd, void *buf, size_t len, int flags); - int (*libc_recvfrom)(int sockfd, +typedef int (*__libc_ioctl)(int d, unsigned long int request, ...); +typedef int (*__libc_listen)(int sockfd, int backlog); +typedef int (*__libc_open)(const char *pathname, int flags, ...); +#ifdef HAVE_OPEN64 +typedef int (*__libc_open64)(const char *pathname, int flags, ...); +#endif /* HAVE_OPEN64 */ +typedef int (*__libc_openat)(int dirfd, const char *path, int flags, ...); +typedef int (*__libc_pipe)(int pipefd[2]); +typedef int (*__libc_read)(int fd, void *buf, size_t count); +typedef ssize_t (*__libc_readv)(int fd, const struct iovec *iov, int iovcnt); +typedef int (*__libc_recv)(int sockfd, void *buf, size_t len, int flags); +typedef int (*__libc_recvfrom)(int sockfd, void *buf, size_t len, int flags, struct sockaddr *src_addr, socklen_t *addrlen); - int (*libc_recvmsg)(int sockfd, const struct msghdr *msg, int flags); - int (*libc_send)(int sockfd, const void *buf, size_t len, int flags); - int (*libc_sendmsg)(int sockfd, const struct msghdr *msg, int flags); - int (*libc_sendto)(int sockfd, +typedef int (*__libc_recvmsg)(int sockfd, const struct msghdr *msg, int flags); +typedef int (*__libc_send)(int sockfd, const void *buf, size_t len, int flags); +typedef int (*__libc_sendmsg)(int sockfd, const struct msghdr *msg, int flags); +typedef int (*__libc_sendto)(int sockfd, const void *buf, size_t len, int flags, const struct sockaddr *dst_addr, socklen_t addrlen); - int (*libc_setsockopt)(int sockfd, +typedef int (*__libc_setsockopt)(int sockfd, int level, int optname, const void *optval, socklen_t optlen); #ifdef HAVE_SIGNALFD - int (*libc_signalfd)(int fd, const sigset_t *mask, int flags); +typedef int (*__libc_signalfd)(int fd, const sigset_t *mask, int flags); #endif - int (*libc_socket)(int domain, int type, int protocol); - int (*libc_socketpair)(int domain, int type, int protocol, int sv[2]); +typedef int (*__libc_socket)(int domain, int type, int protocol); +typedef int (*__libc_socketpair)(int domain, int type, int protocol, int sv[2]); #ifdef HAVE_TIMERFD_CREATE - int (*libc_timerfd_create)(int clockid, int flags); +typedef int (*__libc_timerfd_create)(int clockid, int flags); #endif - ssize_t (*libc_write)(int fd, const void *buf, size_t count); - ssize_t (*libc_writev)(int fd, const struct iovec *iov, int iovcnt); -}; +typedef ssize_t (*__libc_write)(int fd, const void *buf, size_t count); +typedef ssize_t (*__libc_writev)(int fd, const struct iovec *iov, int iovcnt); -struct swrap { - void *libc_handle; - void *libsocket_handle; +#define SWRAP_SYMBOL_ENTRY(i) \ + union { \ + __libc_##i f; \ + void *obj; \ + } _libc_##i - bool initialised; - bool enabled; - - char *socket_dir; +struct swrap_libc_symbols { +#ifdef HAVE_ACCEPT4 + SWRAP_SYMBOL_ENTRY(accept4); +#else + SWRAP_SYMBOL_ENTRY(accept); +#endif + SWRAP_SYMBOL_ENTRY(bind); + SWRAP_SYMBOL_ENTRY(close); + SWRAP_SYMBOL_ENTRY(connect); + SWRAP_SYMBOL_ENTRY(dup); + SWRAP_SYMBOL_ENTRY(dup2); + SWRAP_SYMBOL_ENTRY(fcntl); + SWRAP_SYMBOL_ENTRY(fopen); +#ifdef HAVE_FOPEN64 + SWRAP_SYMBOL_ENTRY(fopen64); +#endif +#ifdef HAVE_EVENTFD + SWRAP_SYMBOL_ENTRY(eventfd); +#endif + SWRAP_SYMBOL_ENTRY(getpeername); + SWRAP_SYMBOL_ENTRY(getsockname); + SWRAP_SYMBOL_ENTRY(getsockopt); + SWRAP_SYMBOL_ENTRY(ioctl); + SWRAP_SYMBOL_ENTRY(listen); + SWRAP_SYMBOL_ENTRY(open); +#ifdef HAVE_OPEN64 + SWRAP_SYMBOL_ENTRY(open64); +#endif + SWRAP_SYMBOL_ENTRY(openat); + SWRAP_SYMBOL_ENTRY(pipe); + SWRAP_SYMBOL_ENTRY(read); + SWRAP_SYMBOL_ENTRY(readv); + SWRAP_SYMBOL_ENTRY(recv); + SWRAP_SYMBOL_ENTRY(recvfrom); + SWRAP_SYMBOL_ENTRY(recvmsg); + SWRAP_SYMBOL_ENTRY(send); + SWRAP_SYMBOL_ENTRY(sendmsg); + SWRAP_SYMBOL_ENTRY(sendto); + SWRAP_SYMBOL_ENTRY(setsockopt); +#ifdef HAVE_SIGNALFD + SWRAP_SYMBOL_ENTRY(signalfd); +#endif + SWRAP_SYMBOL_ENTRY(socket); + SWRAP_SYMBOL_ENTRY(socketpair); +#ifdef HAVE_TIMERFD_CREATE + SWRAP_SYMBOL_ENTRY(timerfd_create); +#endif + SWRAP_SYMBOL_ENTRY(write); + SWRAP_SYMBOL_ENTRY(writev); +}; - struct swrap_libc_fns fns; +struct swrap { + struct { + void *handle; + void *socket_handle; + struct swrap_libc_symbols symbols; + } libc; }; static struct swrap swrap; @@ -434,7 +560,6 @@ enum swrap_lib { SWRAP_LIBSOCKET, }; -#ifndef NDEBUG static const char *swrap_str_lib(enum swrap_lib lib) { switch (lib) { @@ -449,7 +574,6 @@ static const char *swrap_str_lib(enum swrap_lib lib) /* Compiler would warn us about unhandled enum value if we get here */ return "unknown"; } -#endif static void *swrap_load_lib_handle(enum swrap_lib lib) { @@ -463,10 +587,10 @@ static void *swrap_load_lib_handle(enum swrap_lib lib) switch (lib) { case SWRAP_LIBNSL: - /* FALL TROUGH */ + FALL_THROUGH; case SWRAP_LIBSOCKET: #ifdef HAVE_LIBSOCKET - handle = swrap.libsocket_handle; + handle = swrap.libc.socket_handle; if (handle == NULL) { for (i = 10; i >= 0; i--) { char soname[256] = {0}; @@ -478,18 +602,18 @@ static void *swrap_load_lib_handle(enum swrap_lib lib) } } - swrap.libsocket_handle = handle; + swrap.libc.socket_handle = handle; } break; #endif - /* FALL TROUGH */ + FALL_THROUGH; case SWRAP_LIBC: - handle = swrap.libc_handle; + handle = swrap.libc.handle; #ifdef LIBC_SO if (handle == NULL) { handle = dlopen(LIBC_SO, flags); - swrap.libc_handle = handle; + swrap.libc.handle = handle; } #endif if (handle == NULL) { @@ -503,14 +627,14 @@ static void *swrap_load_lib_handle(enum swrap_lib lib) } } - swrap.libc_handle = handle; + swrap.libc.handle = handle; } break; } if (handle == NULL) { #ifdef RTLD_NEXT - handle = swrap.libc_handle = swrap.libsocket_handle = RTLD_NEXT; + handle = swrap.libc.handle = swrap.libc.socket_handle = RTLD_NEXT; #else SWRAP_LOG(SWRAP_LOG_ERROR, "Failed to dlopen library: %s\n", @@ -522,7 +646,7 @@ static void *swrap_load_lib_handle(enum swrap_lib lib) return handle; } -static void *_swrap_load_lib_function(enum swrap_lib lib, const char *fn_name) +static void *_swrap_bind_symbol(enum swrap_lib lib, const char *fn_name) { void *handle; void *func; @@ -532,51 +656,79 @@ static void *_swrap_load_lib_function(enum swrap_lib lib, const char *fn_name) func = dlsym(handle, fn_name); if (func == NULL) { SWRAP_LOG(SWRAP_LOG_ERROR, - "Failed to find %s: %s\n", - fn_name, dlerror()); + "Failed to find %s: %s\n", + fn_name, + dlerror()); exit(-1); } SWRAP_LOG(SWRAP_LOG_TRACE, - "Loaded %s from %s", - fn_name, swrap_str_lib(lib)); + "Loaded %s from %s", + fn_name, + swrap_str_lib(lib)); + return func; } -#define swrap_load_lib_function(lib, fn_name) \ - if (swrap.fns.libc_##fn_name == NULL) { \ - void *swrap_cast_ptr = _swrap_load_lib_function(lib, #fn_name); \ - *(void **) (&swrap.fns.libc_##fn_name) = \ - swrap_cast_ptr; \ +#define swrap_bind_symbol_libc(sym_name) \ + if (swrap.libc.symbols._libc_##sym_name.obj == NULL) { \ + SWRAP_LOCK(libc_symbol_binding); \ + if (swrap.libc.symbols._libc_##sym_name.obj == NULL) { \ + swrap.libc.symbols._libc_##sym_name.obj = \ + _swrap_bind_symbol(SWRAP_LIBC, #sym_name); \ + } \ + SWRAP_UNLOCK(libc_symbol_binding); \ } +#define swrap_bind_symbol_libsocket(sym_name) \ + if (swrap.libc.symbols._libc_##sym_name.obj == NULL) { \ + SWRAP_LOCK(libc_symbol_binding); \ + if (swrap.libc.symbols._libc_##sym_name.obj == NULL) { \ + swrap.libc.symbols._libc_##sym_name.obj = \ + _swrap_bind_symbol(SWRAP_LIBSOCKET, #sym_name); \ + } \ + SWRAP_UNLOCK(libc_symbol_binding); \ + } -/* - * IMPORTANT +#define swrap_bind_symbol_libnsl(sym_name) \ + if (swrap.libc.symbols._libc_##sym_name.obj == NULL) { \ + SWRAP_LOCK(libc_symbol_binding); \ + if (swrap.libc.symbols._libc_##sym_name.obj == NULL) { \ + swrap.libc.symbols._libc_##sym_name.obj = \ + _swrap_bind_symbol(SWRAP_LIBNSL, #sym_name); \ + } \ + SWRAP_UNLOCK(libc_symbol_binding); \ + } + +/**************************************************************************** + * IMPORTANT + **************************************************************************** * - * Functions especially from libc need to be loaded individually, you can't load - * all at once or gdb will segfault at startup. The same applies to valgrind and - * has probably something todo with with the linker. - * So we need load each function at the point it is called the first time. - */ + * Functions especially from libc need to be loaded individually, you can't + * load all at once or gdb will segfault at startup. The same applies to + * valgrind and has probably something todo with with the linker. So we need + * load each function at the point it is called the first time. + * + ****************************************************************************/ + #ifdef HAVE_ACCEPT4 static int libc_accept4(int sockfd, struct sockaddr *addr, socklen_t *addrlen, int flags) { - swrap_load_lib_function(SWRAP_LIBSOCKET, accept4); + swrap_bind_symbol_libsocket(accept4); - return swrap.fns.libc_accept4(sockfd, addr, addrlen, flags); + return swrap.libc.symbols._libc_accept4.f(sockfd, addr, addrlen, flags); } #else /* HAVE_ACCEPT4 */ static int libc_accept(int sockfd, struct sockaddr *addr, socklen_t *addrlen) { - swrap_load_lib_function(SWRAP_LIBSOCKET, accept); + swrap_bind_symbol_libsocket(accept); - return swrap.fns.libc_accept(sockfd, addr, addrlen); + return swrap.libc.symbols._libc_accept.f(sockfd, addr, addrlen); } #endif /* HAVE_ACCEPT4 */ @@ -584,69 +736,61 @@ static int libc_bind(int sockfd, const struct sockaddr *addr, socklen_t addrlen) { - swrap_load_lib_function(SWRAP_LIBSOCKET, bind); + swrap_bind_symbol_libsocket(bind); - return swrap.fns.libc_bind(sockfd, addr, addrlen); + return swrap.libc.symbols._libc_bind.f(sockfd, addr, addrlen); } static int libc_close(int fd) { - swrap_load_lib_function(SWRAP_LIBC, close); + swrap_bind_symbol_libc(close); - return swrap.fns.libc_close(fd); + return swrap.libc.symbols._libc_close.f(fd); } static int libc_connect(int sockfd, const struct sockaddr *addr, socklen_t addrlen) { - swrap_load_lib_function(SWRAP_LIBSOCKET, connect); + swrap_bind_symbol_libsocket(connect); - return swrap.fns.libc_connect(sockfd, addr, addrlen); + return swrap.libc.symbols._libc_connect.f(sockfd, addr, addrlen); } static int libc_dup(int fd) { - swrap_load_lib_function(SWRAP_LIBC, dup); + swrap_bind_symbol_libc(dup); - return swrap.fns.libc_dup(fd); + return swrap.libc.symbols._libc_dup.f(fd); } static int libc_dup2(int oldfd, int newfd) { - swrap_load_lib_function(SWRAP_LIBC, dup2); + swrap_bind_symbol_libc(dup2); - return swrap.fns.libc_dup2(oldfd, newfd); + return swrap.libc.symbols._libc_dup2.f(oldfd, newfd); } #ifdef HAVE_EVENTFD static int libc_eventfd(int count, int flags) { - swrap_load_lib_function(SWRAP_LIBC, eventfd); + swrap_bind_symbol_libc(eventfd); - return swrap.fns.libc_eventfd(count, flags); + return swrap.libc.symbols._libc_eventfd.f(count, flags); } #endif DO_NOT_SANITIZE_ADDRESS_ATTRIBUTE static int libc_vfcntl(int fd, int cmd, va_list ap) { - long int args[4]; + void *arg; int rc; - int i; - swrap_load_lib_function(SWRAP_LIBC, fcntl); + swrap_bind_symbol_libc(fcntl); - for (i = 0; i < 4; i++) { - args[i] = va_arg(ap, long int); - } + arg = va_arg(ap, void *); - rc = swrap.fns.libc_fcntl(fd, - cmd, - args[0], - args[1], - args[2], - args[3]); + rc = swrap.libc.symbols._libc_fcntl.f(fd, cmd, arg); return rc; } @@ -655,18 +799,18 @@ static int libc_getpeername(int sockfd, struct sockaddr *addr, socklen_t *addrlen) { - swrap_load_lib_function(SWRAP_LIBSOCKET, getpeername); + swrap_bind_symbol_libsocket(getpeername); - return swrap.fns.libc_getpeername(sockfd, addr, addrlen); + return swrap.libc.symbols._libc_getpeername.f(sockfd, addr, addrlen); } static int libc_getsockname(int sockfd, struct sockaddr *addr, socklen_t *addrlen) { - swrap_load_lib_function(SWRAP_LIBSOCKET, getsockname); + swrap_bind_symbol_libsocket(getsockname); - return swrap.fns.libc_getsockname(sockfd, addr, addrlen); + return swrap.libc.symbols._libc_getsockname.f(sockfd, addr, addrlen); } static int libc_getsockopt(int sockfd, @@ -675,58 +819,64 @@ static int libc_getsockopt(int sockfd, void *optval, socklen_t *optlen) { - swrap_load_lib_function(SWRAP_LIBSOCKET, getsockopt); + swrap_bind_symbol_libsocket(getsockopt); - return swrap.fns.libc_getsockopt(sockfd, level, optname, optval, optlen); + return swrap.libc.symbols._libc_getsockopt.f(sockfd, + level, + optname, + optval, + optlen); } DO_NOT_SANITIZE_ADDRESS_ATTRIBUTE static int libc_vioctl(int d, unsigned long int request, va_list ap) { - long int args[4]; + void *arg; int rc; - int i; - swrap_load_lib_function(SWRAP_LIBC, ioctl); + swrap_bind_symbol_libc(ioctl); - for (i = 0; i < 4; i++) { - args[i] = va_arg(ap, long int); - } + arg = va_arg(ap, void *); - rc = swrap.fns.libc_ioctl(d, - request, - args[0], - args[1], - args[2], - args[3]); + rc = swrap.libc.symbols._libc_ioctl.f(d, request, arg); return rc; } static int libc_listen(int sockfd, int backlog) { - swrap_load_lib_function(SWRAP_LIBSOCKET, listen); + swrap_bind_symbol_libsocket(listen); - return swrap.fns.libc_listen(sockfd, backlog); + return swrap.libc.symbols._libc_listen.f(sockfd, backlog); } static FILE *libc_fopen(const char *name, const char *mode) { - swrap_load_lib_function(SWRAP_LIBC, fopen); + swrap_bind_symbol_libc(fopen); - return swrap.fns.libc_fopen(name, mode); + return swrap.libc.symbols._libc_fopen.f(name, mode); } +#ifdef HAVE_FOPEN64 +static FILE *libc_fopen64(const char *name, const char *mode) +{ + swrap_bind_symbol_libc(fopen64); + + return swrap.libc.symbols._libc_fopen64.f(name, mode); +} +#endif /* HAVE_FOPEN64 */ + static int libc_vopen(const char *pathname, int flags, va_list ap) { - long int mode = 0; + int mode = 0; int fd; - swrap_load_lib_function(SWRAP_LIBC, open); - - mode = va_arg(ap, long int); + swrap_bind_symbol_libc(open); - fd = swrap.fns.libc_open(pathname, flags, (mode_t)mode); + if (flags & O_CREAT) { + mode = va_arg(ap, int); + } + fd = swrap.libc.symbols._libc_open.f(pathname, flags, (mode_t)mode); return fd; } @@ -743,32 +893,81 @@ static int libc_open(const char *pathname, int flags, ...) return fd; } +#ifdef HAVE_OPEN64 +static int libc_vopen64(const char *pathname, int flags, va_list ap) +{ + int mode = 0; + int fd; + + swrap_bind_symbol_libc(open64); + + if (flags & O_CREAT) { + mode = va_arg(ap, int); + } + fd = swrap.libc.symbols._libc_open64.f(pathname, flags, (mode_t)mode); + + return fd; +} +#endif /* HAVE_OPEN64 */ + +static int libc_vopenat(int dirfd, const char *path, int flags, va_list ap) +{ + int mode = 0; + int fd; + + swrap_bind_symbol_libc(openat); + + if (flags & O_CREAT) { + mode = va_arg(ap, int); + } + fd = swrap.libc.symbols._libc_openat.f(dirfd, + path, + flags, + (mode_t)mode); + + return fd; +} + +#if 0 +static int libc_openat(int dirfd, const char *path, int flags, ...) +{ + va_list ap; + int fd; + + va_start(ap, flags); + fd = libc_vopenat(dirfd, path, flags, ap); + va_end(ap); + + return fd; +} +#endif + static int libc_pipe(int pipefd[2]) { - swrap_load_lib_function(SWRAP_LIBSOCKET, pipe); + swrap_bind_symbol_libsocket(pipe); - return swrap.fns.libc_pipe(pipefd); + return swrap.libc.symbols._libc_pipe.f(pipefd); } static int libc_read(int fd, void *buf, size_t count) { - swrap_load_lib_function(SWRAP_LIBC, read); + swrap_bind_symbol_libc(read); - return swrap.fns.libc_read(fd, buf, count); + return swrap.libc.symbols._libc_read.f(fd, buf, count); } static ssize_t libc_readv(int fd, const struct iovec *iov, int iovcnt) { - swrap_load_lib_function(SWRAP_LIBSOCKET, readv); + swrap_bind_symbol_libsocket(readv); - return swrap.fns.libc_readv(fd, iov, iovcnt); + return swrap.libc.symbols._libc_readv.f(fd, iov, iovcnt); } static int libc_recv(int sockfd, void *buf, size_t len, int flags) { - swrap_load_lib_function(SWRAP_LIBSOCKET, recv); + swrap_bind_symbol_libsocket(recv); - return swrap.fns.libc_recv(sockfd, buf, len, flags); + return swrap.libc.symbols._libc_recv.f(sockfd, buf, len, flags); } static int libc_recvfrom(int sockfd, @@ -778,30 +977,35 @@ static int libc_recvfrom(int sockfd, struct sockaddr *src_addr, socklen_t *addrlen) { - swrap_load_lib_function(SWRAP_LIBSOCKET, recvfrom); + swrap_bind_symbol_libsocket(recvfrom); - return swrap.fns.libc_recvfrom(sockfd, buf, len, flags, src_addr, addrlen); + return swrap.libc.symbols._libc_recvfrom.f(sockfd, + buf, + len, + flags, + src_addr, + addrlen); } static int libc_recvmsg(int sockfd, struct msghdr *msg, int flags) { - swrap_load_lib_function(SWRAP_LIBSOCKET, recvmsg); + swrap_bind_symbol_libsocket(recvmsg); - return swrap.fns.libc_recvmsg(sockfd, msg, flags); + return swrap.libc.symbols._libc_recvmsg.f(sockfd, msg, flags); } static int libc_send(int sockfd, const void *buf, size_t len, int flags) { - swrap_load_lib_function(SWRAP_LIBSOCKET, send); + swrap_bind_symbol_libsocket(send); - return swrap.fns.libc_send(sockfd, buf, len, flags); + return swrap.libc.symbols._libc_send.f(sockfd, buf, len, flags); } static int libc_sendmsg(int sockfd, const struct msghdr *msg, int flags) { - swrap_load_lib_function(SWRAP_LIBSOCKET, sendmsg); + swrap_bind_symbol_libsocket(sendmsg); - return swrap.fns.libc_sendmsg(sockfd, msg, flags); + return swrap.libc.symbols._libc_sendmsg.f(sockfd, msg, flags); } static int libc_sendto(int sockfd, @@ -811,9 +1015,14 @@ static int libc_sendto(int sockfd, const struct sockaddr *dst_addr, socklen_t addrlen) { - swrap_load_lib_function(SWRAP_LIBSOCKET, sendto); + swrap_bind_symbol_libsocket(sendto); - return swrap.fns.libc_sendto(sockfd, buf, len, flags, dst_addr, addrlen); + return swrap.libc.symbols._libc_sendto.f(sockfd, + buf, + len, + flags, + dst_addr, + addrlen); } static int libc_setsockopt(int sockfd, @@ -822,55 +1031,112 @@ static int libc_setsockopt(int sockfd, const void *optval, socklen_t optlen) { - swrap_load_lib_function(SWRAP_LIBSOCKET, setsockopt); + swrap_bind_symbol_libsocket(setsockopt); - return swrap.fns.libc_setsockopt(sockfd, level, optname, optval, optlen); + return swrap.libc.symbols._libc_setsockopt.f(sockfd, + level, + optname, + optval, + optlen); } #ifdef HAVE_SIGNALFD static int libc_signalfd(int fd, const sigset_t *mask, int flags) { - swrap_load_lib_function(SWRAP_LIBSOCKET, signalfd); + swrap_bind_symbol_libsocket(signalfd); - return swrap.fns.libc_signalfd(fd, mask, flags); + return swrap.libc.symbols._libc_signalfd.f(fd, mask, flags); } #endif static int libc_socket(int domain, int type, int protocol) { - swrap_load_lib_function(SWRAP_LIBSOCKET, socket); + swrap_bind_symbol_libsocket(socket); - return swrap.fns.libc_socket(domain, type, protocol); + return swrap.libc.symbols._libc_socket.f(domain, type, protocol); } static int libc_socketpair(int domain, int type, int protocol, int sv[2]) { - swrap_load_lib_function(SWRAP_LIBSOCKET, socketpair); + swrap_bind_symbol_libsocket(socketpair); - return swrap.fns.libc_socketpair(domain, type, protocol, sv); + return swrap.libc.symbols._libc_socketpair.f(domain, type, protocol, sv); } #ifdef HAVE_TIMERFD_CREATE static int libc_timerfd_create(int clockid, int flags) { - swrap_load_lib_function(SWRAP_LIBC, timerfd_create); + swrap_bind_symbol_libc(timerfd_create); - return swrap.fns.libc_timerfd_create(clockid, flags); + return swrap.libc.symbols._libc_timerfd_create.f(clockid, flags); } #endif static ssize_t libc_write(int fd, const void *buf, size_t count) { - swrap_load_lib_function(SWRAP_LIBC, write); + swrap_bind_symbol_libc(write); - return swrap.fns.libc_write(fd, buf, count); + return swrap.libc.symbols._libc_write.f(fd, buf, count); } static ssize_t libc_writev(int fd, const struct iovec *iov, int iovcnt) { - swrap_load_lib_function(SWRAP_LIBSOCKET, writev); + swrap_bind_symbol_libsocket(writev); + + return swrap.libc.symbols._libc_writev.f(fd, iov, iovcnt); +} - return swrap.fns.libc_writev(fd, iov, iovcnt); +/* DO NOT call this function during library initialization! */ +static void swrap_bind_symbol_all(void) +{ +#ifdef HAVE_ACCEPT4 + swrap_bind_symbol_libsocket(accept4); +#else + swrap_bind_symbol_libsocket(accept); +#endif + swrap_bind_symbol_libsocket(bind); + swrap_bind_symbol_libc(close); + swrap_bind_symbol_libsocket(connect); + swrap_bind_symbol_libc(dup); + swrap_bind_symbol_libc(dup2); + swrap_bind_symbol_libc(fcntl); + swrap_bind_symbol_libc(fopen); +#ifdef HAVE_FOPEN64 + swrap_bind_symbol_libc(fopen64); +#endif +#ifdef HAVE_EVENTFD + swrap_bind_symbol_libc(eventfd); +#endif + swrap_bind_symbol_libsocket(getpeername); + swrap_bind_symbol_libsocket(getsockname); + swrap_bind_symbol_libsocket(getsockopt); + swrap_bind_symbol_libc(ioctl); + swrap_bind_symbol_libsocket(listen); + swrap_bind_symbol_libc(open); +#ifdef HAVE_OPEN64 + swrap_bind_symbol_libc(open64); +#endif + swrap_bind_symbol_libc(openat); + swrap_bind_symbol_libsocket(pipe); + swrap_bind_symbol_libc(read); + swrap_bind_symbol_libsocket(readv); + swrap_bind_symbol_libsocket(recv); + swrap_bind_symbol_libsocket(recvfrom); + swrap_bind_symbol_libsocket(recvmsg); + swrap_bind_symbol_libsocket(send); + swrap_bind_symbol_libsocket(sendmsg); + swrap_bind_symbol_libsocket(sendto); + swrap_bind_symbol_libsocket(setsockopt); +#ifdef HAVE_SIGNALFD + swrap_bind_symbol_libsocket(signalfd); +#endif + swrap_bind_symbol_libsocket(socket); + swrap_bind_symbol_libsocket(socketpair); +#ifdef HAVE_TIMERFD_CREATE + swrap_bind_symbol_libc(timerfd_create); +#endif + swrap_bind_symbol_libc(write); + swrap_bind_symbol_libsocket(writev); } /********************************************************* @@ -975,11 +1241,78 @@ done: return max_mtu; } +static size_t socket_wrapper_max_sockets(void) +{ + const char *s; + unsigned long tmp; + char *endp; + + if (max_sockets != 0) { + return max_sockets; + } + + max_sockets = SOCKET_WRAPPER_MAX_SOCKETS_DEFAULT; + + s = getenv("SOCKET_WRAPPER_MAX_SOCKETS"); + if (s == NULL || s[0] == '\0') { + goto done; + } + + tmp = strtoul(s, &endp, 10); + if (s == endp) { + goto done; + } + if (tmp == 0 || tmp > SOCKET_WRAPPER_MAX_SOCKETS_LIMIT) { + SWRAP_LOG(SWRAP_LOG_ERROR, + "Invalid number of sockets specified, using default."); + goto done; + } + + max_sockets = tmp; + +done: + return max_sockets; +} + +static void socket_wrapper_init_sockets(void) +{ + size_t i; + + if (sockets != NULL) { + return; + } + + max_sockets = socket_wrapper_max_sockets(); + + sockets = (struct socket_info *)calloc(max_sockets, + sizeof(struct socket_info)); + + if (sockets == NULL) { + SWRAP_LOG(SWRAP_LOG_ERROR, + "Failed to allocate sockets array.\n"); + exit(-1); + } + + first_free = 0; + + for (i = 0; i < max_sockets; i++) { + sockets[i].next_free = i+1; + } + + sockets[max_sockets-1].next_free = -1; +} + bool socket_wrapper_enabled(void) { const char *s = socket_wrapper_dir(); - return s != NULL ? true : false; + if (s == NULL) { + return false; + } + + socket_wrapper_init_sockets(); + + return true; } static unsigned int socket_wrapper_default_iface(void) @@ -997,6 +1330,25 @@ static unsigned int socket_wrapper_default_iface(void) return 1;/* 127.0.0.1 */ } +/* + * Return the first free entry (if any) and make + * it re-usable again (by nulling it out) + */ +static int socket_wrapper_first_free_index(void) +{ + int next_free; + + if (first_free == -1) { + return -1; + } + + next_free = sockets[first_free].next_free; + ZERO_STRUCT(sockets[first_free]); + sockets[first_free].next_free = next_free; + + return first_free; +} + static int convert_un_in(const struct sockaddr_un *un, struct sockaddr *in, socklen_t *len) { unsigned int iface; @@ -1364,26 +1716,46 @@ static int convert_in_un_alloc(struct socket_info *si, const struct sockaddr *in return 0; } -static struct socket_info *find_socket_info(int fd) +static struct socket_info_fd *find_socket_info_fd(int fd) { - struct socket_info *i; + struct socket_info_fd *f; - for (i = sockets; i; i = i->next) { - struct socket_info_fd *f; - for (f = i->fds; f; f = f->next) { - if (f->fd == fd) { - return i; - } + for (f = socket_fds; f; f = f->next) { + if (f->fd == fd) { + return f; } } return NULL; } +static int find_socket_info_index(int fd) +{ + struct socket_info_fd *fi = find_socket_info_fd(fd); + + if (fi == NULL) { + return -1; + } + + return fi->si_index; +} + +static struct socket_info *find_socket_info(int fd) +{ + int idx = find_socket_info_index(fd); + + if (idx == -1) { + return NULL; + } + + return &sockets[idx]; +} + #if 0 /* FIXME */ static bool check_addr_port_in_use(const struct sockaddr *sa, socklen_t len) { - struct socket_info *s; + struct socket_info_fd *f; + const struct socket_info *last_s = NULL; /* first catch invalid input */ switch (sa->sa_family) { @@ -1404,7 +1776,14 @@ static bool check_addr_port_in_use(const struct sockaddr *sa, socklen_t len) break; } - for (s = sockets; s != NULL; s = s->next) { + for (f = socket_fds; f; f = f->next) { + struct socket_info *s = &sockets[f->si_index]; + + if (s == last_s) { + continue; + } + last_s = s; + if (s->myname == NULL) { continue; } @@ -1466,27 +1845,33 @@ static bool check_addr_port_in_use(const struct sockaddr *sa, socklen_t len) static void swrap_remove_stale(int fd) { - struct socket_info *si = find_socket_info(fd); - struct socket_info_fd *fi; + struct socket_info_fd *fi = find_socket_info_fd(fd); + struct socket_info *si; + int si_index; - if (si != NULL) { - for (fi = si->fds; fi; fi = fi->next) { - if (fi->fd == fd) { - SWRAP_LOG(SWRAP_LOG_TRACE, "remove stale wrapper for %d", fd); - SWRAP_DLIST_REMOVE(si->fds, fi); - free(fi); - break; - } - } + if (fi == NULL) { + return; + } - if (si->fds == NULL) { - SWRAP_DLIST_REMOVE(sockets, si); - if (si->un_addr.sun_path[0] != '\0') { - unlink(si->un_addr.sun_path); - } - free(si); - } + si_index = fi->si_index; + + SWRAP_LOG(SWRAP_LOG_TRACE, "remove stale wrapper for %d", fd); + SWRAP_DLIST_REMOVE(socket_fds, fi); + free(fi); + + si = &sockets[si_index]; + si->refcount--; + + if (si->refcount > 0) { + return; + } + + if (si->un_addr.sun_path[0] != '\0') { + unlink(si->un_addr.sun_path); } + + si->next_free = first_free; + first_free = si_index; } static int sockaddr_convert_to_un(struct socket_info *si, @@ -1528,7 +1913,7 @@ static int sockaddr_convert_to_un(struct socket_info *si, * AF_UNSPEC is mapped to AF_INET and must be treated here. */ - /* FALL THROUGH */ + FALL_THROUGH; } case AF_INET: #ifdef HAVE_IPV6 @@ -2013,7 +2398,9 @@ static int swrap_pcap_get_fd(const char *fname) { static int fd = -1; - if (fd != -1) return fd; + if (fd != -1) { + return fd; + } fd = libc_open(fname, O_WRONLY|O_CREAT|O_EXCL|O_APPEND, 0644); if (fd != -1) { @@ -2066,7 +2453,9 @@ static uint8_t *swrap_pcap_marshall_packet(struct socket_info *si, switch (type) { case SWRAP_CONNECT_SEND: - if (si->type != SOCK_STREAM) return NULL; + if (si->type != SOCK_STREAM) { + return NULL; + } src_addr = &si->myname.sa.s; dest_addr = addr; @@ -2080,7 +2469,9 @@ static uint8_t *swrap_pcap_marshall_packet(struct socket_info *si, break; case SWRAP_CONNECT_RECV: - if (si->type != SOCK_STREAM) return NULL; + if (si->type != SOCK_STREAM) { + return NULL; + } dest_addr = &si->myname.sa.s; src_addr = addr; @@ -2094,7 +2485,9 @@ static uint8_t *swrap_pcap_marshall_packet(struct socket_info *si, break; case SWRAP_CONNECT_UNREACH: - if (si->type != SOCK_STREAM) return NULL; + if (si->type != SOCK_STREAM) { + return NULL; + } dest_addr = &si->myname.sa.s; src_addr = addr; @@ -2108,7 +2501,9 @@ static uint8_t *swrap_pcap_marshall_packet(struct socket_info *si, break; case SWRAP_CONNECT_ACK: - if (si->type != SOCK_STREAM) return NULL; + if (si->type != SOCK_STREAM) { + return NULL; + } src_addr = &si->myname.sa.s; dest_addr = addr; @@ -2120,7 +2515,9 @@ static uint8_t *swrap_pcap_marshall_packet(struct socket_info *si, break; case SWRAP_ACCEPT_SEND: - if (si->type != SOCK_STREAM) return NULL; + if (si->type != SOCK_STREAM) { + return NULL; + } dest_addr = &si->myname.sa.s; src_addr = addr; @@ -2134,7 +2531,9 @@ static uint8_t *swrap_pcap_marshall_packet(struct socket_info *si, break; case SWRAP_ACCEPT_RECV: - if (si->type != SOCK_STREAM) return NULL; + if (si->type != SOCK_STREAM) { + return NULL; + } src_addr = &si->myname.sa.s; dest_addr = addr; @@ -2148,7 +2547,9 @@ static uint8_t *swrap_pcap_marshall_packet(struct socket_info *si, break; case SWRAP_ACCEPT_ACK: - if (si->type != SOCK_STREAM) return NULL; + if (si->type != SOCK_STREAM) { + return NULL; + } dest_addr = &si->myname.sa.s; src_addr = addr; @@ -2255,7 +2656,9 @@ static uint8_t *swrap_pcap_marshall_packet(struct socket_info *si, break; case SWRAP_CLOSE_SEND: - if (si->type != SOCK_STREAM) return NULL; + if (si->type != SOCK_STREAM) { + return NULL; + } src_addr = &si->myname.sa.s; dest_addr = &si->peername.sa.s; @@ -2269,7 +2672,9 @@ static uint8_t *swrap_pcap_marshall_packet(struct socket_info *si, break; case SWRAP_CLOSE_RECV: - if (si->type != SOCK_STREAM) return NULL; + if (si->type != SOCK_STREAM) { + return NULL; + } dest_addr = &si->myname.sa.s; src_addr = &si->peername.sa.s; @@ -2283,7 +2688,9 @@ static uint8_t *swrap_pcap_marshall_packet(struct socket_info *si, break; case SWRAP_CLOSE_ACK: - if (si->type != SOCK_STREAM) return NULL; + if (si->type != SOCK_STREAM) { + return NULL; + } src_addr = &si->myname.sa.s; dest_addr = &si->peername.sa.s; @@ -2380,6 +2787,7 @@ static int swrap_socket(int family, int type, int protocol) struct socket_info *si; struct socket_info_fd *fi; int fd; + int idx; int real_type = type; /* @@ -2434,12 +2842,12 @@ static int swrap_socket(int family, int type, int protocol) if (real_type == SOCK_STREAM) { break; } - /*fall through*/ + FALL_THROUGH; case 17: if (real_type == SOCK_DGRAM) { break; } - /*fall through*/ + FALL_THROUGH; default: errno = EPROTONOSUPPORT; return -1; @@ -2456,17 +2864,16 @@ static int swrap_socket(int family, int type, int protocol) } /* Check if we have a stale fd and remove it */ - si = find_socket_info(fd); - if (si != NULL) { - swrap_remove_stale(fd); - } + swrap_remove_stale(fd); - si = (struct socket_info *)calloc(1, sizeof(struct socket_info)); - if (si == NULL) { + idx = socket_wrapper_first_free_index(); + if (idx == -1) { errno = ENOMEM; return -1; } + si = &sockets[idx]; + si->family = family; /* however, the rest of the socket_wrapper code expects just @@ -2498,22 +2905,24 @@ static int swrap_socket(int family, int type, int protocol) break; } default: - free(si); errno = EINVAL; return -1; } fi = (struct socket_info_fd *)calloc(1, sizeof(struct socket_info_fd)); if (fi == NULL) { - free(si); errno = ENOMEM; return -1; } + si->refcount = 1; + first_free = si->next_free; + si->next_free = 0; + fi->fd = fd; + fi->si_index = idx; - SWRAP_DLIST_ADD(si->fds, fi); - SWRAP_DLIST_ADD(sockets, si); + SWRAP_DLIST_ADD(socket_fds, fi); SWRAP_LOG(SWRAP_LOG_TRACE, "Created %s socket for protocol %s", @@ -2607,6 +3016,7 @@ static int swrap_accept(int s, struct socket_info *parent_si, *child_si; struct socket_info_fd *child_fi; int fd; + int idx; struct swrap_address un_addr = { .sa_socklen = sizeof(struct sockaddr_un), }; @@ -2626,6 +3036,7 @@ static int swrap_accept(int s, #ifdef HAVE_ACCEPT4 return libc_accept4(s, addr, addrlen, flags); #else + UNUSED(flags); return libc_accept(s, addr, addrlen); #endif } @@ -2643,6 +3054,7 @@ static int swrap_accept(int s, #ifdef HAVE_ACCEPT4 ret = libc_accept4(s, &un_addr.sa.s, &un_addr.sa_socklen, flags); #else + UNUSED(flags); ret = libc_accept(s, &un_addr.sa.s, &un_addr.sa_socklen); #endif if (ret == -1) { @@ -2666,16 +3078,16 @@ static int swrap_accept(int s, return ret; } - child_si = (struct socket_info *)calloc(1, sizeof(struct socket_info)); - if (child_si == NULL) { - close(fd); + idx = socket_wrapper_first_free_index(); + if (idx == -1) { errno = ENOMEM; return -1; } + child_si = &sockets[idx]; + child_fi = (struct socket_info_fd *)calloc(1, sizeof(struct socket_info_fd)); if (child_fi == NULL) { - free(child_si); close(fd); errno = ENOMEM; return -1; @@ -2683,8 +3095,6 @@ static int swrap_accept(int s, child_fi->fd = fd; - SWRAP_DLIST_ADD(child_si->fds, child_fi); - child_si->family = parent_si->family; child_si->type = parent_si->type; child_si->protocol = parent_si->protocol; @@ -2710,7 +3120,6 @@ static int swrap_accept(int s, &un_my_addr.sa_socklen); if (ret == -1) { free(child_fi); - free(child_si); close(fd); return ret; } @@ -2723,7 +3132,6 @@ static int swrap_accept(int s, &in_my_addr.sa_socklen); if (ret == -1) { free(child_fi); - free(child_si); close(fd); return ret; } @@ -2737,7 +3145,13 @@ static int swrap_accept(int s, }; memcpy(&child_si->myname.sa.ss, &in_my_addr.sa.ss, in_my_addr.sa_socklen); - SWRAP_DLIST_ADD(sockets, child_si); + child_si->refcount = 1; + first_free = child_si->next_free; + child_si->next_free = 0; + + child_fi->si_index = idx; + + SWRAP_DLIST_ADD(socket_fds, child_fi); if (addr != NULL) { swrap_pcap_dump_packet(child_si, addr, SWRAP_ACCEPT_SEND, NULL, 0); @@ -2805,8 +3219,9 @@ static int swrap_auto_bind(int fd, struct socket_info *si, int family) type = SOCKET_TYPE_CHAR_UDP; break; default: - errno = ESOCKTNOSUPPORT; - return -1; + errno = ESOCKTNOSUPPORT; + ret = -1; + goto done; } memset(&in, 0, sizeof(in)); @@ -2826,7 +3241,8 @@ static int swrap_auto_bind(int fd, struct socket_info *si, int family) if (si->family != family) { errno = ENETUNREACH; - return -1; + ret = -1; + goto done; } switch (si->type) { @@ -2838,7 +3254,8 @@ static int swrap_auto_bind(int fd, struct socket_info *si, int family) break; default: errno = ESOCKTNOSUPPORT; - return -1; + ret = -1; + goto done; } memset(&in6, 0, sizeof(in6)); @@ -2855,7 +3272,8 @@ static int swrap_auto_bind(int fd, struct socket_info *si, int family) #endif default: errno = ESOCKTNOSUPPORT; - return -1; + ret = -1; + goto done; } if (autobind_start > 60000) { @@ -2870,7 +3288,9 @@ static int swrap_auto_bind(int fd, struct socket_info *si, int family) if (stat(un_addr.sa.un.sun_path, &st) == 0) continue; ret = libc_bind(fd, &un_addr.sa.s, un_addr.sa_socklen); - if (ret == -1) return ret; + if (ret == -1) { + goto done; + } si->un_addr = un_addr.sa.un; @@ -2886,13 +3306,17 @@ static int swrap_auto_bind(int fd, struct socket_info *si, int family) socket_wrapper_default_iface(), 0); errno = ENFILE; - return -1; + ret = -1; + goto done; } si->family = family; set_port(si->family, port, &si->myname); - return 0; + ret = 0; + +done: + return ret; } /**************************************************************************** @@ -2915,21 +3339,27 @@ static int swrap_connect(int s, const struct sockaddr *serv_addr, if (si->bound == 0) { ret = swrap_auto_bind(s, si, serv_addr->sa_family); - if (ret == -1) return -1; + if (ret == -1) { + goto done; + } } if (si->family != serv_addr->sa_family) { errno = EINVAL; - return -1; + ret = -1; + goto done; } ret = sockaddr_convert_to_un(si, serv_addr, addrlen, &un_addr.sa.un, 0, &bcast); - if (ret == -1) return -1; + if (ret == -1) { + goto done; + } if (bcast) { errno = ENETUNREACH; - return -1; + ret = -1; + goto done; } if (si->type == SOCK_DGRAM) { @@ -2989,6 +3419,7 @@ static int swrap_connect(int s, const struct sockaddr *serv_addr, swrap_pcap_dump_packet(si, serv_addr, SWRAP_CONNECT_UNREACH, NULL, 0); } +done: return ret; } @@ -3084,7 +3515,9 @@ static int swrap_bind(int s, const struct sockaddr *myaddr, socklen_t addrlen) &un_addr.sa.un, 1, &si->bcast); - if (ret == -1) return -1; + if (ret == -1) { + return -1; + } unlink(un_addr.sa.un.sun_path); @@ -3243,6 +3676,31 @@ FILE *fopen(const char *name, const char *mode) } /**************************************************************************** + * FOPEN64 + ***************************************************************************/ + +#ifdef HAVE_FOPEN64 +static FILE *swrap_fopen64(const char *name, const char *mode) +{ + FILE *fp; + + fp = libc_fopen64(name, mode); + if (fp != NULL) { + int fd = fileno(fp); + + swrap_remove_stale(fd); + } + + return fp; +} + +FILE *fopen64(const char *name, const char *mode) +{ + return swrap_fopen64(name, mode); +} +#endif /* HAVE_FOPEN64 */ + +/**************************************************************************** * OPEN ***************************************************************************/ @@ -3276,6 +3734,75 @@ int open(const char *pathname, int flags, ...) } /**************************************************************************** + * OPEN64 + ***************************************************************************/ + +#ifdef HAVE_OPEN64 +static int swrap_vopen64(const char *pathname, int flags, va_list ap) +{ + int ret; + + ret = libc_vopen64(pathname, flags, ap); + if (ret != -1) { + /* + * There are methods for closing descriptors (libc-internal code + * paths, direct syscalls) which close descriptors in ways that + * we can't intercept, so try to recover when we notice that + * that's happened + */ + swrap_remove_stale(ret); + } + return ret; +} + +int open64(const char *pathname, int flags, ...) +{ + va_list ap; + int fd; + + va_start(ap, flags); + fd = swrap_vopen64(pathname, flags, ap); + va_end(ap); + + return fd; +} +#endif /* HAVE_OPEN64 */ + +/**************************************************************************** + * OPENAT + ***************************************************************************/ + +static int swrap_vopenat(int dirfd, const char *path, int flags, va_list ap) +{ + int ret; + + ret = libc_vopenat(dirfd, path, flags, ap); + if (ret != -1) { + /* + * There are methods for closing descriptors (libc-internal code + * paths, direct syscalls) which close descriptors in ways that + * we can't intercept, so try to recover when we notice that + * that's happened + */ + swrap_remove_stale(ret); + } + + return ret; +} + +int openat(int dirfd, const char *path, int flags, ...) +{ + va_list ap; + int fd; + + va_start(ap, flags); + fd = swrap_vopenat(dirfd, path, flags, ap); + va_end(ap); + + return fd; +} + +/**************************************************************************** * GETPEERNAME ***************************************************************************/ @@ -3361,6 +3888,7 @@ static int swrap_getsockopt(int s, int level, int optname, void *optval, socklen_t *optlen) { struct socket_info *si = find_socket_info(s); + int ret; if (!si) { return libc_getsockopt(s, @@ -3377,12 +3905,14 @@ static int swrap_getsockopt(int s, int level, int optname, if (optval == NULL || optlen == NULL || *optlen < (socklen_t)sizeof(int)) { errno = EINVAL; - return -1; + ret = -1; + goto done; } *optlen = sizeof(int); *(int *)optval = si->family; - return 0; + ret = 0; + goto done; #endif /* SO_DOMAIN */ #ifdef SO_PROTOCOL @@ -3390,29 +3920,34 @@ static int swrap_getsockopt(int s, int level, int optname, if (optval == NULL || optlen == NULL || *optlen < (socklen_t)sizeof(int)) { errno = EINVAL; - return -1; + ret = -1; + goto done; } *optlen = sizeof(int); *(int *)optval = si->protocol; - return 0; + ret = 0; + goto done; #endif /* SO_PROTOCOL */ case SO_TYPE: if (optval == NULL || optlen == NULL || *optlen < (socklen_t)sizeof(int)) { errno = EINVAL; - return -1; + ret = -1; + goto done; } *optlen = sizeof(int); *(int *)optval = si->type; - return 0; + ret = 0; + goto done; default: - return libc_getsockopt(s, - level, - optname, - optval, - optlen); + ret = libc_getsockopt(s, + level, + optname, + optval, + optlen); + goto done; } } else if (level == IPPROTO_TCP) { switch (optname) { @@ -3426,13 +3961,15 @@ static int swrap_getsockopt(int s, int level, int optname, if (optval == NULL || optlen == NULL || *optlen < (socklen_t)sizeof(int)) { errno = EINVAL; - return -1; + ret = -1; + goto done; } *optlen = sizeof(int); *(int *)optval = si->tcp_nodelay; - return 0; + ret = 0; + goto done; #endif /* TCP_NODELAY */ default: break; @@ -3440,7 +3977,10 @@ static int swrap_getsockopt(int s, int level, int optname, } errno = ENOPROTOOPT; - return -1; + ret = -1; + +done: + return ret; } #ifdef HAVE_ACCEPT_PSOCKLEN_T @@ -3460,6 +4000,7 @@ static int swrap_setsockopt(int s, int level, int optname, const void *optval, socklen_t optlen) { struct socket_info *si = find_socket_info(s); + int ret; if (!si) { return libc_setsockopt(s, @@ -3488,17 +4029,20 @@ static int swrap_setsockopt(int s, int level, int optname, if (optval == NULL || optlen == 0 || optlen < (socklen_t)sizeof(int)) { errno = EINVAL; - return -1; + ret = -1; + goto done; } i = *discard_const_p(int, optval); if (i != 0 && i != 1) { errno = EINVAL; - return -1; + ret = -1; + goto done; } si->tcp_nodelay = i; - return 0; + ret = 0; + goto done; } #endif /* TCP_NODELAY */ default: @@ -3515,7 +4059,8 @@ static int swrap_setsockopt(int s, int level, int optname, } #endif /* IP_PKTINFO */ } - return 0; + ret = 0; + goto done; #ifdef HAVE_IPV6 case AF_INET6: if (level == IPPROTO_IPV6) { @@ -3525,12 +4070,17 @@ static int swrap_setsockopt(int s, int level, int optname, } #endif /* IPV6_PKTINFO */ } - return 0; + ret = 0; + goto done; #endif default: errno = ENOPROTOOPT; - return -1; + ret = -1; + goto done; } + +done: + return ret; } int setsockopt(int s, int level, int optname, @@ -3944,7 +4494,9 @@ static ssize_t swrap_sendmsg_before(int fd, ret = sockaddr_convert_to_un(si, msg_name, msg->msg_namelen, tmp_un, 0, bcast); - if (ret == -1) return -1; + if (ret == -1) { + return -1; + } if (to_un) { *to_un = tmp_un; @@ -3979,7 +4531,9 @@ static ssize_t swrap_sendmsg_before(int fd, tmp_un, 0, NULL); - if (ret == -1) return -1; + if (ret == -1) { + return -1; + } ret = libc_connect(fd, (struct sockaddr *)(void *)tmp_un, @@ -5085,35 +5639,34 @@ ssize_t writev(int s, const struct iovec *vector, int count) static int swrap_close(int fd) { - struct socket_info *si = find_socket_info(fd); - struct socket_info_fd *fi; + struct socket_info_fd *fi = find_socket_info_fd(fd); + struct socket_info *si = NULL; + int si_index; int ret; - if (!si) { + if (fi == NULL) { return libc_close(fd); } - for (fi = si->fds; fi; fi = fi->next) { - if (fi->fd == fd) { - SWRAP_DLIST_REMOVE(si->fds, fi); - free(fi); - break; - } - } + si_index = fi->si_index; + + SWRAP_DLIST_REMOVE(socket_fds, fi); + free(fi); + + ret = libc_close(fd); - if (si->fds) { + si = &sockets[si_index]; + si->refcount--; + + if (si->refcount > 0) { /* there are still references left */ - return libc_close(fd); + return ret; } - SWRAP_DLIST_REMOVE(sockets, si); - if (si->myname.sa_socklen > 0 && si->peername.sa_socklen > 0) { swrap_pcap_dump_packet(si, NULL, SWRAP_CLOSE_SEND, NULL, 0); } - ret = libc_close(fd); - if (si->myname.sa_socklen > 0 && si->peername.sa_socklen > 0) { swrap_pcap_dump_packet(si, NULL, SWRAP_CLOSE_RECV, NULL, 0); swrap_pcap_dump_packet(si, NULL, SWRAP_CLOSE_ACK, NULL, 0); @@ -5122,7 +5675,9 @@ static int swrap_close(int fd) if (si->un_addr.sun_path[0] != '\0') { unlink(si->un_addr.sun_path); } - free(si); + + si->next_free = first_free; + first_free = si_index; return ret; } @@ -5139,14 +5694,15 @@ int close(int fd) static int swrap_dup(int fd) { struct socket_info *si; - struct socket_info_fd *fi; - - si = find_socket_info(fd); + struct socket_info_fd *src_fi, *fi; - if (!si) { + src_fi = find_socket_info_fd(fd); + if (src_fi == NULL) { return libc_dup(fd); } + si = &sockets[src_fi->si_index]; + fi = (struct socket_info_fd *)calloc(1, sizeof(struct socket_info_fd)); if (fi == NULL) { errno = ENOMEM; @@ -5161,10 +5717,13 @@ static int swrap_dup(int fd) return -1; } + si->refcount++; + fi->si_index = src_fi->si_index; + /* Make sure we don't have an entry for the fd */ swrap_remove_stale(fi->fd); - SWRAP_DLIST_ADD(si->fds, fi); + SWRAP_DLIST_ADD_AFTER(socket_fds, fi, src_fi); return fi->fd; } @@ -5180,14 +5739,25 @@ int dup(int fd) static int swrap_dup2(int fd, int newfd) { struct socket_info *si; - struct socket_info_fd *fi; - - si = find_socket_info(fd); + struct socket_info_fd *src_fi, *fi; - if (!si) { + src_fi = find_socket_info_fd(fd); + if (src_fi == NULL) { return libc_dup2(fd, newfd); } + si = &sockets[src_fi->si_index]; + + if (fd == newfd) { + /* + * According to the manpage: + * + * "If oldfd is a valid file descriptor, and newfd has the same + * value as oldfd, then dup2() does nothing, and returns newfd." + */ + return newfd; + } + if (find_socket_info(newfd)) { /* dup2() does an implicit close of newfd, which we * need to emulate */ @@ -5208,10 +5778,13 @@ static int swrap_dup2(int fd, int newfd) return -1; } + si->refcount++; + fi->si_index = src_fi->si_index; + /* Make sure we don't have an entry for the fd */ swrap_remove_stale(fi->fd); - SWRAP_DLIST_ADD(si->fds, fi); + SWRAP_DLIST_ADD_AFTER(socket_fds, fi, src_fi); return fi->fd; } @@ -5226,17 +5799,17 @@ int dup2(int fd, int newfd) static int swrap_vfcntl(int fd, int cmd, va_list va) { - struct socket_info_fd *fi; + struct socket_info_fd *src_fi, *fi; struct socket_info *si; int rc; - si = find_socket_info(fd); - if (si == NULL) { - rc = libc_vfcntl(fd, cmd, va); - - return rc; + src_fi = find_socket_info_fd(fd); + if (src_fi == NULL) { + return libc_vfcntl(fd, cmd, va); } + si = &sockets[src_fi->si_index]; + switch (cmd) { case F_DUPFD: fi = (struct socket_info_fd *)calloc(1, sizeof(struct socket_info_fd)); @@ -5253,10 +5826,13 @@ static int swrap_vfcntl(int fd, int cmd, va_list va) return -1; } + si->refcount++; + fi->si_index = src_fi->si_index; + /* Make sure we don't have an entry for the fd */ swrap_remove_stale(fi->fd); - SWRAP_DLIST_ADD(si->fds, fi); + SWRAP_DLIST_ADD_AFTER(socket_fds, fi, src_fi); rc = fi->fd; break; @@ -5319,6 +5895,45 @@ int pledge(const char *promises, const char *paths[]) } #endif /* HAVE_PLEDGE */ +static void swrap_thread_prepare(void) +{ + /* + * This function should only be called here!! + * + * We bind all symobls to avoid deadlocks of the fork is + * interrupted by a signal handler using a symbol of this + * library. + */ + swrap_bind_symbol_all(); + + SWRAP_LOCK_ALL; +} + +static void swrap_thread_parent(void) +{ + SWRAP_UNLOCK_ALL; +} + +static void swrap_thread_child(void) +{ + SWRAP_UNLOCK_ALL; +} + +/**************************** + * CONSTRUCTOR + ***************************/ +void swrap_constructor(void) +{ + /* + * If we hold a lock and the application forks, then the child + * is not able to unlock the mutex and we are in a deadlock. + * This should prevent such deadlocks. + */ + pthread_atfork(&swrap_thread_prepare, + &swrap_thread_parent, + &swrap_thread_child); +} + /**************************** * DESTRUCTOR ***************************/ @@ -5329,20 +5944,19 @@ int pledge(const char *promises, const char *paths[]) */ void swrap_destructor(void) { - struct socket_info *s = sockets; + struct socket_info_fd *s = socket_fds; while (s != NULL) { - struct socket_info_fd *f = s->fds; - if (f != NULL) { - swrap_close(f->fd); - } - s = sockets; + swrap_close(s->fd); + s = socket_fds; } - if (swrap.libc_handle != NULL) { - dlclose(swrap.libc_handle); + free(sockets); + + if (swrap.libc.handle != NULL) { + dlclose(swrap.libc.handle); } - if (swrap.libsocket_handle) { - dlclose(swrap.libsocket_handle); + if (swrap.libc.socket_handle) { + dlclose(swrap.libc.socket_handle); } } diff --git a/third_party/socket_wrapper/wscript b/third_party/socket_wrapper/wscript index 514265b92b6..1693b44eece 100644 --- a/third_party/socket_wrapper/wscript +++ b/third_party/socket_wrapper/wscript @@ -2,7 +2,7 @@ import os -VERSION="1.1.7" +VERSION="1.1.9" def configure(conf): if conf.CHECK_SOCKET_WRAPPER(): @@ -109,6 +109,6 @@ def build(bld): # breaks preloading! bld.SAMBA_LIBRARY('socket_wrapper', source='socket_wrapper.c', - deps='dl', + deps='dl pthread', install=False, realname='libsocket-wrapper.so') |