Securing APIs with JSON Web Tokens (JWT)

Adding Custom Authorizers in Lambda functions

For this tutorial we are going to protect our APIs from unauthorized access by creating Lambda Authorizer, formerly known as CustomAuthorizer. It is an API Gateway feature that uses a Lambda function to control access to your API.

Interested to know more about custom authorizers? Please visit Configure a Lambda authorizer using the API Gateway console .

Custom Authorizers

Okay, let me first remove the "Authorizer" which I have declared in the template.yaml file. By the way, if you have observed from the beginning, we are using Authorizer as "None", but we did not set up the Default Authorizer. If you try to deploy the functions it’s not going to work, and it would lead to an error because you haven’t set up the authorizer. So, let’s begin.

auth_step_1

Let me start by creating an authorizer package under "user". I will follow the usual process like creating a app.py file.

auth_step_2

I will be using an existing blueprint for the authorization process. It’s already available in AWS Lambda under the "Blueprints". I am going to copy this logic into my app.py file.

auth_step_3

Below, I am providing the final snippet of the custom authorizer blueprint.

The source code is available on GitHub at https://github.com/mukulmantosh/ServerlessDemo

import os
import re
import jwt


def auth_token_decode(auth_token):
    """
    Checks whether JWT Token is valid or not.
    If valid returns True else False
    """
    try:
        jwt.decode(auth_token, os.environ['SECRET_KEY'], algorithms=["HS256"])
        return True
    except jwt.ExpiredSignatureError:
        return False
    except jwt.InvalidSignatureError:
        return False
    except jwt.InvalidTokenError:
        return False



def lambda_handler(event, context):
    token = event['authorizationToken']  # retrieve the Auth token

    principal_id = 'abc123'  # fake

    policy = create_policy(event['methodArn'], principal_id)

    if event['authorizationToken']:
        user_info = auth_token_decode(token)
        if user_info:
            policy.allowAllMethods()
        else:
            policy.denyAllMethods()
    else:
        policy.denyAllMethods()

    return policy.build()


def create_policy(method_arn, principal_id):
    tmp = method_arn.split(':')
    region = tmp[3]
    account_id = tmp[4]
    api_id, stage = tmp[5].split('/')[:2]

    policy = AuthPolicy(principal_id, account_id)
    policy.restApiId = api_id
    policy.region = region
    policy.stage = stage

    return policy


class HttpVerb:
    GET = 'GET'
    POST = 'POST'
    PUT = 'PUT'
    PATCH = 'PATCH'
    HEAD = 'HEAD'
    DELETE = 'DELETE'
    OPTIONS = 'OPTIONS'
    ALL = '*'


