1use base64ct::{Base64UrlUnpadded, Encoding};
8use chrono::{DateTime, Utc};
9use crc::{CRC_32_ISO_HDLC, Crc};
10use mas_iana::oauth::OAuthTokenTypeHint;
11use rand::{Rng, RngCore, distributions::Alphanumeric};
12use thiserror::Error;
13use ulid::Ulid;
14
15use crate::InvalidTransitionError;
16
17#[derive(Debug, Clone, Default, PartialEq, Eq)]
18pub enum AccessTokenState {
19 #[default]
20 Valid,
21 Revoked {
22 revoked_at: DateTime<Utc>,
23 },
24}
25
26impl AccessTokenState {
27 fn revoke(self, revoked_at: DateTime<Utc>) -> Result<Self, InvalidTransitionError> {
28 match self {
29 Self::Valid => Ok(Self::Revoked { revoked_at }),
30 Self::Revoked { .. } => Err(InvalidTransitionError),
31 }
32 }
33
34 #[must_use]
38 pub fn is_valid(&self) -> bool {
39 matches!(self, Self::Valid)
40 }
41
42 #[must_use]
46 pub fn is_revoked(&self) -> bool {
47 matches!(self, Self::Revoked { .. })
48 }
49}
50
51#[derive(Debug, Clone, PartialEq, Eq)]
52pub struct AccessToken {
53 pub id: Ulid,
54 pub state: AccessTokenState,
55 pub session_id: Ulid,
56 pub access_token: String,
57 pub created_at: DateTime<Utc>,
58 pub expires_at: Option<DateTime<Utc>>,
59 pub first_used_at: Option<DateTime<Utc>>,
60}
61
62impl AccessToken {
63 #[must_use]
64 pub fn jti(&self) -> String {
65 self.id.to_string()
66 }
67
68 #[must_use]
74 pub fn is_valid(&self, now: DateTime<Utc>) -> bool {
75 self.state.is_valid() && !self.is_expired(now)
76 }
77
78 #[must_use]
86 pub fn is_expired(&self, now: DateTime<Utc>) -> bool {
87 match self.expires_at {
88 Some(expires_at) => expires_at < now,
89 None => false,
90 }
91 }
92
93 #[must_use]
95 pub fn is_used(&self) -> bool {
96 self.first_used_at.is_some()
97 }
98
99 pub fn revoke(mut self, revoked_at: DateTime<Utc>) -> Result<Self, InvalidTransitionError> {
109 self.state = self.state.revoke(revoked_at)?;
110 Ok(self)
111 }
112}
113
114#[derive(Debug, Clone, Default, PartialEq, Eq)]
115pub enum RefreshTokenState {
116 #[default]
117 Valid,
118 Consumed {
119 consumed_at: DateTime<Utc>,
120 next_refresh_token_id: Option<Ulid>,
121 },
122 Revoked {
123 revoked_at: DateTime<Utc>,
124 },
125}
126
127impl RefreshTokenState {
128 fn consume(
134 self,
135 consumed_at: DateTime<Utc>,
136 replaced_by: &RefreshToken,
137 ) -> Result<Self, InvalidTransitionError> {
138 match self {
139 Self::Valid | Self::Consumed { .. } => Ok(Self::Consumed {
140 consumed_at,
141 next_refresh_token_id: Some(replaced_by.id),
142 }),
143 Self::Revoked { .. } => Err(InvalidTransitionError),
144 }
145 }
146
147 pub fn revoke(self, revoked_at: DateTime<Utc>) -> Result<Self, InvalidTransitionError> {
153 match self {
154 Self::Valid => Ok(Self::Revoked { revoked_at }),
155 Self::Consumed { .. } | Self::Revoked { .. } => Err(InvalidTransitionError),
156 }
157 }
158
159 #[must_use]
163 pub fn is_valid(&self) -> bool {
164 matches!(self, Self::Valid)
165 }
166
167 #[must_use]
169 pub fn next_refresh_token_id(&self) -> Option<Ulid> {
170 match self {
171 Self::Valid | Self::Revoked { .. } => None,
172 Self::Consumed {
173 next_refresh_token_id,
174 ..
175 } => *next_refresh_token_id,
176 }
177 }
178}
179
180#[derive(Debug, Clone, PartialEq, Eq)]
181pub struct RefreshToken {
182 pub id: Ulid,
183 pub state: RefreshTokenState,
184 pub refresh_token: String,
185 pub session_id: Ulid,
186 pub created_at: DateTime<Utc>,
187 pub access_token_id: Option<Ulid>,
188}
189
190impl std::ops::Deref for RefreshToken {
191 type Target = RefreshTokenState;
192
193 fn deref(&self) -> &Self::Target {
194 &self.state
195 }
196}
197
198impl RefreshToken {
199 #[must_use]
200 pub fn jti(&self) -> String {
201 self.id.to_string()
202 }
203
204 pub fn consume(
210 mut self,
211 consumed_at: DateTime<Utc>,
212 replaced_by: &Self,
213 ) -> Result<Self, InvalidTransitionError> {
214 self.state = self.state.consume(consumed_at, replaced_by)?;
215 Ok(self)
216 }
217
218 pub fn revoke(mut self, revoked_at: DateTime<Utc>) -> Result<Self, InvalidTransitionError> {
224 self.state = self.state.revoke(revoked_at)?;
225 Ok(self)
226 }
227}
228
229#[derive(Debug, Clone, Copy, PartialEq, Eq)]
231pub enum TokenType {
232 AccessToken,
234
235 RefreshToken,
237
238 CompatAccessToken,
240
241 CompatRefreshToken,
243
244 PersonalAccessToken,
246}
247
248impl std::fmt::Display for TokenType {
249 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
250 match self {
251 TokenType::AccessToken => write!(f, "access token"),
252 TokenType::RefreshToken => write!(f, "refresh token"),
253 TokenType::CompatAccessToken => write!(f, "compat access token"),
254 TokenType::CompatRefreshToken => write!(f, "compat refresh token"),
255 TokenType::PersonalAccessToken => write!(f, "personal access token"),
256 }
257 }
258}
259
260impl TokenType {
261 fn prefix(self) -> &'static str {
262 match self {
263 TokenType::AccessToken => "mat",
264 TokenType::RefreshToken => "mar",
265 TokenType::CompatAccessToken => "mct",
266 TokenType::CompatRefreshToken => "mcr",
267 TokenType::PersonalAccessToken => "mpt",
268 }
269 }
270
271 fn match_prefix(prefix: &str) -> Option<Self> {
272 match prefix {
273 "mat" => Some(TokenType::AccessToken),
274 "mar" => Some(TokenType::RefreshToken),
275 "mct" | "syt" => Some(TokenType::CompatAccessToken),
276 "mcr" | "syr" => Some(TokenType::CompatRefreshToken),
277 "mpt" => Some(TokenType::PersonalAccessToken),
278 _ => None,
279 }
280 }
281
282 pub fn generate(self, rng: &mut (impl RngCore + ?Sized)) -> String {
284 let random_part: String = rng
285 .sample_iter(&Alphanumeric)
286 .take(30)
287 .map(char::from)
288 .collect();
289
290 let base = format!("{prefix}_{random_part}", prefix = self.prefix());
291 let crc = CRC.checksum(base.as_bytes());
292 let crc = base62_encode(crc);
293 format!("{base}_{crc}")
294 }
295
296 pub fn check(token: &str) -> Result<TokenType, TokenFormatError> {
302 if token.starts_with("syt_") || is_likely_synapse_macaroon(token) {
305 return Ok(TokenType::CompatAccessToken);
306 }
307 if token.starts_with("syr_") {
308 return Ok(TokenType::CompatRefreshToken);
309 }
310
311 let split: Vec<&str> = token.split('_').collect();
312 let [prefix, random_part, crc]: [&str; 3] = split
313 .try_into()
314 .map_err(|_| TokenFormatError::InvalidFormat)?;
315
316 if prefix.len() != 3 || random_part.len() != 30 || crc.len() != 6 {
317 return Err(TokenFormatError::InvalidFormat);
318 }
319
320 let token_type =
321 TokenType::match_prefix(prefix).ok_or_else(|| TokenFormatError::UnknownPrefix {
322 prefix: prefix.to_owned(),
323 })?;
324
325 let base = format!("{prefix}_{random_part}", prefix = token_type.prefix());
326 let expected_crc = CRC.checksum(base.as_bytes());
327 let expected_crc = base62_encode(expected_crc);
328 if crc != expected_crc {
329 return Err(TokenFormatError::InvalidCrc {
330 expected: expected_crc,
331 got: crc.to_owned(),
332 });
333 }
334
335 Ok(token_type)
336 }
337}
338
339impl PartialEq<OAuthTokenTypeHint> for TokenType {
340 fn eq(&self, other: &OAuthTokenTypeHint) -> bool {
341 matches!(
342 (self, other),
343 (
344 TokenType::AccessToken
345 | TokenType::CompatAccessToken
346 | TokenType::PersonalAccessToken,
347 OAuthTokenTypeHint::AccessToken
348 ) | (
349 TokenType::RefreshToken | TokenType::CompatRefreshToken,
350 OAuthTokenTypeHint::RefreshToken
351 )
352 )
353 }
354}
355
356fn is_likely_synapse_macaroon(token: &str) -> bool {
364 let Ok(decoded) = Base64UrlUnpadded::decode_vec(token) else {
365 return false;
366 };
367 decoded.get(4..13) == Some(b"location ")
368}
369
370const NUM: [u8; 62] = *b"0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz";
371
372fn base62_encode(mut num: u32) -> String {
373 let mut res = String::with_capacity(6);
374 while num > 0 {
375 res.push(NUM[(num % 62) as usize] as char);
376 num /= 62;
377 }
378
379 format!("{res:0>6}")
380}
381
382const CRC: Crc<u32> = Crc::<u32>::new(&CRC_32_ISO_HDLC);
383
384#[derive(Debug, Error, PartialEq, Eq)]
386pub enum TokenFormatError {
387 #[error("invalid token format")]
389 InvalidFormat,
390
391 #[error("unknown token prefix {prefix:?}")]
393 UnknownPrefix {
394 prefix: String,
396 },
397
398 #[error("invalid crc {got:?}, expected {expected:?}")]
400 InvalidCrc {
401 expected: String,
403 got: String,
405 },
406}
407
408#[cfg(test)]
409mod tests {
410 use std::collections::HashSet;
411
412 use rand::thread_rng;
413
414 use super::*;
415
416 #[test]
417 fn test_prefix_match() {
418 use TokenType::{AccessToken, CompatAccessToken, CompatRefreshToken, RefreshToken};
419 assert_eq!(TokenType::match_prefix("syt"), Some(CompatAccessToken));
420 assert_eq!(TokenType::match_prefix("syr"), Some(CompatRefreshToken));
421 assert_eq!(TokenType::match_prefix("mct"), Some(CompatAccessToken));
422 assert_eq!(TokenType::match_prefix("mcr"), Some(CompatRefreshToken));
423 assert_eq!(TokenType::match_prefix("mat"), Some(AccessToken));
424 assert_eq!(TokenType::match_prefix("mar"), Some(RefreshToken));
425 assert_eq!(TokenType::match_prefix("matt"), None);
426 assert_eq!(TokenType::match_prefix("marr"), None);
427 assert_eq!(TokenType::match_prefix("ma"), None);
428 assert_eq!(
429 TokenType::match_prefix(TokenType::CompatAccessToken.prefix()),
430 Some(TokenType::CompatAccessToken)
431 );
432 assert_eq!(
433 TokenType::match_prefix(TokenType::CompatRefreshToken.prefix()),
434 Some(TokenType::CompatRefreshToken)
435 );
436 assert_eq!(
437 TokenType::match_prefix(TokenType::AccessToken.prefix()),
438 Some(TokenType::AccessToken)
439 );
440 assert_eq!(
441 TokenType::match_prefix(TokenType::RefreshToken.prefix()),
442 Some(TokenType::RefreshToken)
443 );
444 }
445
446 #[test]
447 fn test_is_likely_synapse_macaroon() {
448 assert!(is_likely_synapse_macaroon(
451 "MDAxYmxvY2F0aW9uIGxpYnJlcHVzaC5uZXQKMDAx"
452 ));
453
454 assert!(is_likely_synapse_macaroon(
456 "MDAxY2xvY2F0aW9uIGh0dHA6Ly9teWJhbmsvCjAwMjZpZGVudGlmaWVyIHdlIHVzZWQgb3VyIHNlY3JldCBrZXkKMDAyZnNpZ25hdHVyZSDj2eApCFJsTAA5rhURQRXZf91ovyujebNCqvD2F9BVLwo"
457 ));
458
459 assert!(!is_likely_synapse_macaroon(
461 "eyJARTOhearotnaeisahtoarsnhiasra.arsohenaor.oarnsteao"
462 ));
463 assert!(!is_likely_synapse_macaroon("...."));
464 assert!(!is_likely_synapse_macaroon("aaa"));
465 }
466
467 #[test]
468 fn test_generate_and_check() {
469 const COUNT: usize = 500; #[allow(clippy::disallowed_methods)]
472 let mut rng = thread_rng();
473
474 for t in [
475 TokenType::CompatAccessToken,
476 TokenType::CompatRefreshToken,
477 TokenType::AccessToken,
478 TokenType::RefreshToken,
479 ] {
480 let tokens: HashSet<String> = (0..COUNT).map(|_| t.generate(&mut rng)).collect();
482
483 assert_eq!(tokens.len(), COUNT, "All tokens are unique");
485
486 for token in tokens {
488 assert_eq!(TokenType::check(&token).unwrap(), t);
489 }
490 }
491 }
492}