mas_storage_pg/personal/
access_token.rs

1// Copyright 2025 New Vector Ltd.
2//
3// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
4// Please see LICENSE files in the repository root for full details.
5
6use async_trait::async_trait;
7use chrono::{DateTime, Utc};
8use mas_data_model::{
9    Clock,
10    personal::{PersonalAccessToken, session::PersonalSession},
11};
12use mas_storage::personal::PersonalAccessTokenRepository;
13use rand::RngCore;
14use sha2::{Digest, Sha256};
15use sqlx::PgConnection;
16use ulid::Ulid;
17use uuid::Uuid;
18
19use crate::{DatabaseError, tracing::ExecuteExt as _};
20
21/// An implementation of [`PersonalAccessTokenRepository`] for a PostgreSQL
22/// connection
23pub struct PgPersonalAccessTokenRepository<'c> {
24    conn: &'c mut PgConnection,
25}
26
27impl<'c> PgPersonalAccessTokenRepository<'c> {
28    /// Create a new [`PgPersonalAccessTokenRepository`] from an active
29    /// PostgreSQL connection
30    pub fn new(conn: &'c mut PgConnection) -> Self {
31        Self { conn }
32    }
33}
34
35struct PersonalAccessTokenLookup {
36    personal_access_token_id: Uuid,
37    personal_session_id: Uuid,
38    created_at: DateTime<Utc>,
39    expires_at: Option<DateTime<Utc>>,
40    revoked_at: Option<DateTime<Utc>>,
41}
42
43impl From<PersonalAccessTokenLookup> for PersonalAccessToken {
44    fn from(value: PersonalAccessTokenLookup) -> Self {
45        Self {
46            id: Ulid::from(value.personal_access_token_id),
47            session_id: Ulid::from(value.personal_session_id),
48            created_at: value.created_at,
49            expires_at: value.expires_at,
50            revoked_at: value.revoked_at,
51        }
52    }
53}
54
55#[async_trait]
56impl PersonalAccessTokenRepository for PgPersonalAccessTokenRepository<'_> {
57    type Error = DatabaseError;
58
59    #[tracing::instrument(
60        name = "db.personal_access_token.lookup",
61        skip_all,
62        fields(
63            db.query.text,
64            personal_access_token.id = %id,
65        ),
66        err,
67    )]
68    async fn lookup(&mut self, id: Ulid) -> Result<Option<PersonalAccessToken>, Self::Error> {
69        let res = sqlx::query_as!(
70            PersonalAccessTokenLookup,
71            r#"
72                SELECT personal_access_token_id
73                     , personal_session_id
74                     , created_at
75                     , expires_at
76                     , revoked_at
77
78                FROM personal_access_tokens
79
80                WHERE personal_access_token_id = $1
81            "#,
82            Uuid::from(id),
83        )
84        .traced()
85        .fetch_optional(&mut *self.conn)
86        .await?;
87
88        let Some(res) = res else { return Ok(None) };
89
90        Ok(Some(res.into()))
91    }
92
93    #[tracing::instrument(
94        name = "db.personal_access_token.find_by_token",
95        skip_all,
96        fields(
97            db.query.text,
98        ),
99        err,
100    )]
101    async fn find_by_token(
102        &mut self,
103        access_token: &str,
104    ) -> Result<Option<PersonalAccessToken>, Self::Error> {
105        let token_sha256 = Sha256::digest(access_token.as_bytes()).to_vec();
106
107        let res = sqlx::query_as!(
108            PersonalAccessTokenLookup,
109            r#"
110                SELECT personal_access_token_id
111                     , personal_session_id
112                     , created_at
113                     , expires_at
114                     , revoked_at
115
116                FROM personal_access_tokens
117
118                WHERE access_token_sha256 = $1
119            "#,
120            &token_sha256,
121        )
122        .traced()
123        .fetch_optional(&mut *self.conn)
124        .await?;
125
126        let Some(res) = res else { return Ok(None) };
127
128        Ok(Some(res.into()))
129    }
130
131    #[tracing::instrument(
132        name = "db.personal_access_token.add",
133        skip_all,
134        fields(
135            db.query.text,
136            personal_access_token.id,
137            %session.id,
138        ),
139        err,
140    )]
141    async fn add(
142        &mut self,
143        rng: &mut (dyn RngCore + Send),
144        clock: &dyn Clock,
145        session: &PersonalSession,
146        access_token: &str,
147        expires_after: Option<chrono::Duration>,
148    ) -> Result<PersonalAccessToken, Self::Error> {
149        let created_at = clock.now();
150        let id = Ulid::from_datetime_with_source(created_at.into(), rng);
151        tracing::Span::current().record("personal_access_token.id", tracing::field::display(id));
152
153        let token_sha256 = Sha256::digest(access_token.as_bytes()).to_vec();
154
155        let expires_at = expires_after.map(|expires_after| created_at + expires_after);
156
157        sqlx::query!(
158            r#"
159                INSERT INTO personal_access_tokens
160                    (personal_access_token_id, personal_session_id, access_token_sha256, created_at, expires_at)
161                VALUES ($1, $2, $3, $4, $5)
162            "#,
163            Uuid::from(id),
164            Uuid::from(session.id),
165            &token_sha256,
166            created_at,
167            expires_at,
168        )
169        .traced()
170        .execute(&mut *self.conn)
171        .await?;
172
173        Ok(PersonalAccessToken {
174            id,
175            session_id: session.id,
176            created_at,
177            expires_at,
178            revoked_at: None,
179        })
180    }
181
182    #[tracing::instrument(
183        name = "db.personal_access_token.revoke",
184        skip_all,
185        fields(
186            db.query.text,
187            %access_token.id,
188            personal_session.id = %access_token.session_id,
189        ),
190        err,
191    )]
192    async fn revoke(
193        &mut self,
194        clock: &dyn Clock,
195        mut access_token: PersonalAccessToken,
196    ) -> Result<PersonalAccessToken, Self::Error> {
197        let revoked_at = clock.now();
198        let res = sqlx::query!(
199            r#"
200                UPDATE personal_access_tokens
201                SET revoked_at = $2
202                WHERE personal_access_token_id = $1
203            "#,
204            Uuid::from(access_token.id),
205            revoked_at,
206        )
207        .traced()
208        .execute(&mut *self.conn)
209        .await?;
210
211        DatabaseError::ensure_affected_rows(&res, 1)?;
212
213        access_token.revoked_at = Some(revoked_at);
214        Ok(access_token)
215    }
216}