Securing APIs with JSON Web Tokens (JWT)

Adding Custom Authorizers in Lambda functions

For this tutorial we are going to protect our APIs from unauthorized access by creating Lambda Authorizer, formerly known as CustomAuthorizer. It is an API Gateway feature that uses a Lambda function to control access to your API.

Interested to know more about custom authorizers? Please visit Configure a Lambda authorizer using the API Gateway console .

Custom Authorizers

Okay, let me first remove the Authorizers which I have declared in the template.yaml file. By the way, if you have observed from the beginning, we are using Authorizer as None, but we did not set up the Default Authorizer. If you try to deploy the functions it’s not going to work, and it would lead to an error because you haven’t set up the Authorizer. So, let’s begin.

auth_step_1

Let me start by creating an authorizer package under user. I will follow the usual process like creating a app.py file.

auth_step_2

I will be using an existing blueprint for the Authorization process. It’s already available in AWS Lambda under the Blueprints. I am going to copy this logic into my app.py file.

auth_step_3

Below, I am providing the final snippet of the custom authorizer blueprint.

The source code is available on Github at https://github.com/mukulmantosh/ServerlessDemo

import os
import re
import jwt
def auth_token_decode(auth_token):
"""
Checks whether JWT Token is valid or not.
If valid returns True else False
"""
try:
jwt.decode(auth_token, os.environ['SECRET_KEY'], algorithms=["HS256"])
return True
except jwt.ExpiredSignatureError:
return False
except jwt.InvalidSignatureError:
return False
except jwt.InvalidTokenError:
return False
def lambda_handler(event, context):
token = event['authorizationToken'] # retrieve the Auth token
principal_id = 'abc123' # fake
policy = create_policy(event['methodArn'], principal_id)
if event['authorizationToken']:
user_info = auth_token_decode(token)
if user_info:
policy.allowAllMethods()
else:
policy.denyAllMethods()
else:
policy.denyAllMethods()
return policy.build()
def create_policy(method_arn, principal_id):
tmp = method_arn.split(':')
region = tmp[3]
account_id = tmp[4]
api_id, stage = tmp[5].split('/')[:2]
policy = AuthPolicy(principal_id, account_id)
policy.restApiId = api_id
policy.region = region
policy.stage = stage
return policy
class HttpVerb:
GET = 'GET'
POST = 'POST'
PUT = 'PUT'
PATCH = 'PATCH'
HEAD = 'HEAD'
DELETE = 'DELETE'
OPTIONS = 'OPTIONS'
ALL = '*'
class AuthPolicy(object):
# The AWS account id the policy will be generated for. This is used to create the method ARNs.
awsAccountId = ''
# The principal used for the policy, this should be a unique identifier for the end user.
principalId = ''
# The policy version used for the evaluation. This should always be '2012-10-17'
version = '2012-10-17'
# The regular expression used to validate resource paths for the policy
pathRegex = '^[/.a-zA-Z0-9-\*]+$'
'''Internal lists of allowed and denied methods.
These are lists of objects and each object has 2 properties: A resource
ARN and a nullable conditions statement. The build method processes these
lists and generates the approriate statements for the final policy.
'''
allowMethods = []
denyMethods = []
# The API Gateway API id. By default this is set to '*'
restApiId = '*'
# The region where the API is deployed. By default this is set to '*'
region = '*'
# The name of the stage used in the policy. By default this is set to '*'
stage = '*'
def __init__(self, principal, awsAccountId):
self.awsAccountId = awsAccountId
self.principalId = principal
self.allowMethods = []
self.denyMethods = []
def _addMethod(self, effect, verb, resource, conditions):
'''Adds a method to the internal lists of allowed or denied methods. Each object in
the internal list contains a resource ARN and a condition statement. The condition
statement can be null.'''
if verb != '*' and not hasattr(HttpVerb, verb):
raise NameError('Invalid HTTP verb ' + verb + '. Allowed verbs in HttpVerb class')
resourcePattern = re.compile(self.pathRegex)
if not resourcePattern.match(resource):
raise NameError('Invalid resource path: ' + resource + '. Path should match ' + self.pathRegex)
if resource[:1] == '/':
resource = resource[1:]
resourceArn = 'arn:aws:execute-api:{}:{}:{}/{}/{}/{}'.format(self.region, self.awsAccountId, self.restApiId,
self.stage, verb, resource)
if effect.lower() == 'allow':
self.allowMethods.append({
'resourceArn': resourceArn,
'conditions': conditions
})
elif effect.lower() == 'deny':
self.denyMethods.append({
'resourceArn': resourceArn,
'conditions': conditions
})
def _getEmptyStatement(self, effect):
'''Returns an empty statement object prepopulated with the correct action and the
desired effect.'''
statement = {
'Action': 'execute-api:Invoke',
'Effect': effect[:1].upper() + effect[1:].lower(),
'Resource': []
}
return statement
def _getStatementForEffect(self, effect, methods):
'''This function loops over an array of objects containing a resourceArn and
conditions statement and generates the array of statements for the policy.'''
statements = []
if len(methods) > 0:
statement = self._getEmptyStatement(effect)
for curMethod in methods:
if curMethod['conditions'] is None or len(curMethod['conditions']) == 0:
statement['Resource'].append(curMethod['resourceArn'])
else:
conditionalStatement = self._getEmptyStatement(effect)
conditionalStatement['Resource'].append(curMethod['resourceArn'])
conditionalStatement['Condition'] = curMethod['conditions']
statements.append(conditionalStatement)
if statement['Resource']:
statements.append(statement)
return statements
def allowAllMethods(self):
'''Adds a '*' allow to the policy to authorize access to all methods of an API'''
self._addMethod('Allow', HttpVerb.ALL, '*', [])
def denyAllMethods(self):
'''Adds a '*' allow to the policy to deny access to all methods of an API'''
self._addMethod('Deny', HttpVerb.ALL, '*', [])
def allowMethod(self, verb, resource):
'''Adds an API Gateway method (Http verb + Resource path) to the list of allowed
methods for the policy'''
self._addMethod('Allow', verb, resource, [])
def denyMethod(self, verb, resource):
'''Adds an API Gateway method (Http verb + Resource path) to the list of denied
methods for the policy'''
self._addMethod('Deny', verb, resource, [])
def allowMethodWithConditions(self, verb, resource, conditions):
'''Adds an API Gateway method (Http verb + Resource path) to the list of allowed
methods and includes a condition for the policy statement. More on AWS policy
conditions here: http://docs.aws.amazon.com/IAM/latest/UserGuide/reference_policies_elements.html#Condition'''
self._addMethod('Allow', verb, resource, conditions)
def denyMethodWithConditions(self, verb, resource, conditions):
'''Adds an API Gateway method (Http verb + Resource path) to the list of denied
methods and includes a condition for the policy statement. More on AWS policy
conditions here: http://docs.aws.amazon.com/IAM/latest/UserGuide/reference_policies_elements.html#Condition'''
self._addMethod('Deny', verb, resource, conditions)
def build(self):
'''Generates the policy document based on the internal lists of allowed and denied
conditions. This will generate a policy with two main statements for the effect:
one statement for Allow and one statement for Deny.
Methods that includes conditions will have their own statement in the policy.'''
if ((self.allowMethods is None or len(self.allowMethods) == 0) and
(self.denyMethods is None or len(self.denyMethods) == 0)):
raise NameError('No statements defined for the policy')
policy = {
'principalId': self.principalId,
'policyDocument': {
'Version': self.version,
'Statement': []
}
}
policy['policyDocument']['Statement'].extend(self._getStatementForEffect('Allow', self.allowMethods))
policy['policyDocument']['Statement'].extend(self._getStatementForEffect('Deny', self.denyMethods))
return policy