class AuthPolicy(object):
    # The AWS account id the policy will be generated for. This is used to create the method ARNs.
    awsAccountId = ''
    # The principal used for the policy, this should be a unique identifier for the end user.
    principalId = ''
    # The policy version used for the evaluation. This should always be '2012-10-17'
    version = '2012-10-17'
    # The regular expression used to validate resource paths for the policy
    pathRegex = '^[/.a-zA-Z0-9-\*]+$'

    '''Internal lists of allowed and denied methods.

    These are lists of objects and each object has 2 properties: A resource
    ARN and a nullable conditions statement. The build method processes these
    lists and generates the approriate statements for the final policy.
    '''
    allowMethods = []
    denyMethods = []

    # The API Gateway API id. By default this is set to '*'
    restApiId = '*'
    # The region where the API is deployed. By default this is set to '*'
    region = '*'
    # The name of the stage used in the policy. By default this is set to '*'
    stage = '*'

    def __init__(self, principal, awsAccountId):
        self.awsAccountId = awsAccountId
        self.principalId = principal
        self.allowMethods = []
        self.denyMethods = []

    def _addMethod(self, effect, verb, resource, conditions):
        '''Adds a method to the internal lists of allowed or denied methods. Each object in
        the internal list contains a resource ARN and a condition statement. The condition
        statement can be null.'''
        if verb != '*' and not hasattr(HttpVerb, verb):
            raise NameError('Invalid HTTP verb ' + verb + '. Allowed verbs in HttpVerb class')
        resourcePattern = re.compile(self.pathRegex)
        if not resourcePattern.match(resource):
            raise NameError('Invalid resource path: ' + resource + '. Path should match ' + self.pathRegex)

        if resource[:1] == '/':
            resource = resource[1:]

        resourceArn = 'arn:aws:execute-api:{}:{}:{}/{}/{}/{}'.format(self.region, self.awsAccountId, self.restApiId,
                                                                     self.stage, verb, resource)

        if effect.lower() == 'allow':
            self.allowMethods.append({
                'resourceArn': resourceArn,
                'conditions': conditions
            })
        elif effect.lower() == 'deny':
            self.denyMethods.append({
                'resourceArn': resourceArn,
                'conditions': conditions
            })

    def _getEmptyStatement(self, effect):
        '''Returns an empty statement object prepopulated with the correct action and the
        desired effect.'''
        statement = {
            'Action': 'execute-api:Invoke',
            'Effect': effect[:1].upper() + effect[1:].lower(),
            'Resource': []
        }

        return statement

    def _getStatementForEffect(self, effect, methods):
        '''This function loops over an array of objects containing a resourceArn and
        conditions statement and generates the array of statements for the policy.'''
        statements = []

        if len(methods) > 0:
            statement = self._getEmptyStatement(effect)

            for curMethod in methods:
                if curMethod['conditions'] is None or len(curMethod['conditions']) == 0:
                    statement['Resource'].append(curMethod['resourceArn'])
                else:
                    conditionalStatement = self._getEmptyStatement(effect)
                    conditionalStatement['Resource'].append(curMethod['resourceArn'])
                    conditionalStatement['Condition'] = curMethod['conditions']
                    statements.append(conditionalStatement)

            if statement['Resource']:
                statements.append(statement)

        return statements

    def allowAllMethods(self):
        '''Adds a '*' allow to the policy to authorize access to all methods of an API'''
        self._addMethod('Allow', HttpVerb.ALL, '*', [])

    def denyAllMethods(self):
        '''Adds a '*' allow to the policy to deny access to all methods of an API'''
        self._addMethod('Deny', HttpVerb.ALL, '*', [])

    def allowMethod(self, verb, resource):
        '''Adds an API Gateway method (Http verb + Resource path) to the list of allowed
        methods for the policy'''
        self._addMethod('Allow', verb, resource, [])

    def denyMethod(self, verb, resource):
        '''Adds an API Gateway method (Http verb + Resource path) to the list of denied
        methods for the policy'''
        self._addMethod('Deny', verb, resource, [])

    def allowMethodWithConditions(self, verb, resource, conditions):
        '''Adds an API Gateway method (Http verb + Resource path) to the list of allowed
        methods and includes a condition for the policy statement. More on AWS policy
        conditions here: http://docs.aws.amazon.com/IAM/latest/UserGuide/reference_policies_elements.html#Condition'''
        self._addMethod('Allow', verb, resource, conditions)

    def denyMethodWithConditions(self, verb, resource, conditions):
        '''Adds an API Gateway method (Http verb + Resource path) to the list of denied
        methods and includes a condition for the policy statement. More on AWS policy
        conditions here: http://docs.aws.amazon.com/IAM/latest/UserGuide/reference_policies_elements.html#Condition'''
        self._addMethod('Deny', verb, resource, conditions)

    def build(self):
        '''Generates the policy document based on the internal lists of allowed and denied
        conditions. This will generate a policy with two main statements for the effect:
        one statement for Allow and one statement for Deny.
        Methods that includes conditions will have their own statement in the policy.'''
        if ((self.allowMethods is None or len(self.allowMethods) == 0) and
                (self.denyMethods is None or len(self.denyMethods) == 0)):
            raise NameError('No statements defined for the policy')

        policy = {
            'principalId': self.principalId,
            'policyDocument': {
                'Version': self.version,
                'Statement': []
            }
        }

        policy['policyDocument']['Statement'].extend(self._getStatementForEffect('Allow', self.allowMethods))
        policy['policyDocument']['Statement'].extend(self._getStatementForEffect('Deny', self.denyMethods))

        return policy

If you have observed the logic carefully, you will notice that I have slightly refactored the code according to my requirements.

