{-# LANGUAGE OverloadedLabels #-}

module Mensam.Server.Secrets where

import Mensam.Server.Application.SeldaPool.Class
import Mensam.Server.Database.Schema

import Control.Monad.IO.Class
import Control.Monad.Logger.CallStack
import Control.Monad.Trans.Class
import Crypto.JOSE.JWK qualified as JOSE
import Data.Kind
import Data.Text qualified as T
import Data.Time qualified as T
import Database.Selda qualified as Selda
import Servant.Auth.Server

type Secrets :: Type
newtype Secrets = MkSecrets
  { Secrets -> JWK
secretsJwk :: JOSE.JWK
  }

jwkGetLatest ::
  ( MonadSeldaPool m
  , MonadLogger m
  ) =>
  SeldaTransactionT m (Maybe JOSE.JWK)
jwkGetLatest :: forall (m :: * -> *).
(MonadSeldaPool m, MonadLogger m) =>
SeldaTransactionT m (Maybe JWK)
jwkGetLatest = do
  m () -> SeldaTransactionT m ()
forall (m :: * -> *) a. Monad m => m a -> SeldaTransactionT m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m () -> SeldaTransactionT m ()) -> m () -> SeldaTransactionT m ()
forall a b. (a -> b) -> a -> b
$ Text -> m ()
forall (m :: * -> *). (HasCallStack, MonadLogger m) => Text -> m ()
logDebug Text
"Getting latest JWK."
  [DbJwk]
dbJwks <- Query (Backend (SeldaTransactionT m)) (Row SQLite DbJwk)
-> SeldaTransactionT m [Res (Row SQLite DbJwk)]
forall (m :: * -> *) a.
(MonadSelda m, Result a) =>
Query (Backend m) a -> m [Res a]
Selda.query (Query (Backend (SeldaTransactionT m)) (Row SQLite DbJwk)
 -> SeldaTransactionT m [Res (Row SQLite DbJwk)])
-> Query (Backend (SeldaTransactionT m)) (Row SQLite DbJwk)
-> SeldaTransactionT m [Res (Row SQLite DbJwk)]
forall a b. (a -> b) -> a -> b
$ do
    Row SQLite DbJwk
dbJwk <- Table DbJwk -> Query SQLite (Row SQLite DbJwk)
forall a s. Relational a => Table a -> Query s (Row s a)
Selda.select Table DbJwk
tableJwk
    Col SQLite (ID DbJwk) -> Order -> Query SQLite ()
forall s t a.
(Same s t, SqlType a) =>
Col s a -> Order -> Query t ()
Selda.order (Row SQLite DbJwk
dbJwk Row SQLite DbJwk
-> Selector DbJwk (ID DbJwk) -> Col SQLite (ID DbJwk)
forall a s t. SqlType a => Row s t -> Selector t a -> Col s a
Selda.! Selector DbJwk (ID DbJwk)
#dbJwk_id) Order
Selda.Desc
    Row SQLite DbJwk -> Query SQLite (Row SQLite DbJwk)
forall a. a -> Query SQLite a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Row SQLite DbJwk
dbJwk
  case [DbJwk]
dbJwks of
    [] -> do
      m () -> SeldaTransactionT m ()
forall (m :: * -> *) a. Monad m => m a -> SeldaTransactionT m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m () -> SeldaTransactionT m ()) -> m () -> SeldaTransactionT m ()
forall a b. (a -> b) -> a -> b
$ Text -> m ()
forall (m :: * -> *). (HasCallStack, MonadLogger m) => Text -> m ()
logInfo Text
"No JWK currently in database."
      Maybe JWK -> SeldaTransactionT m (Maybe JWK)
forall a. a -> SeldaTransactionT m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe JWK
forall a. Maybe a
Nothing
    DbJwk
dbJwk : [DbJwk]
_ -> do
      m () -> SeldaTransactionT m ()
forall (m :: * -> *) a. Monad m => m a -> SeldaTransactionT m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m () -> SeldaTransactionT m ()) -> m () -> SeldaTransactionT m ()
forall a b. (a -> b) -> a -> b
$ Text -> m ()
forall (m :: * -> *). (HasCallStack, MonadLogger m) => Text -> m ()
logDebug Text
"Successfully got JWKs from database. Parsing the latest JWK."
      let jwk :: JWK
jwk = ByteString -> JWK
fromSecret (ByteString -> JWK) -> ByteString -> JWK
forall a b. (a -> b) -> a -> b
$ DbJwk -> ByteString
dbJwk_jwk DbJwk
dbJwk
      m () -> SeldaTransactionT m ()
forall (m :: * -> *) a. Monad m => m a -> SeldaTransactionT m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m () -> SeldaTransactionT m ()) -> m () -> SeldaTransactionT m ()
forall a b. (a -> b) -> a -> b
$ Text -> m ()
forall (m :: * -> *). (HasCallStack, MonadLogger m) => Text -> m ()
logDebug Text
"Successfully parsed JWK."
      m () -> SeldaTransactionT m ()
forall (m :: * -> *) a. Monad m => m a -> SeldaTransactionT m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m () -> SeldaTransactionT m ()) -> m () -> SeldaTransactionT m ()
forall a b. (a -> b) -> a -> b
$ Text -> m ()
forall (m :: * -> *). (HasCallStack, MonadLogger m) => Text -> m ()
logInfo Text
"Returning JWK."
      Maybe JWK -> SeldaTransactionT m (Maybe JWK)
