CNLearn FastAPI - Adding Authentication

Today, we will add authentication end points. Remember how we added registration and that was it? Well, that’s not very useful. We need to authenticate somehow so that we can create features for individual profiles (my profile realy… :) )

We are still somewhat basing it on the cookiecutter template here but making changes, removing stuff, adding tests, etc.

For now, this might change later on (depends if I’m lazy or not), we will use the in-built FastAPI authorisation system that uses an OAuth2 “password” flow (the Oauth2 Resource Owner Password Credentials Grant). Warning: This Grant type is discouraged by OAuth and will be removed in OAuth 2.1. Since this is a learning project and we won’t separate our application server from our authorisation server, we will use it. If you do use it, please don’t, or think about this.

What do we need to add?

  • Well, we need a login endpoint
  • When somebody logs in, this will generate a JSON Web Token (JWT) they can use in all subsequent requests
  • in subsequent requests to protected endpoint, the token will be attached to the request. The token will then be verified and if valid, will authorise the user to access that protected endpoint

How will the user be verified in the request? Well, we will use a dependency injection that will return the current active user associated with that JWT. The JWT will get decoded (at least attempted to be), the username obtained from payload (from the “sub” of the “claim” part of the JWT payload, more info here andhere) and the request authorised.

Confused? Ok let’s go bit by bit.

Creating the Login Endpoint

Let’s start with the code for logging in and obtaining a token. In the code below, the Token Pydantic schema is:

class Token(BaseModel):
    access_token: str
    token_type: str

Notice it has a token_type (which will always be “bearer” for now) and an access_token (which will be the actual JSON Web Token).

@router.post("/login", response_model=schemas.Token)
def login_access_token(
    db: Session = Depends(database.get_async_session), form_data: OAuth2PasswordRequestForm = Depends()
) -> Any:
    """
    OAuth2 compatible token login, get an access token for future requests
    """
    user = crud.user.authenticate(
        db, email=form_data.username, password=form_data.password
    )
    if not user:
        raise HTTPException(status_code=400, detail="Incorrect email or password")
    elif not crud.user.is_active(user):
        raise HTTPException(status_code=400, detail="Inactive user")
    access_token_expires = timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES)
    return {
        "access_token": create_access_token(
            user.id, expires_delta=access_token_expires
        ),
        "token_type": "bearer",
    }

Ok so looking at our dependencies, we have the usual database session (makes sense, we’ll need to check the password against the hashed password) but also this OAuth2PasswordRequestForm. What is that? Well, when posting to this endpoint, it will expect a username and password field in the form-data (you can see more information here). This dependency will essentially create a class where username and password are attributes on the form_data object. Notice that we are then using those two attributes when authenticating the user (as a reminder, authentication finds a user by email. If there is such a username, it verifies the password provided against the hashed password by hashing the provided password and comparing them). If there is no user, or if there is no user, a Bad Request is returned. Otherwise, it creates an access token (we will look at that soon). It’s worth noting taht the access token has an expiration date (we will look at this soon too).

Creating the JSON Web Token

The JSON Web Token is an access-token with an expiration date. There are two important parts to using one:

  • create/encode it (embedding the needed information as well as the expiration date using the secret key) using a specific algorithm
  • decoding it -> if the JSON Web Token can be decoded, it’s still valid. If it’s a fake token, the decoding will fail. If it’s an expired token, the decoding will fail.

Let’s see how it gets created:

def create_access_token(
    subject: Union[str, Any], expires_delta: Optional[timedelta] = None
) -> str:
    if expires_delta:
        expire = datetime.utcnow() + expires_delta
    else:
        expire = datetime.utcnow() + timedelta(
            minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES
        )
    to_encode = {"exp": expire, "sub": str(subject)}
    encoded_jwt = jwt.encode(to_encode, settings.SECRET_KEY, algorithm=ALGORITHM)
    return encoded_jwt

It takes in an optional timedelta expires_delta or the default duration from settings (with a default value of 8 days, we’ll decrease that later I think) and adds it to the time now in order to set an expiration time for the token. It then creates the payload of the token, which will only have two of the registered claims (later one we might add more, depending on what we do with our authentication system): exp (expiration time) and sub (subject). It then encodes it and returns it.

What about how it gets verified? Well, we have a function that will verify the token and return a dict payload if all went well. Otherwise, it raises the relevant exception (after logging, which is not yet implemented).