You can see line number 23, that I am retrieving the authorization token from the headers. Watch at line 30 I am calling a method "auth_token_decode" which is going to check whether the token is valid or not. If the token is valid then the policy is going to allow all HTTP methods else all methods will be denied.

auth_step_4

I won’t be going into too much detail as the blueprint code is already available in AWS Lambda. For reference, visit this link.

Let’s open the template.yaml where I am going to define the authorizer under "Resources". The identifier is going to be named as "MyApi".

auth_step_5

Under "Properties", I am going to define the "StageName". Stage name represents an API stage, you can provide any random text, but the common stage name which is widely used is dev,prod,stage or test. You will see it appearing in the API Gateway URI (Uniform Resource Identifier).

Under Auth, I will provide a name to my default authorizer. I am going to name it as "JWTCustomAuthorizer".

FunctionArn is our Lambda function which is handling the authorization process. Arn stands for Amazon Resource Name.

It’s a naming convention to identify a resource.

As you can see in the below image, "JWTAuthFunction" is going to process the authorization once we receive the tokens. As usual, we are going to provide the handler, runtime etc. We are also using an environment variable called "SECRET_KEY" for encoding and decoding of our JWT Tokens.

auth_step_6

Okay, we have defined the CustomAuthorizer. Let us now make a reference to each function so each API will only work when a token is provided.

Below I have provided the final code snippet how it is going to look like.

{ % include "./demos/template.yaml" % }

As you can see in my screen for the "CreateUserAPI", I am referring to the CustomAuthorizer through "RestApiId". Until then, I have set the Authorizer as None, which is completely valid. I can exclude APIs for which I don’t want an authorizer for example like the LoginAPI.

auth_step_7

As I am done with my template.yaml file, let me go to the codebase and create a package for Login API.

I will follow the standard process that I did for the previous apis.

auth_step_8

I am going to create a token.py file where I will be defining two functions one for creating a new jwt token, and the other refresh token which will be used to generate new token based on previous token validity until it hasn't expired.

JWT stands for JSON Web Token. It is a self-contained way for securely transmitting information between parties as a JSON object. This information can be verified and trusted because it is digitally signed. JWTs can be signed using a secret or a public/private key pair. Again, JWT is a standard, meaning that all JWTs are tokens, but not all tokens are JWTs.

If you want to get more information about JWT then visit jwt.io.

Below, I have provided the final code snippet how it is going to look like.

import datetime
import os

import jwt


def create_access_token(result):
    # Returns new JWT Token.
    jwt_info = jwt.encode({
        "id": str(result["_id"]),
        "first_name": result["first_name"],
        "last_name": result["last_name"],
        "exp": datetime.datetime.utcnow() + datetime.timedelta(seconds=300)}, os.environ['SECRET_KEY'])

    return jwt_info


def refresh_token(token):
    # Refresh Token if the token hasn't expired.
    try:
        result = jwt.decode(token, os.environ['SECRET_KEY'], algorithms=["HS256"])
        jwt_info = jwt.encode({**result, "exp": datetime.datetime.utcnow() + datetime.timedelta(seconds=300)},
                              os.environ['SECRET_KEY'])

        return {"status": True, "data": jwt_info, "message": None}
    except jwt.exceptions.DecodeError:
        return {"status": False, "data": None, "message": "Unable to decode data !"}
    except jwt.ExpiredSignatureError:
        return {"status": False, "data": None, "message": "Token has expired !"}

As you can see the function "create_access_token" which is going to take the user information and encode it and return us back with a long encrypted string. This token is only valid for 5 minutes.

Ok we are done with the tokens. Let’s move to the validator.py file.

I am going to create "UserLoginSchema" in which I will be taking email and password as required input. In the validation function I will check whether the email provided exists in the db or not and not. For the password I will verify with the encrypted password stored in the database.

If the password verification is successful then I will generate a new token and send it back in the response.

auth_step_9

Next, I am going to create a "RefreshTokenSchema" which takes the token as required input. If the existing token is valid then it is going to return a new token in the response else it will raise a validation error.

Below, I have provided the final code snippet how the validator.py is going to look like.