forall a. a -> SeldaTransactionT m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe JWK -> SeldaTransactionT m (Maybe JWK))
-> Maybe JWK -> SeldaTransactionT m (Maybe JWK)
forall a b. (a -> b) -> a -> b
$ JWK -> Maybe JWK
forall a. a -> Maybe a
Just JWK
jwk

jwkSetLatest ::
  ( MonadSeldaPool m
  , MonadLogger m
  ) =>
  SeldaTransactionT m JOSE.JWK
jwkSetLatest :: forall (m :: * -> *).
(MonadSeldaPool m, MonadLogger m) =>
SeldaTransactionT m JWK
jwkSetLatest = do
  m () -> SeldaTransactionT m ()
forall (m :: * -> *) a. Monad m => m a -> SeldaTransactionT m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m () -> SeldaTransactionT m ()) -> m () -> SeldaTransactionT m ()
forall a b. (a -> b) -> a -> b
$ Text -> m ()
forall (m :: * -> *). (HasCallStack, MonadLogger m) => Text -> m ()
logDebug Text
"Setting latest JWK."
  Maybe JWK
maybeJwk <- SeldaTransactionT m (Maybe JWK)
forall (m :: * -> *).
(MonadSeldaPool m, MonadLogger m) =>
SeldaTransactionT m (Maybe JWK)
jwkGetLatest
  case Maybe JWK
maybeJwk of
    Just JWK
jwk -> do
      m () -> SeldaTransactionT m ()
forall (m :: * -> *) a. Monad m => m a -> SeldaTransactionT m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m () -> SeldaTransactionT m ()) -> m () -> SeldaTransactionT m ()
forall a b. (a -> b) -> a -> b
$ Text -> m ()
forall (m :: * -> *). (HasCallStack, MonadLogger m) => Text -> m ()
logWarn Text
"JWK already exists. Skipping."
      m () -> SeldaTransactionT m ()
forall (m :: * -> *) a. Monad m => m a -> SeldaTransactionT m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m () -> SeldaTransactionT m ()) -> m () -> SeldaTransactionT m ()
forall a b. (a -> b) -> a -> b
$ Text -> m ()
forall (m :: * -> *). (HasCallStack, MonadLogger m) => Text -> m ()
logInfo Text
"Returning an old JWK."
      JWK -> SeldaTransactionT m JWK
forall a. a -> SeldaTransactionT m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure JWK
jwk
    Maybe JWK
Nothing -> do
      m () -> SeldaTransactionT m ()
forall (m :: * -> *) a. Monad m => m a -> SeldaTransactionT m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m () -> SeldaTransactionT m ()) -> m () -> SeldaTransactionT m ()
forall a b. (a -> b) -> a -> b
$ Text -> m ()
forall (m :: * -> *). (HasCallStack, MonadLogger m) => Text -> m ()
logDebug Text
"Generating a new JWK."
      ByteString
secret <- IO ByteString -> SeldaTransactionT m ByteString
forall a. IO a -> SeldaTransactionT m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO IO ByteString
forall (m :: * -> *). MonadRandom m => m ByteString
generateSecret
      let jwk :: JWK
jwk = ByteString -> JWK
fromSecret ByteString
secret
      UTCTime
currentTime <- IO UTCTime -> SeldaTransactionT m UTCTime
forall a. IO a -> SeldaTransactionT m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO IO UTCTime
T.getCurrentTime
      let dbJwk :: DbJwk
dbJwk =
            MkDbJwk
              { dbJwk_id :: ID DbJwk
dbJwk_id = ID DbJwk
forall a. SqlType a => a
Selda.def
              , dbJwk_jwk :: ByteString
dbJwk_jwk = ByteString
secret
              , dbJwk_created :: UTCTime
dbJwk_created = UTCTime
currentTime
              }
      ID DbJwk
identifier <- Table DbJwk -> [DbJwk] -> SeldaTransactionT m (ID DbJwk)
forall (m :: * -> *) a.
(MonadSelda m, Relational a) =>
Table a -> [a] -> m (ID a)
Selda.insertWithPK Table DbJwk
tableJwk [DbJwk
dbJwk]
      m () -> SeldaTransactionT m ()
forall (m :: * -> *) a. Monad m => m a -> SeldaTransactionT m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m () -> SeldaTransactionT m ()) -> m () -> SeldaTransactionT m ()
forall a b. (a -> b) -> a -> b
$ Text -> m ()
forall (m :: * -> *). (HasCallStack, MonadLogger m) => Text -> m ()
logInfo (Text -> m ()) -> Text -> m ()
forall a b. (a -> b) -> a -> b
$ Text
"Inserted new JWK: " Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> String -> Text
T.pack (ID DbJwk -> String
forall a. Show a => a -> String
show ID DbJwk
identifier)
      m () -> SeldaTransactionT m ()
forall (m :: * -> *) a. Monad m => m a -> SeldaTransactionT m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m () -> SeldaTransactionT m ()) -> m () -> SeldaTransactionT m ()
forall a b. (a -> b) -> a -> b
$ Text -> m ()
forall (m :: * -> *). (HasCallStack, MonadLogger m) => Text -> m ()
logInfo Text
"Returning newly generated JWK."
      JWK -> SeldaTransactionT m JWK
forall a. a -> SeldaTransactionT m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure JWK
jwk