If you have observed the logic carefully, you will notice that I have slightly refactored the code according to my requirements.

You can see line number 23, that I am retrieving the authorization token from the headers. Watch at line 30 I am calling a method auth_token_decode which is going to check whether the token is valid or not. If the token is valid then the policy is going to allow all HTTP methods else all methods will be denied.

auth_step_4

I won’t be going into too much detail as the blueprint code is already available in AWS Lambda. For reference, visit this link.

Let’s open the template.yaml where I am going to define the authorizer under Resources. The identifier is going to be named as MyApi.

auth_step_5

Under Properties, I am going to define the StageName. Stage name represents an API stage, you can provide any random text, but the common stage name which is widely used is dev, prod, stage or test. You will see it appearing in the API Gateway URI (Uniform Resource Identifier).

Under Auth, I will provide a name to my default authorizer. I am going to name it as JWTCustomAuthorizer.

FunctionArn is our Lambda function which is handling the authorization process. Arn stands for Amazon Resource Name.

It’s a naming convention to identify a resource.

As you can see in the below image, JWTAuthFunction is going to process the authorization once we receive the tokens. As usual, we are going to provide the handler, runtime etc. We are also using an environment variable called SECRET_KEY for encoding and decoding of our JWT Tokens.

auth_step_6

Okay, we have defined the CustomAuthorizer. Let us now make a reference to each function so each API will only work when a token is provided.

Below I have provided the final code snippet how it is going to look like.