def decode_access_token(token: str) -> Optional[dict]:
    try:
        payload: dict = jwt.decode(token, settings.SECRET_KEY, algorithms=[ALGORITHM])
    except (JWTError, ExpiredSignatureError):
        # TODO: log this
        raise
    return payload

Let’s test that it works. In tests/core/test_security.py, we will have a test_encoding_decoding_tokens test function. Let’s patch/mock some of the external things like settings and parametrize a few things so we run a few things. Please note this is not a great test, I definitely should not sleep in it…but I will add the dependency in the next post to travel to the future.

@pytest.mark.parametrize(
    ("subject", "additional_string", "expires_delta", "expired", "expectation"),
    [
        # I really hope this test does not fail...hopefully faster than 999 days
        ("vlad", "", timedelta(days=999), False, does_not_raise()),
        ("vlad", "", None, False, does_not_raise()),
        # TODO: obviously this needs to change...next post
        ("vlad", "", timedelta(microseconds=1), True, pytest.raises(ExpiredSignatureError)),
        ("vlad", "_i_4m_a_h4ck3r", None, False, pytest.raises(JWTError)),
    ],
)
@mock.patch("app.core.security.settings")
def test_encoding_decoding_tokens(
    # the following comes from our patch
    settings_mock: mock.MagicMock,
    # the following comes from our test pytest parameters
    subject: str,
    additional_string: str,
    expires_delta: Optional[timedelta],
    expired: bool,
    expectation: contextmanager
):
    settings_mock.ACCESS_TOKEN_EXPIRE_MINUTES = 10
    settings_mock.SECRET_KEY = "wowsosecret"


    encoded_token: str = create_access_token(subject=subject, expires_delta=expires_delta)
    if expired and expires_delta:
        sleep(expires_delta.microseconds)
    with expectation:
        payload = decode_access_token(encoded_token+additional_string)
        assert payload["sub"] == subject

There are four cases we are testing: one that decodes successfully with a specific timedelta, one that decodes successfully with a default timedelta, one that expires and fails to decode and one that gets tampered with and also fails decoding.

Using the JWT

We have created (encoded) the JWT and decoded it. Now let’s actually use it. For now, we will create a simple endpoint that will return the user’s details. It will be very simple, it reads the user that requests the endpoint by using a dependency. that returns the current user. How does it do that? Well, let’s have a look at it. In app/api/dependencies/user.py:

from fastapi import Depends, HTTPException, status
from fastapi.security import OAuth2PasswordBearer
from jose import ExpiredSignatureError, JWTError
from pydantic import ValidationError
from sqlalchemy.ext.asyncio import AsyncSession

from app import crud, schemas
from app.api.dependencies import database
from app.core import security
from app.models.user import User
from app.settings.base import settings

reusable_oauth2 = OAuth2PasswordBearer(tokenUrl=f"{settings.API_V1_STR}/login/access-token")

We have all the imports we need, please note I haven’t put all the code in yet. What is that reusable_oauth2?? Well it will used as a dependency shortly. If we look at the code for it, taken from here:

    async def __call__(self, request: Request) -> Optional[str]:
        authorization: str = request.headers.get("Authorization")
        scheme, param = get_authorization_scheme_param(authorization)
        if not authorization or scheme.lower() != "bearer":
            if self.auto_error:
                raise HTTPException(
                    status_code=HTTP_401_UNAUTHORIZED,
                    detail="Not authenticated",
                    headers={"WWW-Authenticate": "Bearer"},
                )
            else:
                return None
        return param

It looks at the request from when we request the endpoint and checks the headers for a possible “Authorization” header. The next line basically splits it into the bearer and the token part. If any are missing, a 401 Unauthorized is returned. Otherwise, the token is returned as a string. Where is the token used? Well, it’s used in the following function (which itself will be used as a dependency).

async def get_current_user(
    db: AsyncSession = Depends(database.get_async_session), token: str = Depends(reusable_oauth2)
) -> User:
    try:
        payload = security.decode_access_token(token)
        token_data = schemas.TokenPayload(**payload)
    except (JWTError, ExpiredSignatureError, ValidationError):
        raise HTTPException(
            status_code=status.HTTP_403_FORBIDDEN,
            detail="Could not validate credentials",
        )
    user = await crud.user.get(db, id=token_data.sub)
    if not user:
        raise HTTPException(status_code=404, detail="User not found")
    return user

