{-# LANGUAGE OverloadedLabels #-} module Mensam.Server.Secrets where import Mensam.Server.Application.SeldaPool.Class import Mensam.Server.Database.Schema import Control.Monad.IO.Class import Control.Monad.Logger.CallStack import Control.Monad.Trans.Class import Crypto.JOSE.JWK qualified as JOSE import Data.Kind import Data.Text qualified as T import Data.Time qualified as T import Database.Selda qualified as Selda import Servant.Auth.Server type Secrets :: Type newtype Secrets = MkSecrets { Secrets -> JWK secretsJwk :: JOSE.JWK } jwkGetLatest :: ( MonadSeldaPool m , MonadLogger m ) => SeldaTransactionT m (Maybe JOSE.JWK) jwkGetLatest :: forall (m :: * -> *). (MonadSeldaPool m, MonadLogger m) => SeldaTransactionT m (Maybe JWK) jwkGetLatest = do m () -> SeldaTransactionT m () forall (m :: * -> *) a. Monad m => m a -> SeldaTransactionT m a forall (t :: (* -> *) -> * -> *) (m :: * -> *) a. (MonadTrans t, Monad m) => m a -> t m a lift (m () -> SeldaTransactionT m ()) -> m () -> SeldaTransactionT m () forall a b. (a -> b) -> a -> b $ Text -> m () forall (m :: * -> *). (HasCallStack, MonadLogger m) => Text -> m () logDebug Text "Getting latest JWK." [DbJwk] dbJwks <- Query (Backend (SeldaTransactionT m)) (Row SQLite DbJwk) -> SeldaTransactionT m [Res (Row SQLite DbJwk)] forall (m :: * -> *) a. (MonadSelda m, Result a) => Query (Backend m) a -> m [Res a] Selda.query (Query (Backend (SeldaTransactionT m)) (Row SQLite DbJwk) -> SeldaTransactionT m [Res (Row SQLite DbJwk)]) -> Query (Backend (SeldaTransactionT m)) (Row SQLite DbJwk) -> SeldaTransactionT m [Res (Row SQLite DbJwk)] forall a b. (a -> b) -> a -> b $ do Row SQLite DbJwk dbJwk <- Table DbJwk -> Query SQLite (Row SQLite DbJwk) forall a s. Relational a => Table a -> Query s (Row s a) Selda.select Table DbJwk tableJwk Col SQLite (ID DbJwk) -> Order -> Query SQLite () forall s t a. (Same s t, SqlType a) => Col s a -> Order -> Query t () Selda.order (Row SQLite DbJwk dbJwk Row SQLite DbJwk -> Selector DbJwk (ID DbJwk) -> Col SQLite (ID DbJwk) forall a s t. SqlType a => Row s t -> Selector t a -> Col s a Selda.! Selector DbJwk (ID DbJwk) #dbJwk_id) Order Selda.Desc Row SQLite DbJwk -> Query SQLite (Row SQLite DbJwk) forall a. a -> Query SQLite a forall (f :: * -> *) a. Applicative f => a -> f a pure Row SQLite DbJwk dbJwk case [DbJwk] dbJwks of [] -> do m () -> SeldaTransactionT m () forall (m :: * -> *) a. Monad m => m a -> SeldaTransactionT m a forall (t :: (* -> *) -> * -> *) (m :: * -> *) a. (MonadTrans t, Monad m) => m a -> t m a lift (m () -> SeldaTransactionT m ()) -> m () -> SeldaTransactionT m () forall a b. (a -> b) -> a -> b $ Text -> m () forall (m :: * -> *). (HasCallStack, MonadLogger m) => Text -> m () logInfo Text "No JWK currently in database." Maybe JWK -> SeldaTransactionT m (Maybe JWK) forall a. a -> SeldaTransactionT m a forall (f :: * -> *) a. Applicative f => a -> f a pure Maybe JWK forall a. Maybe a Nothing DbJwk dbJwk : [DbJwk] _ -> do m () -> SeldaTransactionT m () forall (m :: * -> *) a. Monad m => m a -> SeldaTransactionT m a forall (t :: (* -> *) -> * -> *) (m :: * -> *) a. (MonadTrans t, Monad m) => m a -> t m a lift (m () -> SeldaTransactionT m ()) -> m () -> SeldaTransactionT m () forall a b. (a -> b) -> a -> b $ Text -> m () forall (m :: * -> *). (HasCallStack, MonadLogger m) => Text -> m () logDebug Text "Successfully got JWKs from database. Parsing the latest JWK." let jwk :: JWK jwk = ByteString -> JWK fromSecret (ByteString -> JWK) -> ByteString -> JWK forall a b. (a -> b) -> a -> b $ DbJwk -> ByteString dbJwk_jwk DbJwk dbJwk m () -> SeldaTransactionT m () forall (m :: * -> *) a. Monad m => m a -> SeldaTransactionT m a forall (t :: (* -> *) -> * -> *) (m :: * -> *) a. (MonadTrans t, Monad m) => m a -> t m a lift (m () -> SeldaTransactionT m ()) -> m () -> SeldaTransactionT m () forall a b. (a -> b) -> a -> b $ Text -> m () forall (m :: * -> *). (HasCallStack, MonadLogger m) => Text -> m () logDebug Text "Successfully parsed JWK." m () -> SeldaTransactionT m () forall (m :: * -> *) a. Monad m => m a -> SeldaTransactionT m a forall (t :: (* -> *) -> * -> *) (m :: * -> *) a. (MonadTrans t, Monad m) => m a -> t m a lift (m () -> SeldaTransactionT m ()) -> m () -> SeldaTransactionT m () forall a b. (a -> b) -> a -> b $ Text -> m () forall (m :: * -> *). (HasCallStack, MonadLogger m) => Text -> m () logInfo Text "Returning JWK." Maybe JWK -> SeldaTransactionT m (Maybe JWK) forall a. a -> SeldaTransactionT m a forall (f :: * -> *) a. Applicative f => a -> f a pure (Maybe JWK -> SeldaTransactionT m (Maybe JWK)) -> Maybe JWK -> SeldaTransactionT m (Maybe JWK) forall a b. (a -> b) -> a -> b $ JWK -> Maybe JWK forall a. a -> Maybe a Just JWK jwk jwkSetLatest :: ( MonadSeldaPool m , MonadLogger m ) => SeldaTransactionT m JOSE.JWK jwkSetLatest :: forall (m :: * -> *). (MonadSeldaPool m, MonadLogger m) => SeldaTransactionT m JWK jwkSetLatest = do m () -> SeldaTransactionT m () forall (m :: * -> *) a. Monad m => m a -> SeldaTransactionT m a forall (t :: (* -> *) -> * -> *) (m :: * -> *) a. (MonadTrans t, Monad m) => m a -> t m a lift (m () -> SeldaTransactionT m ()) -> m () -> SeldaTransactionT m () forall a b. (a -> b) -> a -> b $ Text -> m () forall (m :: * -> *). (HasCallStack, MonadLogger m) => Text -> m () logDebug Text "Setting latest JWK." Maybe JWK maybeJwk <- SeldaTransactionT m (Maybe JWK) forall (m :: * -> *). (MonadSeldaPool m, MonadLogger m) => SeldaTransactionT m (Maybe JWK) jwkGetLatest case Maybe JWK maybeJwk of Just JWK jwk -> do m () -> SeldaTransactionT m () forall (m :: * -> *) a. Monad m => m a -> SeldaTransactionT m a forall (t :: (* -> *) -> * -> *) (m :: * -> *) a. (MonadTrans t, Monad m) => m a -> t m a lift (m () -> SeldaTransactionT m ()) -> m () -> SeldaTransactionT m () forall a b. (a -> b) -> a -> b $ Text -> m () forall (m :: * -> *). (HasCallStack, MonadLogger m) => Text -> m () logWarn Text "JWK already exists. Skipping." m () -> SeldaTransactionT m () forall (m :: * -> *) a. Monad m => m a -> SeldaTransactionT m a forall (t :: (* -> *) -> * -> *) (m :: * -> *) a. (MonadTrans t, Monad m) => m a -> t m a lift (m () -> SeldaTransactionT m ()) -> m () -> SeldaTransactionT m () forall a b. (a -> b) -> a -> b $ Text -> m () forall (m :: * -> *). (HasCallStack, MonadLogger m) => Text -> m () logInfo Text "Returning an old JWK." JWK -> SeldaTransactionT m JWK forall a. a -> SeldaTransactionT m a forall (f :: * -> *) a. Applicative f => a -> f a pure JWK jwk Maybe JWK Nothing -> do m () -> SeldaTransactionT m () forall (m :: * -> *) a. Monad m => m a -> SeldaTransactionT m a forall (t :: (* -> *) -> * -> *) (m :: * -> *) a. (MonadTrans t, Monad m) => m a -> t m a lift (m () -> SeldaTransactionT m ()) -> m () -> SeldaTransactionT m () forall a b. (a -> b) -> a -> b $ Text -> m () forall (m :: * -> *). (HasCallStack, MonadLogger m) => Text -> m () logDebug Text "Generating a new JWK." ByteString secret <- IO ByteString -> SeldaTransactionT m ByteString forall a. IO a -> SeldaTransactionT m a forall (m :: * -> *) a. MonadIO m => IO a -> m a liftIO IO ByteString forall (m :: * -> *). MonadRandom m => m ByteString generateSecret let jwk :: JWK jwk = ByteString -> JWK fromSecret ByteString secret UTCTime currentTime <- IO UTCTime -> SeldaTransactionT m UTCTime forall a. IO a -> SeldaTransactionT m a forall (m :: * -> *) a. MonadIO m => IO a -> m a liftIO IO UTCTime T.getCurrentTime let dbJwk :: DbJwk dbJwk = MkDbJwk { dbJwk_id :: ID DbJwk dbJwk_id = ID DbJwk forall a. SqlType a => a Selda.def , dbJwk_jwk :: ByteString dbJwk_jwk = ByteString secret , dbJwk_created :: UTCTime dbJwk_created = UTCTime currentTime } ID DbJwk identifier <- Table DbJwk -> [DbJwk] -> SeldaTransactionT m (ID DbJwk) forall (m :: * -> *) a. (MonadSelda m, Relational a) => Table a -> [a] -> m (ID a) Selda.insertWithPK Table DbJwk tableJwk [DbJwk dbJwk] m () -> SeldaTransactionT m () forall (m :: * -> *) a. Monad m => m a -> SeldaTransactionT m a forall (t :: (* -> *) -> * -> *) (m :: * -> *) a. (MonadTrans t, Monad m) => m a -> t m a lift (m () -> SeldaTransactionT m ()) -> m () -> SeldaTransactionT m () forall a b. (a -> b) -> a -> b $ Text -> m () forall (m :: * -> *). (HasCallStack, MonadLogger m) => Text -> m () logInfo (Text -> m ()) -> Text -> m () forall a b. (a -> b) -> a -> b $ Text "Inserted new JWK: " Text -> Text -> Text forall a. Semigroup a => a -> a -> a <> String -> Text T.pack (ID DbJwk -> String forall a. Show a => a -> String show ID DbJwk identifier) m () -> SeldaTransactionT m () forall (m :: * -> *) a. Monad m => m a -> SeldaTransactionT m a forall (t :: (* -> *) -> * -> *) (m :: * -> *) a. (MonadTrans t, Monad m) => m a -> t m a lift (m () -> SeldaTransactionT m ()) -> m () -> SeldaTransactionT m () forall a b. (a -> b) -> a -> b $ Text -> m () forall (m :: * -> *). (HasCallStack, MonadLogger m) => Text -> m () logInfo Text "Returning newly generated JWK." JWK -> SeldaTransactionT m JWK forall a. a -> SeldaTransactionT m a forall (f :: * -> *) a. Applicative f => a -> f a pure JWK jwk