AWSTemplateFormatVersion: '2010-09-09'
Transform: AWS::Serverless-2016-10-31
Description: >
ServerlessDemo
Sample SAM Template for ServerlessDemo
# More info about Globals: https://github.com/awslabs/serverless-application-model/blob/master/docs/globals.rst
Globals:
Function:
Timeout: 30
MemorySize: 2048
Environment:
Variables:
SECRET_KEY: secret-info
Resources:
MyApi:
Type: 'AWS::Serverless::Api'
Properties:
StageName: Prod
Auth:
DefaultAuthorizer: JWTCustomAuthorizer
Authorizers:
JWTCustomAuthorizer:
FunctionArn: !GetAtt JWTAuthFunction.Arn
JWTAuthFunction:
Type: 'AWS::Serverless::Function'
Properties:
CodeUri: organizations/user/
Handler: authorizer.app.lambda_handler
Runtime: python3.8
OrganizationUserCreate:
Type: 'AWS::Serverless::Function'
Properties:
CodeUri: organizations/user/
Handler: create.app.lambda_handler
Runtime: python3.8
Events:
CompanyCreateUserPostAPI:
Type: Api
Properties:
RestApiId: !Ref MyApi
Path: '/user'
Method: POST
Auth:
Authorizer: NONE
OrganizationUserRead:
Type: 'AWS::Serverless::Function'
Properties:
CodeUri: organizations/user/
Handler: read.app.lambda_handler
Runtime: python3.8
Events:
CompanyUserGetAPI:
Type: Api
Properties:
RestApiId: !Ref MyApi
Path: '/user'
Method: GET
OrganizationUserReadById:
Type: 'AWS::Serverless::Function'
Properties:
CodeUri: organizations/user/
Handler: read.app.lambda_handler
Runtime: python3.8
Events:
CompanyUserGetByIdAPI:
Type: Api
Properties:
RestApiId: !Ref MyApi
Path: '/user/{Id}'
Method: GET
OrganizationUserUpdateById:
Type: 'AWS::Serverless::Function'
Properties:
CodeUri: organizations/user/
Handler: update.app.lambda_handler
Runtime: python3.8
Events:
CompanyUserUpdateByIdAPI:
Type: Api
Properties:
RestApiId: !Ref MyApi
Path: '/user/{Id}'
Method: PUT
OrganizationUserDeleteById:
Type: 'AWS::Serverless::Function'
Properties:
CodeUri: organizations/user/
Handler: delete.app.lambda_handler
Runtime: python3.8
Events:
CompanyUserDeleteByIdAPI:
Type: Api
Properties:
RestApiId: !Ref MyApi
Path: '/user/{Id}'
Method: DELETE
OrganizationUserLogin:
Type: 'AWS::Serverless::Function'
Properties:
CodeUri: organizations/user/
Handler: login.app.lambda_handler
Runtime: python3.8
Events:
LoginAPI:
Type: Api
Properties:
RestApiId: !Ref MyApi
Path: '/user/login'
Method: POST
Auth:
Authorizer: NONE
OrganizationRefreshToken:
Type: 'AWS::Serverless::Function'
Properties:
CodeUri: organizations/user/
Handler: login.app.token_refresh
Runtime: python3.8
Events:
RefreshTokenAPI:
Type: Api
Properties:
RestApiId: !Ref MyApi
Path: '/user/refresh-token'
Method: POST

As you can see in my screen for the CreateUserAPI, I am referring to the CustomAuthorizer through RestApiId. Until then, I have set the Authorizer as None which is completely valid. I can exclude APIs for which I don’t want an authorizer for example like the LoginAPI.

auth_step_7

As I am done with my template.yaml file, let me go to the codebase and create a package for Login API.

I will follow the standard process that I did for the previous apis.

auth_step_8

I am going to create a token.py file where I will be defining two functions one for creating a new jwt token, and the other refresh token which will be used to generate new token based on previous token validity until it hasn't expired.

JWT stands for JSON Web Token it is a self-contained way for securely transmitting information between parties as a JSON object. This information can be verified and trusted because it is digitally signed. JWTs can be signed using a secret or a public/private key pair. Again, JWT is a standard, meaning that all JWTs are tokens, but not all tokens are JWTs.

If you want to get more information about JWT then visit jwt.io.

Below, I have provided the final code snippet how it is going to look like.

import datetime
import os
import jwt
def create_access_token(result):
# Returns new JWT Token.
jwt_info = jwt.encode({
"id": str(result["_id"]),
"first_name": result["first_name"],
"last_name": result["last_name"],
"exp": datetime.datetime.utcnow() + datetime.timedelta(seconds=300)}, os.environ['SECRET_KEY'])
return jwt_info
def refresh_token(token):
# Refresh Token if the token hasn't expired.
try:
result = jwt.decode(token, os.environ['SECRET_KEY'], algorithms=["HS256"])
jwt_info = jwt.encode({**result, "exp": datetime.datetime.utcnow() + datetime.timedelta(seconds=300)},
os.environ['SECRET_KEY'])
return {"status": True, "data": jwt_info, "message": None}
except jwt.exceptions.DecodeError:
return {"status": False, "data": None, "message": "Unable to decode data !"}
except jwt.ExpiredSignatureError:
return {"status": False, "data": None, "message": "Token has expired !"}

