summaryrefslogtreecommitdiff
path: root/src/script_lua.c
diff options
context:
space:
mode:
Diffstat (limited to 'src/script_lua.c')
-rw-r--r--src/script_lua.c442
1 files changed, 295 insertions, 147 deletions
diff --git a/src/script_lua.c b/src/script_lua.c
index d7332cf86..9a08a7e47 100644
--- a/src/script_lua.c
+++ b/src/script_lua.c
@@ -238,9 +238,12 @@ static void redisProtocolToLuaType_Error(void *ctx, const char *str, size_t len,
* to push elements to the stack. On failure, exit with panic. */
serverPanic("lua stack limit reach when parsing redis.call reply");
}
- lua_newtable(lua);
- lua_pushstring(lua,"err");
- lua_pushlstring(lua,str,len);
+ sds err_msg = sdscatlen(sdsnew("-"), str, len);
+ luaPushErrorBuff(lua,err_msg);
+ /* push a field indicate to ignore updating the stats on this error
+ * because it was already updated when executing the command. */
+ lua_pushstring(lua,"ignore_error_stats_update");
+ lua_pushboolean(lua, true);
lua_settable(lua,-3);
}
@@ -428,40 +431,66 @@ static void redisProtocolToLuaType_Double(void *ctx, double d, const char *proto
/* This function is used in order to push an error on the Lua stack in the
* format used by redis.pcall to return errors, which is a lua table
- * with a single "err" field set to the error string. Note that this
- * table is never a valid reply by proper commands, since the returned
- * tables are otherwise always indexed by integers, never by strings. */
-void luaPushError(lua_State *lua, char *error) {
- lua_Debug dbg;
+ * with an "err" field set to the error string including the error code.
+ * Note that this table is never a valid reply by proper commands,
+ * since the returned tables are otherwise always indexed by integers, never by strings.
+ *
+ * The function takes ownership on the given err_buffer. */
+void luaPushErrorBuff(lua_State *lua, sds err_buffer) {
+ sds msg;
+ sds error_code;
/* If debugging is active and in step mode, log errors resulting from
* Redis commands. */
if (ldbIsEnabled()) {
- ldbLog(sdscatprintf(sdsempty(),"<error> %s",error));
+ ldbLog(sdscatprintf(sdsempty(),"<error> %s",err_buffer));
+ }
+
+ /* There are two possible formats for the received `error` string:
+ * 1) "-CODE msg": in this case we remove the leading '-' since we don't store it as part of the lua error format.
+ * 2) "msg": in this case we prepend a generic 'ERR' code since all error statuses need some error code.
+ * We support format (1) so this function can reuse the error messages used in other places in redis.
+ * We support format (2) so it'll be easy to pass descriptive errors to this function without worrying about format.
+ */
+ if (err_buffer[0] == '-') {
+ /* derive error code from the message */
+ char *err_msg = strstr(err_buffer, " ");
+ if (!err_msg) {
+ msg = sdsnew(err_buffer+1);
+ error_code = sdsnew("ERR");
+ } else {
+ *err_msg = '\0';
+ msg = sdsnew(err_msg+1);
+ error_code = sdsnew(err_buffer + 1);
+ }
+ sdsfree(err_buffer);
+ } else {
+ msg = err_buffer;
+ error_code = sdsnew("ERR");
}
+ /* Trim newline at end of string. If we reuse the ready-made Redis error objects (case 1 above) then we might
+ * have a newline that needs to be trimmed. In any case the lua Redis error table shouldn't end with a newline. */
+ msg = sdstrim(msg, "\r\n");
+ sds final_msg = sdscatfmt(error_code, " %s", msg);
lua_newtable(lua);
lua_pushstring(lua,"err");
-
- /* Attempt to figure out where this function was called, if possible */
- if(lua_getstack(lua, 1, &dbg) && lua_getinfo(lua, "nSl", &dbg)) {
- sds msg = sdscatprintf(sdsempty(), "%s: %d: %s",
- dbg.source, dbg.currentline, error);
- lua_pushstring(lua, msg);
- sdsfree(msg);
- } else {
- lua_pushstring(lua, error);
- }
+ lua_pushstring(lua, final_msg);
lua_settable(lua,-3);
+
+ sdsfree(msg);
+ sdsfree(final_msg);
+}
+
+void luaPushError(lua_State *lua, const char *error) {
+ luaPushErrorBuff(lua, sdsnew(error));
}
/* In case the error set into the Lua stack by luaPushError() was generated
* by the non-error-trapping version of redis.pcall(), which is redis.call(),
* this function will raise the Lua error so that the execution of the
* script will be halted. */
-int luaRaiseError(lua_State *lua) {
- lua_pushstring(lua,"err");
- lua_gettable(lua,-2);
+int luaError(lua_State *lua) {
return lua_error(lua);
}
@@ -511,8 +540,15 @@ static void luaReplyToRedisReply(client *c, client* script_client, lua_State *lu
lua_gettable(lua,-2);
t = lua_type(lua,-1);
if (t == LUA_TSTRING) {
- addReplyErrorFormat(c,"-%s",lua_tostring(lua,-1));
- lua_pop(lua,2);
+ lua_pop(lua, 1); /* pop the error message, we will use luaExtractErrorInformation to get error information */
+ errorInfo err_info = {0};
+ luaExtractErrorInformation(lua, &err_info);
+ addReplyErrorFormatEx(c,
+ err_info.ignore_err_stats_update? ERR_REPLY_FLAG_NO_STATS_UPDATE: 0,
+ "-%s",
+ err_info.msg);
+ luaErrorInformationDiscard(&err_info);
+ lua_pop(lua,1); /* pop the result table */
return;
}
lua_pop(lua,1); /* Discard field name pushed before. */
@@ -655,55 +691,19 @@ static void luaReplyToRedisReply(client *c, client* script_client, lua_State *lu
* Lua redis.* functions implementations.
* ------------------------------------------------------------------------- */
-#define LUA_CMD_OBJCACHE_SIZE 32
-#define LUA_CMD_OBJCACHE_MAX_LEN 64
-static int luaRedisGenericCommand(lua_State *lua, int raise_error) {
- int j, argc = lua_gettop(lua);
- scriptRunCtx* rctx = luaGetFromRegistry(lua, REGISTRY_RUN_CTX_NAME);
- if (!rctx) {
- luaPushError(lua, "redis.call/pcall can only be called inside a script invocation");
- return luaRaiseError(lua);
- }
- sds err = NULL;
- client* c = rctx->c;
- sds reply;
-
- /* Cached across calls. */
- static robj **argv = NULL;
- static int argv_size = 0;
- static robj *cached_objects[LUA_CMD_OBJCACHE_SIZE];
- static size_t cached_objects_len[LUA_CMD_OBJCACHE_SIZE];
- static int inuse = 0; /* Recursive calls detection. */
-
- /* By using Lua debug hooks it is possible to trigger a recursive call
- * to luaRedisGenericCommand(), which normally should never happen.
- * To make this function reentrant is futile and makes it slower, but
- * we should at least detect such a misuse, and abort. */
- if (inuse) {
- char *recursion_warning =
- "luaRedisGenericCommand() recursive call detected. "
- "Are you doing funny stuff with Lua debug hooks?";
- serverLog(LL_WARNING,"%s",recursion_warning);
- luaPushError(lua,recursion_warning);
- return 1;
- }
- inuse++;
-
+static robj **luaArgsToRedisArgv(lua_State *lua, int *argc) {
+ int j;
/* Require at least one argument */
- if (argc == 0) {
- luaPushError(lua,
- "Please specify at least one argument for redis.call()");
- inuse--;
- return raise_error ? luaRaiseError(lua) : 1;
+ *argc = lua_gettop(lua);
+ if (*argc == 0) {
+ luaPushError(lua, "Please specify at least one argument for this redis lib call");
+ return NULL;
}
/* Build the arguments vector */
- if (argv_size < argc) {
- argv = zrealloc(argv,sizeof(robj*)*argc);
- argv_size = argc;
- }
+ robj **argv = zcalloc(sizeof(robj*) * *argc);
- for (j = 0; j < argc; j++) {
+ for (j = 0; j < *argc; j++) {
char *obj_s;
size_t obj_len;
char dbuf[64];
@@ -720,38 +720,62 @@ static int luaRedisGenericCommand(lua_State *lua, int raise_error) {
if (obj_s == NULL) break; /* Not a string. */
}
- /* Try to use a cached object. */
- if (j < LUA_CMD_OBJCACHE_SIZE && cached_objects[j] &&
- cached_objects_len[j] >= obj_len)
- {
- sds s = cached_objects[j]->ptr;
- argv[j] = cached_objects[j];
- cached_objects[j] = NULL;
- memcpy(s,obj_s,obj_len+1);
- sdssetlen(s, obj_len);
- } else {
- argv[j] = createStringObject(obj_s, obj_len);
- }
+ argv[j] = createStringObject(obj_s, obj_len);
}
+ /* Pop all arguments from the stack, we do not need them anymore
+ * and this way we guaranty we will have room on the stack for the result. */
+ lua_pop(lua, *argc);
+
/* Check if one of the arguments passed by the Lua script
* is not a string or an integer (lua_isstring() return true for
* integers as well). */
- if (j != argc) {
+ if (j != *argc) {
j--;
while (j >= 0) {
decrRefCount(argv[j]);
j--;
}
- luaPushError(lua,
- "Lua redis() command arguments must be strings or integers");
- inuse--;
- return raise_error ? luaRaiseError(lua) : 1;
+ zfree(argv);
+ luaPushError(lua, "Lua redis lib command arguments must be strings or integers");
+ return NULL;
}
- /* Pop all arguments from the stack, we do not need them anymore
- * and this way we guaranty we will have room on the stack for the result. */
- lua_pop(lua, argc);
+ return argv;
+}
+
+static int luaRedisGenericCommand(lua_State *lua, int raise_error) {
+ int j;
+ scriptRunCtx* rctx = luaGetFromRegistry(lua, REGISTRY_RUN_CTX_NAME);
+ if (!rctx) {
+ luaPushError(lua, "redis.call/pcall can only be called inside a script invocation");
+ return luaError(lua);
+ }
+ sds err = NULL;
+ client* c = rctx->c;
+ sds reply;
+
+ int argc;
+ robj **argv = luaArgsToRedisArgv(lua, &argc);
+ if (argv == NULL) {
+ return raise_error ? luaError(lua) : 1;
+ }
+
+ static int inuse = 0; /* Recursive calls detection. */
+
+ /* By using Lua debug hooks it is possible to trigger a recursive call
+ * to luaRedisGenericCommand(), which normally should never happen.
+ * To make this function reentrant is futile and makes it slower, but
+ * we should at least detect such a misuse, and abort. */
+ if (inuse) {
+ char *recursion_warning =
+ "luaRedisGenericCommand() recursive call detected. "
+ "Are you doing funny stuff with Lua debug hooks?";
+ serverLog(LL_WARNING,"%s",recursion_warning);
+ luaPushError(lua,recursion_warning);
+ return 1;
+ }
+ inuse++;
/* Log the command if debugging is active. */
if (ldbIsEnabled()) {
@@ -769,11 +793,15 @@ static int luaRedisGenericCommand(lua_State *lua, int raise_error) {
ldbLog(cmdlog);
}
-
scriptCall(rctx, argv, argc, &err);
if (err) {
luaPushError(lua, err);
sdsfree(err);
+ /* push a field indicate to ignore updating the stats on this error
+ * because it was already updated when executing the command. */
+ lua_pushstring(lua,"ignore_error_stats_update");
+ lua_pushboolean(lua, true);
+ lua_settable(lua,-3);
goto cleanup;
}
@@ -810,48 +838,48 @@ static int luaRedisGenericCommand(lua_State *lua, int raise_error) {
cleanup:
/* Clean up. Command code may have changed argv/argc so we use the
* argv/argc of the client instead of the local variables. */
- for (j = 0; j < c->argc; j++) {
- robj *o = c->argv[j];
-
- /* Try to cache the object in the cached_objects array.
- * The object must be small, SDS-encoded, and with refcount = 1
- * (we must be the only owner) for us to cache it. */
- if (j < LUA_CMD_OBJCACHE_SIZE &&
- o->refcount == 1 &&
- (o->encoding == OBJ_ENCODING_RAW ||
- o->encoding == OBJ_ENCODING_EMBSTR) &&
- sdslen(o->ptr) <= LUA_CMD_OBJCACHE_MAX_LEN)
- {
- sds s = o->ptr;
- if (cached_objects[j]) decrRefCount(cached_objects[j]);
- cached_objects[j] = o;
- cached_objects_len[j] = sdsalloc(s);
- } else {
- decrRefCount(o);
- }
- }
-
- if (c->argv != argv) {
- zfree(c->argv);
- argv = NULL;
- argv_size = 0;
- }
-
+ freeClientArgv(c);
c->user = NULL;
- c->argv = NULL;
- c->argc = 0;
+ inuse--;
if (raise_error) {
/* If we are here we should have an error in the stack, in the
* form of a table with an "err" field. Extract the string to
* return the plain error. */
- inuse--;
- return luaRaiseError(lua);
+ return luaError(lua);
}
- inuse--;
return 1;
}
+/* Our implementation to lua pcall.
+ * We need this implementation for backward
+ * comparability with older Redis versions.
+ *
+ * On Redis 7, the error object is a table,
+ * compare to older version where the error
+ * object is a string. To keep backward
+ * comparability we catch the table object
+ * and just return the error message. */
+static int luaRedisPcall(lua_State *lua) {
+ int argc = lua_gettop(lua);
+ lua_pushboolean(lua, 1); /* result place holder */
+ lua_insert(lua, 1);
+ if (lua_pcall(lua, argc - 1, LUA_MULTRET, 0)) {
+ /* Error */
+ lua_remove(lua, 1); /* remove the result place holder, now we have room for at least one element */
+ if (lua_istable(lua, -1)) {
+ lua_getfield(lua, -1, "err");
+ if (lua_isstring(lua, -1)) {
+ lua_replace(lua, -2); /* replace the error message with the table */
+ }
+ }
+ lua_pushboolean(lua, 0); /* push result */
+ lua_insert(lua, 1);
+ }
+ return lua_gettop(lua);
+
+}
+
/* redis.call() */
static int luaRedisCallCommand(lua_State *lua) {
return luaRedisGenericCommand(lua,1);
@@ -871,8 +899,8 @@ static int luaRedisSha1hexCommand(lua_State *lua) {
char *s;
if (argc != 1) {
- lua_pushstring(lua, "wrong number of arguments");
- return lua_error(lua);
+ luaPushError(lua, "wrong number of arguments");
+ return luaError(lua);
}
s = (char*)lua_tolstring(lua,1,&len);
@@ -903,7 +931,21 @@ static int luaRedisReturnSingleFieldTable(lua_State *lua, char *field) {
/* redis.error_reply() */
static int luaRedisErrorReplyCommand(lua_State *lua) {
- return luaRedisReturnSingleFieldTable(lua,"err");
+ if (lua_gettop(lua) != 1 || lua_type(lua,-1) != LUA_TSTRING) {
+ luaPushError(lua, "wrong number or type of arguments");
+ return 1;
+ }
+
+ /* add '-' if not exists */
+ const char *err = lua_tostring(lua, -1);
+ sds err_buff = NULL;
+ if (err[0] != '-') {
+ err_buff = sdscatfmt(sdsempty(), "-%s", err);
+ } else {
+ err_buff = sdsnew(err);
+ }
+ luaPushErrorBuff(lua, err_buff);
+ return 1;
}
/* redis.status_reply() */
@@ -920,25 +962,65 @@ static int luaRedisSetReplCommand(lua_State *lua) {
scriptRunCtx* rctx = luaGetFromRegistry(lua, REGISTRY_RUN_CTX_NAME);
if (!rctx) {
- lua_pushstring(lua, "redis.set_repl can only be called inside a script invocation");
- return lua_error(lua);
+ luaPushError(lua, "redis.set_repl can only be called inside a script invocation");
+ return luaError(lua);
}
if (argc != 1) {
- lua_pushstring(lua, "redis.set_repl() requires two arguments.");
- return lua_error(lua);
+ luaPushError(lua, "redis.set_repl() requires two arguments.");
+ return luaError(lua);
}
flags = lua_tonumber(lua,-1);
if ((flags & ~(PROPAGATE_AOF|PROPAGATE_REPL)) != 0) {
- lua_pushstring(lua, "Invalid replication flags. Use REPL_AOF, REPL_REPLICA, REPL_ALL or REPL_NONE.");
- return lua_error(lua);
+ luaPushError(lua, "Invalid replication flags. Use REPL_AOF, REPL_REPLICA, REPL_ALL or REPL_NONE.");
+ return luaError(lua);
}
scriptSetRepl(rctx, flags);
return 0;
}
+/* redis.acl_check_cmd()
+ *
+ * Checks ACL permissions for given command for the current user. */
+static int luaRedisAclCheckCmdPermissionsCommand(lua_State *lua) {
+ scriptRunCtx* rctx = luaGetFromRegistry(lua, REGISTRY_RUN_CTX_NAME);
+ if (!rctx) {
+ luaPushError(lua, "redis.acl_check_cmd can only be called inside a script invocation");
+ return luaError(lua);
+ }
+ int raise_error = 0;
+
+ int argc;
+ robj **argv = luaArgsToRedisArgv(lua, &argc);
+
+ /* Require at least one argument */
+ if (argv == NULL) return luaError(lua);
+
+ /* Find command */
+ struct redisCommand *cmd;
+ if ((cmd = lookupCommand(argv, argc)) == NULL) {
+ luaPushError(lua, "Invalid command passed to redis.acl_check_cmd()");
+ raise_error = 1;
+ } else {
+ int keyidxptr;
+ if (ACLCheckAllUserCommandPerm(rctx->original_client->user, cmd, argv, argc, &keyidxptr) != ACL_OK) {
+ lua_pushboolean(lua, 0);
+ } else {
+ lua_pushboolean(lua, 1);
+ }
+ }
+
+ while (argc--) decrRefCount(argv[argc]);
+ zfree(argv);
+ if (raise_error)
+ return luaError(lua);
+ else
+ return 1;
+}
+
+
/* redis.log() */
static int luaLogCommand(lua_State *lua) {
int j, argc = lua_gettop(lua);
@@ -946,16 +1028,16 @@ static int luaLogCommand(lua_State *lua) {
sds log;
if (argc < 2) {
- lua_pushstring(lua, "redis.log() requires two arguments or more.");
- return lua_error(lua);
+ luaPushError(lua, "redis.log() requires two arguments or more.");
+ return luaError(lua);
} else if (!lua_isnumber(lua,-argc)) {
- lua_pushstring(lua, "First argument must be a number (log level).");
- return lua_error(lua);
+ luaPushError(lua, "First argument must be a number (log level).");
+ return luaError(lua);
}
level = lua_tonumber(lua,-argc);
if (level < LL_DEBUG || level > LL_WARNING) {
- lua_pushstring(lua, "Invalid debug level.");
- return lua_error(lua);
+ luaPushError(lua, "Invalid debug level.");
+ return luaError(lua);
}
if (level < server.verbosity) return 0;
@@ -980,20 +1062,20 @@ static int luaLogCommand(lua_State *lua) {
static int luaSetResp(lua_State *lua) {
scriptRunCtx* rctx = luaGetFromRegistry(lua, REGISTRY_RUN_CTX_NAME);
if (!rctx) {
- lua_pushstring(lua, "redis.setresp can only be called inside a script invocation");
- return lua_error(lua);
+ luaPushError(lua, "redis.setresp can only be called inside a script invocation");
+ return luaError(lua);
}
int argc = lua_gettop(lua);
if (argc != 1) {
- lua_pushstring(lua, "redis.setresp() requires one argument.");
- return lua_error(lua);
+ luaPushError(lua, "redis.setresp() requires one argument.");
+ return luaError(lua);
}
int resp = lua_tonumber(lua,-argc);
if (resp != 2 && resp != 3) {
- lua_pushstring(lua, "RESP version must be 2 or 3.");
- return lua_error(lua);
+ luaPushError(lua, "RESP version must be 2 or 3.");
+ return luaError(lua);
}
scriptSetResp(rctx, resp);
return 0;
@@ -1193,6 +1275,9 @@ void luaRegisterRedisAPI(lua_State* lua) {
luaLoadLibraries(lua);
luaRemoveUnsupportedFunctions(lua);
+ lua_pushcfunction(lua,luaRedisPcall);
+ lua_setglobal(lua, "pcall");
+
/* Register the redis commands table and fields */
lua_newtable(lua);
@@ -1251,8 +1336,13 @@ void luaRegisterRedisAPI(lua_State* lua) {
lua_pushstring(lua,"REPL_ALL");
lua_pushnumber(lua,PROPAGATE_AOF|PROPAGATE_REPL);
+ lua_settable(lua,-3);
+ /* redis.acl_check_cmd */
+ lua_pushstring(lua,"acl_check_cmd");
+ lua_pushcfunction(lua,luaRedisAclCheckCmdPermissionsCommand);
lua_settable(lua,-3);
+
/* Finally set the table as 'redis' global var. */
lua_setglobal(lua,REDIS_API_NAME);
@@ -1348,11 +1438,50 @@ static void luaMaskCountHook(lua_State *lua, lua_Debug *ar) {
*/
lua_sethook(lua, luaMaskCountHook, LUA_MASKLINE, 0);
- lua_pushstring(lua,"Script killed by user with SCRIPT KILL...");
- lua_error(lua);
+ luaPushError(lua,"Script killed by user with SCRIPT KILL...");
+ luaError(lua);
}
}
+void luaErrorInformationDiscard(errorInfo *err_info) {
+ if (err_info->msg) sdsfree(err_info->msg);
+ if (err_info->source) sdsfree(err_info->source);
+ if (err_info->line) sdsfree(err_info->line);
+}
+
+void luaExtractErrorInformation(lua_State *lua, errorInfo *err_info) {
+ if (lua_isstring(lua, -1)) {
+ err_info->msg = sdscatfmt(sdsempty(), "ERR %s", lua_tostring(lua, -1));
+ err_info->line = NULL;
+ err_info->source = NULL;
+ err_info->ignore_err_stats_update = 0;
+ }
+
+ lua_getfield(lua, -1, "err");
+ if (lua_isstring(lua, -1)) {
+ err_info->msg = sdsnew(lua_tostring(lua, -1));
+ }
+ lua_pop(lua, 1);
+
+ lua_getfield(lua, -1, "source");
+ if (lua_isstring(lua, -1)) {
+ err_info->source = sdsnew(lua_tostring(lua, -1));
+ }
+ lua_pop(lua, 1);
+
+ lua_getfield(lua, -1, "line");
+ if (lua_isstring(lua, -1)) {
+ err_info->line = sdsnew(lua_tostring(lua, -1));
+ }
+ lua_pop(lua, 1);
+
+ lua_getfield(lua, -1, "ignore_error_stats_update");
+ if (lua_isboolean(lua, -1)) {
+ err_info->ignore_err_stats_update = lua_toboolean(lua, -1);
+ }
+ lua_pop(lua, 1);
+}
+
void luaCallFunction(scriptRunCtx* run_ctx, lua_State *lua, robj** keys, size_t nkeys, robj** args, size_t nargs, int debug_enabled) {
client* c = run_ctx->original_client;
int delhook = 0;
@@ -1410,9 +1539,28 @@ void luaCallFunction(scriptRunCtx* run_ctx, lua_State *lua, robj** keys, size_t
}
if (err) {
- addReplyErrorFormat(c,"Error running script (call to %s): %s\n",
- run_ctx->funcname, lua_tostring(lua,-1));
- lua_pop(lua,1); /* Consume the Lua reply and remove error handler. */
+ /* Error object is a table of the following format:
+ * {err='<error msg>', source='<source file>', line=<line>}
+ * We can construct the error message from this information */
+ if (!lua_istable(lua, -1)) {
+ /* Should not happened, and we should considered assert it */
+ addReplyErrorFormat(c,"Error running script (call to %s)\n", run_ctx->funcname);
+ } else {
+ errorInfo err_info = {0};
+ sds final_msg = sdsempty();
+ luaExtractErrorInformation(lua, &err_info);
+ final_msg = sdscatfmt(final_msg, "-%s",
+ err_info.msg);
+ if (err_info.line && err_info.source) {
+ final_msg = sdscatfmt(final_msg, " script: %s, on %s:%s.",
+ run_ctx->funcname,
+ err_info.source,
+ err_info.line);
+ }
+ addReplyErrorSdsEx(c, final_msg, err_info.ignore_err_stats_update? ERR_REPLY_FLAG_NO_STATS_UPDATE : 0);
+ luaErrorInformationDiscard(&err_info);
+ }
+ lua_pop(lua,1); /* Consume the Lua error */
} else {
/* On success convert the Lua return value into Redis protocol, and
* send it to * the client. */