{-# OPTIONS_GHC -fno-warn-orphans #-}

module Mensam.Server.Server.Auth where

import Mensam.API.Aeson.StaticText
import Mensam.API.Data.User
import Mensam.API.Data.User.Username
import Mensam.Server.Application.SeldaPool.Class
import Mensam.Server.User

import Control.Monad.Catch
import Control.Monad.IO.Class
import Control.Monad.Logger.CallStack
import Crypto.JOSE.JWK qualified as JOSE
import Data.Kind
import Data.Password.Bcrypt
import Data.Text qualified as T
import Data.Text.Encoding qualified as T
import Servant hiding (BasicAuthResult (..))
import Servant.Auth.JWT.WithSession
import Servant.Auth.Server

handleAuthBasic ::
  ( MonadLogger m
  , IsMember (WithStatus 401 ErrorBasicAuth) responses
  ) =>
  AuthResult a ->
  (a -> m (Union responses)) ->
  m (Union responses)
handleAuthBasic :: forall (m :: * -> *) (responses :: [*]) a.
(MonadLogger m,
 IsMember (WithStatus 401 ErrorBasicAuth) responses) =>
AuthResult a -> (a -> m (Union responses)) -> m (Union responses)
handleAuthBasic AuthResult a
authResult a -> m (Union responses)
handler = do
  Text -> m ()
forall (m :: * -> *). (HasCallStack, MonadLogger m) => Text -> m ()
logDebug Text
"Handling result of Basic authentication."
  case AuthResult a
authResult of
    Authenticated a
authenticated -> do
      Text -> m ()
forall (m :: * -> *). (HasCallStack, MonadLogger m) => Text -> m ()
logInfo Text
"Starting handler after successful authentication."
      a -> m (Union responses)
handler a
authenticated
    AuthResult a
BadPassword -> do
      Text -> m ()
forall (m :: * -> *). (HasCallStack, MonadLogger m) => Text -> m ()
logInfo Text
"Can't access handler, because authentication failed due to wrong password."
      WithStatus 401 ErrorBasicAuth -> m (Union responses)
forall x (xs :: [*]) (f :: * -> *).
(Applicative f, HasStatus x, IsMember x xs) =>
x -> f (Union xs)
respond (WithStatus 401 ErrorBasicAuth -> m (Union responses))
-> WithStatus 401 ErrorBasicAuth -> m (Union responses)
forall a b. (a -> b) -> a -> b
$ forall (k :: Nat) a. a -> WithStatus k a
WithStatus @401 ErrorBasicAuth
MkErrorBasicAuthPassword
    AuthResult a
NoSuchUser -> do
      Text -> m ()
forall (m :: * -> *). (HasCallStack, MonadLogger m) => Text -> m ()
logInfo Text
"Can't access handler, because authentication failed due to not existing username."
      WithStatus 401 ErrorBasicAuth -> m (Union responses)
forall x (xs :: [*]) (f :: * -> *).
(Applicative f, HasStatus x, IsMember x xs) =>
x -> f (Union xs)
respond (WithStatus 401 ErrorBasicAuth -> m (Union responses))
-> WithStatus 401 ErrorBasicAuth -> m (Union responses)
forall a b. (a -> b) -> a -> b
$ forall (k :: Nat) a. a -> WithStatus k a
WithStatus @401 ErrorBasicAuth
MkErrorBasicAuthUsername
    AuthResult a
Indefinite -> do
      Text -> m ()
forall (m :: * -> *). (HasCallStack, MonadLogger m) => Text -> m ()
logInfo Text
"Can't access handler, because authentication failed for some reason."
      WithStatus 401 ErrorBasicAuth -> m (Union responses)
forall x (xs :: [*]) (f :: * -> *).
(Applicative f, HasStatus x, IsMember x xs) =>
x -> f (Union xs)
respond (WithStatus 401 ErrorBasicAuth -> m (Union responses))
-> WithStatus 401 ErrorBasicAuth -> m (Union responses)
forall a b. (a -> b) -> a -> b
$ forall (k :: Nat) a. a -> WithStatus k a
WithStatus @401 ErrorBasicAuth
MkErrorBasicAuthIndefinite

handleAuthBearer ::
  ( MonadLogger m
  , IsMember (WithStatus 401 ErrorBearerAuth) responses
  ) =>
  AuthResult a ->
  (a -> m (Union responses)) ->
  m (Union responses)
handleAuthBearer :: forall (m :: * -> *) (responses :: [*]) a.
(MonadLogger m,
 IsMember (WithStatus 401 ErrorBearerAuth) responses) =>
AuthResult a -> (a -> m (Union responses)) -> m (Union responses)
handleAuthBearer AuthResult a
authResult a -> m (Union responses)
handler = do
  Text -> m ()
forall (m :: * -> *). (HasCallStack, MonadLogger m) => Text -> m ()
logDebug Text
"Handling result of Bearer authentication."
  case AuthResult a
authResult of
    Authenticated a
authenticated -> do
      Text -> m ()
forall (m :: * -> *). (HasCallStack, MonadLogger m) => Text -> m ()
logInfo Text
"Starting handler after successful authentication."
      a -> m (Union responses)
handler a
authenticated
    AuthResult a
BadPassword -> do
      Text -> m ()
forall (m :: * -> *). (HasCallStack, MonadLogger m) => Text -> m ()
logError Text
"Didn't expect to handle NoSuchUser in Bearer authentication."
      Text -> m ()
forall (m :: * -> *). (HasCallStack, MonadLogger m) => Text -> m ()
logWarn Text
"Returning a HTTP 401 response even though this case was unexpected."
      WithStatus 401 ErrorBearerAuth -> m (Union responses)
forall x (xs :: [*]) (f :: * -> *).
(Applicative f, HasStatus x, IsMember x xs) =>
x -> f (Union xs)
respond (WithStatus 401 ErrorBearerAuth -> m (Union responses))
-> WithStatus 401 ErrorBearerAuth -> m (Union responses)
forall a b. (a -> b) -> a -> b
$ forall (k :: Nat) a. a -> WithStatus k a
WithStatus @401 (ErrorBearerAuth -> WithStatus 401 ErrorBearerAuth)
-> ErrorBearerAuth -> WithStatus 401 ErrorBearerAuth
forall a b. (a -> b) -> a -> b
$ StaticText "indefinite" -> ErrorBearerAuth
MkErrorBearerAuth (StaticText "indefinite" -> ErrorBearerAuth)
-> StaticText "indefinite" -> ErrorBearerAuth
forall a b. (a -> b) -> a -> b
$ forall (text :: Symbol). StaticText text
MkStaticText @"indefinite"
    AuthResult a
NoSuchUser -> do
      Text -> m ()
forall (m :: * -> *). (HasCallStack, MonadLogger m) => Text -> m ()
logError Text
"Didn't expect to handle NoSuchUser in Bearer authentication."
      Text -> m ()
forall (m :: * -> *). (HasCallStack, MonadLogger m) => Text -> m ()
logWarn Text
"Returning a HTTP 401 response even though this case was unexpected."
      WithStatus 401 ErrorBearerAuth -> m (Union responses)
forall x (xs :: [*]) (f :: * -> *).
(Applicative f, HasStatus x, IsMember x xs) =>
x -> f (Union xs)
respond (WithStatus 401 ErrorBearerAuth -> m (Union responses))
-> WithStatus 401 ErrorBearerAuth -> m (Union responses)
forall a b. (a -> b) -> a -> b
$ forall (k :: Nat) a. a -> WithStatus k a
WithStatus @401 (ErrorBearerAuth -> WithStatus 401 ErrorBearerAuth)
-> ErrorBearerAuth -> WithStatus 401 ErrorBearerAuth
forall a b. (a -> b) -> a -> b
$ StaticText "indefinite" -> ErrorBearerAuth
MkErrorBearerAuth (StaticText "indefinite" -> ErrorBearerAuth)
-> StaticText "indefinite" -> ErrorBearerAuth
forall a b. (a -> b) -> a -> b
$ forall (text :: Symbol). StaticText text
MkStaticText @"indefinite"
    AuthResult a
Indefinite -> do
      Text -> m ()
forall (m :: * -> *). (HasCallStack, MonadLogger m) => Text -> m ()
logInfo Text
"Can't access handler, because authentication failed for some reason."
      WithStatus 401 ErrorBearerAuth -> m (Union responses)
forall x (xs :: [*]) (f :: * -> *).
(Applicative f, HasStatus x, IsMember x xs) =>
x -> f (Union xs)
respond (WithStatus 401 ErrorBearerAuth -> m (Union responses))
-> WithStatus 401 ErrorBearerAuth -> m (Union responses)
forall a b. (a -> b) -> a -> b
$ forall (k :: Nat) a. a -> WithStatus k a
WithStatus @401 (ErrorBearerAuth -> WithStatus 401 ErrorBearerAuth)
-> ErrorBearerAuth -> WithStatus 401 ErrorBearerAuth
forall a b. (a -> b) -> a -> b
$ StaticText "indefinite" -> ErrorBearerAuth
MkErrorBearerAuth (StaticText "indefinite" -> ErrorBearerAuth)
-> StaticText "indefinite" -> ErrorBearerAuth
forall a b. (a -> b) -> a -> b
$ forall (text :: Symbol). StaticText text
MkStaticText @"indefinite"

deriving anyclass instance FromJWT UserAuthenticated
deriving anyclass instance ToJWT UserAuthenticated

instance WithSession UserAuthenticated where
  validateSession :: SessionCfg
-> UserAuthenticated -> IO (AuthResult UserAuthenticated)
validateSession MkRunLoginInIO {m (AuthResult UserAuthenticated)
-> IO (AuthResult UserAuthenticated)
runLoginInIO :: m (AuthResult UserAuthenticated)
-> IO (AuthResult UserAuthenticated)
runLoginInIO :: ()
runLoginInIO} authenticated :: UserAuthenticated
authenticated@MkUserAuthenticated {Maybe IdentifierSession
userAuthenticatedSession :: Maybe IdentifierSession
userAuthenticatedSession :: UserAuthenticated -> Maybe IdentifierSession
userAuthenticatedSession} =
    m (AuthResult UserAuthenticated)
-> IO (AuthResult UserAuthenticated)
runLoginInIO (m (AuthResult UserAuthenticated)
 -> IO (AuthResult UserAuthenticated))
-> m (AuthResult UserAuthenticated)
-> IO (AuthResult UserAuthenticated)
forall a b. (a -> b) -> a -> b
$
      case Maybe IdentifierSession
userAuthenticatedSession of
        Maybe IdentifierSession
Nothing -> do
          Text -> m ()
forall (m :: * -> *). (HasCallStack, MonadLogger m) => Text -> m ()
logDebug Text
"Authenticating user without session."
          AuthResult UserAuthenticated -> m (AuthResult UserAuthenticated)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (AuthResult UserAuthenticated -> m (AuthResult UserAuthenticated))
-> AuthResult UserAuthenticated -> m (AuthResult UserAuthenticated)
forall a b. (a -> b) -> a -> b
$ UserAuthenticated -> AuthResult UserAuthenticated
forall val. val -> AuthResult val
Authenticated UserAuthenticated
authenticated
        Just IdentifierSession
sessionIdentifier -> do
          Text -> m ()
forall (m :: * -> *). (HasCallStack, MonadLogger m) => Text -> m ()
logInfo Text
"Validating session before authenticating user."
          SeldaResult SessionValidity
seldaValidationResult <- SeldaTransactionT m SessionValidity
-> m (SeldaResult SessionValidity)
forall a. SeldaTransactionT m a -> m (SeldaResult a)
forall (m :: * -> *) a.
MonadSeldaPool m =>
SeldaTransactionT m a -> m (SeldaResult a)
runSeldaTransactionT (SeldaTransactionT m SessionValidity
 -> m (SeldaResult SessionValidity))
-> SeldaTransactionT m SessionValidity
-> m (SeldaResult SessionValidity)
forall a b. (a -> b) -> a -> b
$ IdentifierSession -> SeldaTransactionT m SessionValidity
forall (m :: * -> *).
(MonadIO m, MonadLogger m, MonadSeldaPool m) =>
IdentifierSession -> SeldaTransactionT m SessionValidity
userSessionValidate IdentifierSession
sessionIdentifier
          case SeldaResult SessionValidity
seldaValidationResult of
            SeldaFailure SomeException
err -> do
              -- This case should not occur under normal circumstances.
              -- The transaction in this case is just a read transaction.
              Text -> m ()
forall (m :: * -> *). (HasCallStack, MonadLogger m) => Text -> m ()
logError Text
"Session validation failed because of a database error."
              SomeException -> m (AuthResult UserAuthenticated)
forall e a. (HasCallStack, Exception e) => e -> m a
forall (m :: * -> *) e a.
(MonadThrow m, HasCallStack, Exception e) =>
e -> m a
throwM SomeException
err
            SeldaSuccess SessionValidity
validity ->
              case SessionValidity
validity of
                SessionValidity
SessionInvalid -> AuthResult UserAuthenticated -> m (AuthResult UserAuthenticated)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure AuthResult UserAuthenticated
forall val. AuthResult val
Indefinite
                SessionValidity
SessionValid -> AuthResult UserAuthenticated -> m (AuthResult UserAuthenticated)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (AuthResult UserAuthenticated -> m (AuthResult UserAuthenticated))
-> AuthResult UserAuthenticated -> m (AuthResult UserAuthenticated)
forall a b. (a -> b) -> a -> b
$ UserAuthenticated -> AuthResult UserAuthenticated
forall val. val -> AuthResult val
Authenticated UserAuthenticated
authenticated

instance FromBasicAuthData UserAuthenticated where
  fromBasicAuthData :: BasicAuthData -> BasicAuthCfg -> IO (AuthResult UserAuthenticated)
fromBasicAuthData BasicAuthData {ByteString
basicAuthUsername :: ByteString
$sel:basicAuthUsername:BasicAuthData :: BasicAuthData -> ByteString
basicAuthUsername, ByteString
basicAuthPassword :: ByteString
$sel:basicAuthPassword:BasicAuthData :: BasicAuthData -> ByteString
basicAuthPassword} MkRunLoginInIO {m (AuthResult UserAuthenticated)
-> IO (AuthResult UserAuthenticated)
runLoginInIO :: ()
runLoginInIO :: m (AuthResult UserAuthenticated)
-> IO (AuthResult UserAuthenticated)
runLoginInIO} =
    m (AuthResult UserAuthenticated)
-> IO (AuthResult UserAuthenticated)
runLoginInIO (m (AuthResult UserAuthenticated)
 -> IO (AuthResult UserAuthenticated))
-> m (AuthResult UserAuthenticated)
-> IO (AuthResult UserAuthenticated)
forall a b. (a -> b) -> a -> b
$ do
      Text -> m ()
forall (m :: * -> *). (HasCallStack, MonadLogger m) => Text -> m ()
logInfo Text
"Starting password authentication."
      Text -> m ()
forall (m :: * -> *). (HasCallStack, MonadLogger m) => Text -> m ()
logDebug (Text -> m ()) -> Text -> m ()
forall a b. (a -> b) -> a -> b
$ Text
"Decoding UTF-8 username: " Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> String -> Text
T.pack (ByteString -> String
forall a. Show a => a -> String
show ByteString
basicAuthUsername)
      case ByteString -> Either UnicodeException Text
T.decodeUtf8' ByteString
basicAuthUsername of
        Left UnicodeException
err -> do
          Text -> m ()
forall (m :: * -> *). (HasCallStack, MonadLogger m) => Text -> m ()
logInfo (Text -> m ()) -> Text -> m ()
forall a b. (a -> b) -> a -> b
$ Text
"Failed to decode username as UTF-8: " Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> String -> Text
T.pack (UnicodeException -> String
forall a. Show a => a -> String
show UnicodeException
err)
          AuthResult UserAuthenticated -> m (AuthResult UserAuthenticated)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure AuthResult UserAuthenticated
forall val. AuthResult val
NoSuchUser
        Right Text
usernameText -> do
          Text -> m ()
forall (m :: * -> *). (HasCallStack, MonadLogger m) => Text -> m ()
logDebug (Text -> m ()) -> Text -> m ()
forall a b. (a -> b) -> a -> b
$ Text
"Decoded UTF-8 username: " Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> String -> Text
T.pack (Text -> String
forall a. Show a => a -> String
show Text
usernameText)
          let username :: Username
username = Text -> Username
MkUsernameUnsafe Text
usernameText
          Text -> m ()
forall (m :: * -> *). (HasCallStack, MonadLogger m) => Text -> m ()
logDebug Text
"Decoding UTF-8 password."
          case Text -> Password
mkPassword (Text -> Password)
-> Either UnicodeException Text -> Either UnicodeException Password
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ByteString -> Either UnicodeException Text
T.decodeUtf8' ByteString
basicAuthPassword of
            Left UnicodeException
_err -> do
              Text -> m ()
forall (m :: * -> *). (HasCallStack, MonadLogger m) => Text -> m ()
logInfo Text
"Failed to decode password as UTF-8."
              AuthResult UserAuthenticated -> m (AuthResult UserAuthenticated)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure AuthResult UserAuthenticated
forall val. AuthResult val
BadPassword
            Right Password
password -> do
              Text -> m ()
forall (m :: * -> *). (HasCallStack, MonadLogger m) => Text -> m ()
logDebug Text
"Decoded UTF-8 password."
              SeldaResult (Either AuthenticationError UserAuthenticated)
seldaAuthResult <- SeldaTransactionT m (Either AuthenticationError UserAuthenticated)
-> m (SeldaResult (Either AuthenticationError UserAuthenticated))
forall a. SeldaTransactionT m a -> m (SeldaResult a)
forall (m :: * -> *) a.
MonadSeldaPool m =>
SeldaTransactionT m a -> m (SeldaResult a)
runSeldaTransactionT (SeldaTransactionT m (Either AuthenticationError UserAuthenticated)
 -> m (SeldaResult (Either AuthenticationError UserAuthenticated)))
-> SeldaTransactionT
     m (Either AuthenticationError UserAuthenticated)
-> m (SeldaResult (Either AuthenticationError UserAuthenticated))
forall a b. (a -> b) -> a -> b
$ Username
-> Password
-> SeldaTransactionT
     m (Either AuthenticationError UserAuthenticated)
forall (m :: * -> *).
(MonadLogger m, MonadSeldaPool m) =>
Username
-> Password
-> SeldaTransactionT
     m (Either AuthenticationError UserAuthenticated)
userAuthenticate Username
username Password
password
              case SeldaResult (Either AuthenticationError UserAuthenticated)
seldaAuthResult of
                SeldaFailure SomeException
err -> do
                  -- This case should not occur under normal circumstances.
                  -- The transaction in this case is just a read transaction.
                  Text -> m ()
forall (m :: * -> *). (HasCallStack, MonadLogger m) => Text -> m ()
logError Text
"Authentication failed because of a database error."
                  SomeException -> m (AuthResult UserAuthenticated)
forall e a. (HasCallStack, Exception e) => e -> m a
forall (m :: * -> *) e a.
(MonadThrow m, HasCallStack, Exception e) =>
e -> m a
throwM SomeException
err
                SeldaSuccess Either AuthenticationError UserAuthenticated
authResult ->
                  case Either AuthenticationError UserAuthenticated
authResult of
                    Left AuthenticationError
AuthenticationErrorUserDoesNotExist -> AuthResult UserAuthenticated -> m (AuthResult UserAuthenticated)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure AuthResult UserAuthenticated
forall val. AuthResult val
NoSuchUser
                    Left AuthenticationError
AuthenticationErrorWrongPassword -> AuthResult UserAuthenticated -> m (AuthResult UserAuthenticated)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure AuthResult UserAuthenticated
forall val. AuthResult val
BadPassword
                    Right UserAuthenticated
authenticated -> AuthResult UserAuthenticated -> m (AuthResult UserAuthenticated)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (AuthResult UserAuthenticated -> m (AuthResult UserAuthenticated))
-> AuthResult UserAuthenticated -> m (AuthResult UserAuthenticated)
forall a b. (a -> b) -> a -> b
$ UserAuthenticated -> AuthResult UserAuthenticated
forall val. val -> AuthResult val
Authenticated UserAuthenticated
authenticated

type instance BasicAuthCfg = RunLoginInIO

type instance SessionCfg = RunLoginInIO

type RunLoginInIO :: Type
data RunLoginInIO = forall m.
  (MonadIO m, MonadLogger m, MonadSeldaPool m, MonadThrow m) =>
  MkRunLoginInIO {()
runLoginInIO :: m (AuthResult UserAuthenticated) -> IO (AuthResult UserAuthenticated)}

mkJwtSettings :: JOSE.JWK -> JWTSettings
mkJwtSettings :: JWK -> JWTSettings
mkJwtSettings = JWK -> JWTSettings
defaultJWTSettings

cookieSettings :: CookieSettings
cookieSettings :: CookieSettings
cookieSettings = CookieSettings
defaultCookieSettings