diff --git a/src/config.lua b/src/config.lua index 9bb3e6e..58f9d7d 100644 --- a/src/config.lua +++ b/src/config.lua @@ -5,7 +5,7 @@ local M = {} M.DEBUG_PCALLS = true --REST responses will contain 'module' and 'function' keys describing what was requested -M.API_INCLUDE_ENDPOINT_INFO = true +M.API_INCLUDE_ENDPOINT_INFO = false M.DEFAULT_AP_SSID = "d3d-ap-%MAC_ADDR_TAIL%" M.DEFAULT_AP_ADDRESS = "192.168.10.1" diff --git a/src/main.lua b/src/main.lua index cac9897..bbe7d39 100644 --- a/src/main.lua +++ b/src/main.lua @@ -1,3 +1,5 @@ +package.path = package.path .. ';/usr/share/lua/wifibox/?.lua' + local l = require("logger") local RequestClass = require("rest.request") local ResponseClass = require("rest.response") @@ -38,7 +40,8 @@ end local function main() local rq = RequestClass.new(postData, config.DEBUG_PCALLS) - l:info("received request of type " .. rq:getRequestMethod() .. " with arguments: " .. l:dump(rq:getAll())) + l:info("received request of type " .. rq:getRequestMethod() .. " for " .. (rq:getRequestedApiModule() or "") + .. "/" .. (rq:getRealApiFunctionName() or "") .. " with arguments: " .. l:dump(rq:getAll())) if rq:getRequestMethod() ~= "CMDLINE" then l:info("remote IP/port: " .. rq:getRemoteHost() .. "/" .. rq:getRemotePort()) l:debug("user agent: " .. rq:getUserAgent()) @@ -52,7 +55,6 @@ end end else - io.write ("Content-type: text/plain\r\n\r\n") local response, err = rq:handle() if err ~= nil then l:error(err) end diff --git a/src/rest/api/api_test.lua b/src/rest/api/api_test.lua index 57844d8..13712f4 100644 --- a/src/rest/api/api_test.lua +++ b/src/rest/api/api_test.lua @@ -5,6 +5,15 @@ local M = {} M.isApi = true +--empty or nil is equivalent to 'ANY', otherwise restrict to specified letters (command-line is always allowed) +M._access = { + _global = "GET", + success = "GET", fail = "GET", error = "GET", + read = "GET", write = "POST", readwrite = "ANY", readwrite2 = "", + echo = "GET" +} + + function M._global(request, response) local ba = request:getBlankArgument() @@ -27,6 +36,13 @@ function M.error(request, response) response:addData("url", "http://xkcd.com/1024/") end + +function M.read(request, response) response:setSuccess("this endpoint can only be accessed through GET request") end +function M.write(request, response) response:setSuccess("this endpoint can only be accessed through POST request") end +function M.readwrite(request, response) response:setSuccess("this endpoint can only be accessed through POST request") end +function M.readwrite2(request, response) response:setSuccess("this endpoint can only be accessed through POST request") end + + function M.echo(request, response) response:setSuccess("request echo") response:addData("request_data", request:getAll()) diff --git a/src/rest/request.lua b/src/rest/request.lua index b988757..8b64768 100644 --- a/src/rest/request.lua +++ b/src/rest/request.lua @@ -17,6 +17,7 @@ M.requestedApiFunction = nil M.resolvedApiFunction = nil --will contain function address, or nil M.realApiFunctionName = nil --will contain requested name, or global name, or nil M.resolutionError = nil --non-nil means function could not be resolved +M.moduleAccessTable = nil local function kvTableFromUrlEncodedString(encodedText) @@ -31,6 +32,8 @@ end local function kvTableFromArray(argArray) local args = {} + if not argArray then return args end + for _, v in ipairs(argArray) do local split = v:find("=") if split ~= nil then @@ -45,8 +48,12 @@ end --NOTE: this function ignores empty tokens (e.g. '/a//b/' yields { [1] = a, [2] = b }) local function arrayFromPath(pathText) - return pathText and pathText:split("/") or {} --FIXME: nothing returned? regardless of which sep is used - --return pathText:split("/") + return pathText and pathText:split("/") or {} +end + +--returns true if acceptable is nil or empty or 'ANY' or if it contains requested +local function matchRequestMethod(acceptable, requested) + return acceptable == nil or acceptable == '' or acceptable == 'ANY' or string.find(acceptable, requested) end @@ -69,8 +76,11 @@ local function resolveApiModule(modname) return modObj end ---returns funcobj+nil (usual), funcobj+number (global func with blank arg), or nil+errmsg (unresolvable or inaccessible) +--returns resultData+nil (usual), or nil+errmsg (unresolvable or inaccessible) +--resultData contains 'func', 'accessTable' and if found, also 'blankArg' local function resolveApiFunction(modname, funcname) + local resultData = {} + if funcname and string.find(funcname, "_") == 1 then return nil, "function names starting with '_' are preserved for internal use" end local mod, msg = resolveApiModule(modname) @@ -84,12 +94,17 @@ local function resolveApiFunction(modname, funcname) local funcNumber = tonumber(funcname) if (type(f) == "function") then - return f + resultData.func = f + resultData.accessTable = mod._access elseif funcNumber ~= nil then - return mod[GLOBAL_API_FUNCTION_NAME], funcNumber + resultData.func = mod[GLOBAL_API_FUNCTION_NAME] + resultData.accessTable = mod._access + resultData.blankArg = funcNumber else return nil, ("function '" .. funcname .. "' does not exist in API module '" .. modname .. "'") end + + return resultData end @@ -123,27 +138,24 @@ function M.new(postData, debug) if debug and self.requestMethod == "CMDLINE" then self.pathArgs = arrayFromPath(self.cmdLineArgs["p"]) end + table.remove(self.pathArgs, 1) --drop the first 'empty' field caused by the opening slash of the query string if #self.pathArgs >= 1 then self.requestedApiModule = self.pathArgs[1] end if #self.pathArgs >= 2 then self.requestedApiFunction = self.pathArgs[2] end --- if debug then --- self.requestedApiModule = self.cmdLineArgs["m"] or self.requestedApiModule --- self.requestedApiFunction = self.cmdLineArgs["f"] or self.requestedApiFunction --- end - if self.requestedApiModule == "" then self.requestedApiModule = nil end if self.requestedApiFunction == "" then self.requestedApiFunction = nil end -- Perform module/function resolution - local sfunc, sres = resolveApiFunction(self:getRequestedApiModule(), self:getRequestedApiFunction()) + local rData, errMsg = resolveApiFunction(self:getRequestedApiModule(), self:getRequestedApiFunction()) - if sfunc ~= nil then --function (possibly the global one) could be resolved - self.resolvedApiFunction = sfunc - if sres ~= nil then --apparently it was the global one, and we received a 'blank argument' - self:setBlankArgument(sres) + if rData ~= nil and rData.func ~= nil then --function (possibly the global one) could be resolved + self.resolvedApiFunction = rData.func + self.moduleAccessTable = rData.accessTable + if rData.blankArg ~= nil then --apparently it was the global one, and we received a 'blank argument' + self:setBlankArgument(rData.blankArg) self.realApiFunctionName = GLOBAL_API_FUNCTION_NAME else --resolved without blank argument but still potentially the global function, hence the _or_ construction if self:getRequestedApiFunction() ~= nil then @@ -155,38 +167,19 @@ function M.new(postData, debug) end else --instead of throwing an error, save the message for handle() which is expected to return a response anyway - self.resolutionError = sres + self.resolutionError = errMsg end return self end ---returns either GET or POST or CMDLINE -function M:getRequestMethod() - return self.requestMethod -end - -function M:getRequestedApiModule() - return self.requestedApiModule -end - -function M:getRequestedApiFunction() - return self.requestedApiFunction -end - -function M:getRealApiFunctionName() - return self.realApiFunctionName -end - -function M:getBlankArgument() - return self.blankArgument -end - -function M:setBlankArgument(arg) - self.blankArgument = arg -end - +function M:getRequestMethod() return self.requestMethod end --returns either GET or POST or CMDLINE +function M:getRequestedApiModule() return self.requestedApiModule end +function M:getRequestedApiFunction() return self.requestedApiFunction end +function M:getRealApiFunctionName() return self.realApiFunctionName end +function M:getBlankArgument() return self.blankArgument end +function M:setBlankArgument(arg) self.blankArgument = arg end function M:getRemoteHost() return self.remoteHost or "" end function M:getRemotePort() return self.remotePort or 0 end function M:getUserAgent() return self.userAgent or "" end @@ -219,13 +212,19 @@ function M:getPathData() return self.pathArgs end - --returns either a response object+nil, or response object+errmsg function M:handle() local modname = self:getRequestedApiModule() local resp = ResponseClass.new(self) if (self.resolvedApiFunction ~= nil) then --we found a function (possible the global function) + --check access type + local accessText = self.moduleAccessTable[self.realApiFunctionName] + if not matchRequestMethod(accessText, self.requestMethod) then + resp:setError("function '" .. modname .. "/" .. self.realApiFunctionName .. "' requires different request method ('" .. accessText .. "')") + return resp, "incorrect access method (" .. accessText .. " != " .. self.requestMethod .. ")" + end + --invoke the function local ok, r if config.DEBUG_PCALLS then ok, r = true, self.resolvedApiFunction(self, resp) diff --git a/src/rest/response.lua b/src/rest/response.lua index 789e1fa..7824462 100644 --- a/src/rest/response.lua +++ b/src/rest/response.lua @@ -1,4 +1,4 @@ -local JSON = (loadfile "util/JSON.lua")() +local JSON = require("util/JSON") local config = require("config") local M = {} @@ -6,6 +6,9 @@ M.__index = M local REQUEST_ID_ARGUMENT = "rq_id" +M.httpStatusCode, M.httpStatusText = nil, nil + + setmetatable(M, { __call = function(cls, ...) return cls.new(...) @@ -16,7 +19,8 @@ setmetatable(M, { function M.new(requestObject) local self = setmetatable({}, M) - self.body = {status = nil, data = {}} + self.body = { status = nil, data = {} } + self:setHttpStatus(200, "OK") if requestObject ~= nil then local rqId = requestObject:get(REQUEST_ID_ARGUMENT) @@ -31,6 +35,11 @@ function M.new(requestObject) return self end +function M:setHttpStatus(code, text) + if code ~= nil then self.httpStatusCode = code end + if text ~= nil then self.httpStatusText = text end +end + function M:setSuccess(msg) self.body.status = "success" if msg ~= "" then self.body.msg = msg end @@ -44,6 +53,8 @@ end function M:setError(msg) self.body.status = "error" if msg ~= "" then self.body.msg = msg end + + self:addData("more_info", "http://doodle3d.nl/wiki/wiki/communication-api") end --NOTE: with this method, to add nested data, it is necessary to precreate the table and add it with its root key @@ -57,6 +68,8 @@ function M:serializeAsJson() end function M:send() + io.write("Status: " .. self.httpStatusCode .. " " .. self.httpStatusText .. "\r\n") + io.write ("Content-type: text/plain\r\n\r\n") print(self:serializeAsJson()) end diff --git a/src/util/utils.lua b/src/util/utils.lua index 6ddb119..8d12215 100644 --- a/src/util/utils.lua +++ b/src/util/utils.lua @@ -2,11 +2,14 @@ local uci = require("uci").cursor() local M = {} -function string:split(sep) - local sep, fields = sep or ":", {} - local pattern = string.format("([^%s]+)", sep) - self:gsub(pattern, function(c) fields[#fields+1] = c end) - return fields +function string:split(div) + local div, pos, arr = div or ":", 0, {} + for st,sp in function() return self:find(div, pos, true) end do + table.insert(arr, self:sub(pos, st - 1)) + pos = sp + 1 + end + table.insert(arr, self:sub(pos)) + return arr end function M.toboolean(s)