As you can see the function create_access_token which is going to take the user information and encode it and return us back with a long encrypted string. This token is only valid for 5 minutes.

Ok we are done with the tokens. Let’s move to the validator.py file.

I am going to create UserLoginSchema in which I will be taking email and password as required input. In the validation function I will check whether the email provided exists in the db or not and not. For the password I will verify with the encrypted password stored in the database.

If the password verification is successful then I will generate a new token and send it back in the response.

auth_step_9

Next, I am going to create a RefreshTokenSchema which takes the token as required input. If the existing token is valid then it is going to return a new token in the response else it will raise a validation error.

Below, I have provided the final code snippet how the validator.py is going to look like.

import argon2
from argon2 import PasswordHasher
from marshmallow import Schema, fields, ValidationError, post_load
from . import db
from . import token
class UserLoginSchema(Schema):
email = fields.Email(required=True)
password = fields.Str(required=True)
@post_load
def validate_email_password(self, data, **kwargs):
mongo = db.MongoDBConnection()
with mongo:
database = mongo.connection['myDB']
collection = database['registrations']
result = collection.find_one({"email": data["email"]})
if result is None:
raise ValidationError('Sorry! You have provided invalid email.')
else:
ph = PasswordHasher()
try:
ph.verify(result['password'], data['password'])
data['token'] = token.create_access_token(result)
except argon2.exceptions.VerifyMismatchError:
raise ValidationError('The password is invalid.')
return data
class RefreshTokenSchema(Schema):
token = fields.Str(required=True)
@post_load
def validate_token(self, data, **kwargs):
refresh_token = token.refresh_token(data['token'])
if refresh_token['status']:
data['token'] = refresh_token['data']
else:
raise ValidationError(refresh_token['message'])
return data

Let’s move to app.py. I am going to define the lambda handler. As you have seen before I will be parsing the event body and passing into my schema for validation and post-processing.

If the validation is successful then it will return a token with 200 http response else it will raise an error with 400 status code.

auth_step_10

I will be also defining one more function token_refresh in the same file. It is also going to perform the same kind of operation taking an existing token as input and returning it back with a new token.

Below, I have provided the final code snippet how the app.py is going to look like.

import ujson
from marshmallow import ValidationError
from .utils import validator
def lambda_handler(event, context):
try:
body = ujson.loads(event['body'])
result = validator.UserLoginSchema()
res = not bool(result.validate(body))
if res:
return {
"statusCode": 200,
"body": ujson.dumps({
"message": "Welcome !",
"data": {
"token": result.load(body)['token']
}
})
}
else:
return {
"statusCode": 400,
"body": ujson.dumps({
"message": "Error !",
"data": result.validate(body)
})
}
except ValidationError as err:
return {
"statusCode": 400,
"body": ujson.dumps({
"message": err.messages
})
}
except KeyError as error:
return {
"statusCode": 400,
"body": ujson.dumps({
"message": "Something went wrong. Unable to parse data ! " + str(error)
})
}
def token_refresh(event, context):
try:
body = ujson.loads(event['body'])
result = validator.RefreshTokenSchema()
res = not bool(result.validate(body))
if res:
return {
"statusCode": 200,
"body": ujson.dumps({
"message": None,
"data": result.load(body)
})
}
else:
return {
"statusCode": 400,
"body": ujson.dumps({
"message": "Error !",
"data": result.validate(body)
})
}
except ValidationError as err:
return {
"statusCode": 400,
"body": ujson.dumps({
"message": err.messages
})
}
except KeyError:
return {
"statusCode": 400,
"body": ujson.dumps({"message": "Something went wrong. Unable to parse data !"})
}

Now, I am going to register the UserLogin & RefreshToken API in the template.yaml file.

auth_step_11

auth_step_12

We have registered the APIs. Before testing out the API there are some issues which I have missed, let me fix that. Open app.py and goto line number 19 it should be load not loads, remove ‘s’.

Next, I will goto token.py file remove decode in line number 15.

For the refresh token function in line number 21, I need to mention the algorithm HS256. You can follow the source code in Github. It's already up-to-date.

Let’s test out the functionality. I will click on Run and then Edit Configurations

You can see I have provided email and password.

auth_step_13

Let me run the function and check out what response I get.

auth_step_14

Okay, I am receiving a new token. Now, I am going to pass this token to the RefreshToken function and in return I will receive a new token.

auth_step_15

I am receiving a new token. Both the functions are working Great !

In the upcoming tutorial I will be writing Unit Tests to make sure our functions are working fine before deployment.