import argon2
from argon2 import PasswordHasher
from marshmallow import Schema, fields, ValidationError, post_load

from . import db
from . import token


class UserLoginSchema(Schema):
    email = fields.Email(required=True)
    password = fields.Str(required=True)

    @post_load
    def validate_email_password(self, data, **kwargs):
        mongo = db.MongoDBConnection()
        with mongo:
            database = mongo.connection['myDB']
            collection = database['registrations']
            result = collection.find_one({"email": data["email"]})
            if result is None:
                raise ValidationError('Sorry! You have provided invalid email.')
            else:
                ph = PasswordHasher()
                try:
                    ph.verify(result['password'], data['password'])
                    data['token'] = token.create_access_token(result)
                except argon2.exceptions.VerifyMismatchError:
                    raise ValidationError('The password is invalid.')

        return data


class RefreshTokenSchema(Schema):
    token = fields.Str(required=True)

    @post_load
    def validate_token(self, data, **kwargs):
        refresh_token = token.refresh_token(data['token'])
        if refresh_token['status']:
            data['token'] = refresh_token['data']
        else:
            raise ValidationError(refresh_token['message'])

        return data

Let’s move to app.py. I am going to define the lambda handler. As you have seen before I will be parsing the event body and passing into my schema for validation and post-processing.

If the validation is successful then it will return a token with 200 http response else it will raise an error with 400 status code.

auth_step_10

I will be also defining one more function "token_refresh" in the same file. It is also going to perform the same kind of operation taking an existing token as input and returning it back with a new token.

Below, I have provided the final code snippet how the app.py is going to look like.

import ujson
from marshmallow import ValidationError

from .utils import validator


def lambda_handler(event, context):
    try:
        body = ujson.loads(event['body'])
        result = validator.UserLoginSchema()
        res = not bool(result.validate(body))

        if res:
            return {
                "statusCode": 200,
                "body": ujson.dumps({
                    "message": "Welcome !",
                    "data": {
                        "token": result.load(body)['token']
                    }
                })
            }
        else:
            return {
                "statusCode": 400,
                "body": ujson.dumps({
                    "message": "Error !",
                    "data": result.validate(body)
                })
            }
    except ValidationError as err:
        return {
            "statusCode": 400,
            "body": ujson.dumps({
                "message": err.messages
            })
        }
    except KeyError as error:
        return {
            "statusCode": 400,
            "body": ujson.dumps({
                "message": "Something went wrong. Unable to parse data ! " + str(error)
            })
        }


def token_refresh(event, context):
    try:
        body = ujson.loads(event['body'])
        result = validator.RefreshTokenSchema()
        res = not bool(result.validate(body))

        if res:
            return {
                "statusCode": 200,
                "body": ujson.dumps({
                    "message": None,
                    "data": result.load(body)
                })
            }
        else:
            return {
                "statusCode": 400,
                "body": ujson.dumps({
                    "message": "Error !",
                    "data": result.validate(body)
                })
            }

    except ValidationError as err:
        return {
            "statusCode": 400,
            "body": ujson.dumps({
                "message": err.messages
            })
        }

    except KeyError:
        return {
            "statusCode": 400,
            "body": ujson.dumps({"message": "Something went wrong. Unable to parse data !"})
        }

Now, I am going to register the UserLogin & RefreshToken API in the template.yaml file.

auth_step_11

auth_step_12

We have registered the APIs. Before testing out the API there are some issues which I have missed, let me fix that. Open app.py and goto line number 19 it should be "load" not loads, remove ‘s’.

Next, I will goto token.py file remove "decode" in line number 15.

For the refresh token function in line number 21, I need to mention the algorithm "HS256". You can follow the source code in GitHub. It's already up-to-date.

Let’s test out the functionality. I will click on Run and then Edit Configurations.

You can see I have provided "email" and "password".

auth_step_13

Let me run the function and check out what response I get.

auth_step_14

Okay, I am receiving a new token. Now, I am going to pass this token to the RefreshToken function and in return I will receive a new token.

auth_step_15

I am receiving a new token. Both the functions are working Great !

In the upcoming tutorial I will be writing Unit Tests to make sure our functions are working fine before deployment.