It has two dependencies itself (the usual get_async_session) and the token. If both dependencies return, we try to decode the token accounting for the JWTError (incorrect format), it being expired or just a generic Pydantic ValidationError. In those cases, we return a Forbidden 403. If, however, the token is successfully decoded, we use the token_data.sub (corresponding to the id column in the database) to get the user. If there is such a user, we return it. Otherwise, we return a 404 Not Found.

Now that we have seen all these dependencies, let’s actually look at our user details endpoint. I won’t expand on it too much, nor will I add a lot of detail because it will be changed shortly.

@router.get("/login/me")
async def read_users_me(current_user: schemas.User = Depends(user.get_current_user)):
    return current_user

It’s a very simple GET endpoint with one dependency. The get_current_user, in turn, has the get_async_session and reusable_oauth2 dependencies. I think that’s one of the thing I like most about FastAPI, its dependency injection system.

Testing

Now we’ll write a test for this new endpoint (even thought it will be updated/remove in the following post) -> it will serve as a good base for all subsequent authentication-needed endpoints.

But first, shouldn’t we test that our login endopoint? You’re absolutely right that we should. How are we gonna do that? Let’s think of the steps needed:

  1. We need to create a user object in the db (so we can log in with it)
  2. We need to hit the /login/access-token endpoint with the username/password (and application/x-www-form-urlencoded content-type in the headers of course)
  3. We check whether there is an access token in the response and whether the token is of bearer type

User creation fixture

We don’t want to actually call our user registration endpoint here, so let’s write a fixture that will do that for us. We want a fixture that we can call like

    email: str = "admin@cnlearn.app"
    password: str = "thisissecret"
    await create_user_object(email=email, password=password)

in our test. Let’s set it up. We want the fixture to be able to take some keyword arguments for email and password (and an optional last_name) and create the user object in the database. The fixture is as follows:

@pytest_asyncio.fixture
async def create_user_object(get_async_session: AsyncSession):
    async def _create_user(email: str, password: str, full_name: Optional[str] = None) -> User:
        db: AsyncSession = get_async_session
        user_in = schemas.UserCreate(password=password, email=email, full_name=full_name)
        new_user = await user.create(db, obj_in=user_in)
        return new_user
    return _create_user

Please note it’s a function returning an inner function (that’s how you pass arguments to pytest fixtures). Note it itself has a get_async_session dependency. It creates a User (schema) object, then passes that to the user crud create method. Finally, it returns the user. For now, we won’t be using the User object that is returned. Let’s look at our test now.

Testing the login endpoint

Starting with the function signature:

@pytest.mark.asyncio
async def test_login_access_token(
    client: AsyncClient,
    app: FastAPI,
    create_user_object: Callable[..., User],
    clean_users_table: Callable[[None], None],
):

We have four dependencies, the AsyncClient, the FastAPI app, the create_user_object (which we will use to create a user in the database) and the clean_users_table (because I don’t want to have to think about whether I’ve used this email/password combination before during our tests so I will delete all users). We then create the user:

    email: str = "admin@cnlearn.app"
    password: str = "thisissecret"
    await create_user_object(email=email, password=password)

Now let’s actually try to log in:

    login_url: str = app.url_path_for("user:access-token")
    response: Response = await client.post(
        url=login_url,
        data={
            "username": email,
            "password": password,
        },
        # we need to change the client's headers content-type
        headers={"content-type": "application/x-www-form-urlencoded"},
    )

I don’t like having hard-coded endpoints here (and for now we still have a hardcoded tokenUrl from earlier, I still need to get to fixing that). An important thing is that we have to sent the data as form-data, which is why we changed the ContentType headers. Finally, let’s make some assertions:

    assert response.status_code == 200
    json_response: dict[str, Any] = response.json()
    assert "access_token" in json_response
    assert json_response["token_type"] == "bearer"

We test that we get a 200 OK code, and that we both have the access_token in the resposne as well as the type of the response being bearer.

What if it’s a bad email/password? We have a test for that too.

@pytest.mark.asyncio
async def test_login_access_token_fail(
    client: AsyncClient,
    app: FastAPI,
):
    # let's hit the login endpoint
    login_url: str = app.url_path_for("user:access-token")
    response: Response = await client.post(
        url=login_url,
        data={
            "username": "email@notexist.com",
            "password": "badbadpassword",
        },
        # we need to change the client's headers content-type
        headers={"content-type": "application/x-www-form-urlencoded"},
    )
    assert response.status_code == 400
    json_response: dict[str, Any] = response.json()
    assert json_response == {'detail': 'Incorrect email or password'}

Here we’re not creating a user. We’re posting bad data to the login url, and getting a 400 Bad Request response back.

Testing the User:Me endpoint

Ok now we need to test our “user:me” endpoint, which returns information about our user. Since I don’t want to call all the create user, login endpoints in the test, I will offload that to a fixture. All I want to do is call the “user:me” endpoint and verify that we get what we expect to get. We therefore need a fixture that will create an user (which we already do) and that will log in our user (which sounds suspiciously like we just did in our tests just now, nice!).

The following might look familiar:

@pytest_asyncio.fixture
async def return_logged_in_user_bearer_token(
    get_async_session: AsyncSession,
    create_user_object: Callable[..., User],
    client: AsyncClient,
    app: FastAPI,
):
    async def _get_logged_in_user_token(email: str, password: str):
        await create_user_object(email=email, password=password)
        # now let's hit the login endpoint
        login_url: str = app.url_path_for("user:access-token")
        response: Response = await client.post(
            url=login_url,
            data={
                "username": email,
                "password": password,
            },
            # we need to change the client's headers content-type
            headers={"content-type": "application/x-www-form-urlencoded"},
        )
        json_response: dict[str, str] = response.json()
        token: str = json_response["access_token"]
        return token

    return _get_logged_in_user_token

What happens in the fixture? We need to be able to pass in some arguments (I haven’t added full name for now, was lazy), so we have an inner function that we return. In that inner function, we create the user object. We log in. We return the access_token. Cool!

Now that we have this amazingly named return_logged_in_user_bearer_token fixture, let’s use it in our test_read_users_me test. First, the signature:

@pytest.mark.asyncio
async def test_read_users_me(
    client: AsyncClient,
    app: FastAPI,
    return_logged_in_user_bearer_token: Callable[..., str],
    clean_users_table: Callable[[None], None],
):

The fixture should be familiar for now. The client, app, our new fixture, and the one to clear our users table at the end.

The set up for the test:

    email: str = "free@cnlearn.app"
    password: str = "paid"
    token: str = await return_logged_in_user_bearer_token(email=email, password=password)
    headers = {"Authorization": f"Bearer {token}"}
    me_url: str = app.url_path_for("user:me")
    response: Response = await client.get(
        url=me_url,
        headers=headers,
    )
    json_response: dict[str, Any] = response.json()

Note that we are passing in the Bearer authorization header. Shall we assert various things about the response? (you know, the actual test…)

    assert response.status_code == 200
    assert json_response["email"] == email
    assert json_response["is_active"]
    assert not json_response["is_superuser"]
    assert json_response["full_name"] is None
    assert isinstance(json_response["id"], int)

We assert the response is 200 OK, the email is as expected, that the user is active, that they’re not a superuser, that full_name is None (as we didn’t set any) and that the id is an integer. I’m not testing the id value because I don’t care and I don’t want flaky tests.

Finally, let’s repeat the test but let it fail (i.e. let’s try and crack into our application).

@pytest.mark.asyncio
async def test_read_users_me_fail(
    client: AsyncClient,
    app: FastAPI,
):
    me_url: str = app.url_path_for("user:me")
    response: Response = await client.get(
        url=me_url,
    )
    json_response: dict[str, Any] = response.json()
    assert json_response == {'detail': 'Not authenticated'}

Luckily, we fail. We get a 401 Unauthorized response, and that we are not authenticated.

I think that’s it for this post. We’ve done quite a bit I think. Now I need to proof read it (jk), check for mistakes (yea right) and push the code to the repo. This time, I actually created it in a branch and I will create a PR to the main branch (pretending like I have good git practice, you know?)

The commit for this post is here

In the next post we will start moving/rewriting some of the dictionary/search logic from the “app” version (yes I haven’t forgotten about it) from here to the web version. We will first implement endpoints for the two models (Word and Character) and then implement a search endpoint as well. One at a time :)