1use aes_gcm::{
7 aead::{rand_core::RngCore, Aead, KeyInit, OsRng, Payload},
8 Aes256Gcm, Key, Nonce,
9};
10use serde::{Deserialize, Serialize};
11use std::env;
12use thiserror::Error;
13use zeroize::Zeroize;
14
15use crate::{
16 models::SecretString,
17 utils::{base64_decode, base64_encode, EncryptionContext},
18};
19
20#[derive(Error, Debug, Clone)]
21pub enum EncryptionError {
22 #[error("Encryption failed: {0}")]
23 EncryptionFailed(String),
24 #[error("Decryption failed: {0}")]
25 DecryptionFailed(String),
26 #[error("Key derivation failed: {0}")]
27 KeyDerivationFailed(String),
28 #[error("Invalid encrypted data format: {0}")]
29 InvalidFormat(String),
30 #[error("Missing encryption key environment variable: {0}")]
31 MissingKey(String),
32 #[error("Invalid key length: expected 32 bytes, got {0}")]
33 InvalidKeyLength(usize),
34 #[error("Missing AAD for v2 decryption")]
35 MissingAAD,
36 #[error("Unsupported encryption version: {0}")]
37 UnsupportedVersion(u8),
38}
39
40#[derive(Debug, Clone, Serialize, Deserialize)]
42pub struct EncryptedData {
43 pub nonce: String,
45 pub ciphertext: String,
47 pub version: u8,
49}
50
51#[derive(Clone)]
53pub struct FieldEncryption {
54 cipher: Aes256Gcm,
55}
56
57impl FieldEncryption {
58 pub fn new() -> Result<Self, EncryptionError> {
64 let key = Self::load_key_from_env()?;
65 let cipher = Aes256Gcm::new(&key);
66 Ok(Self { cipher })
67 }
68
69 pub fn new_with_key(key: &[u8; 32]) -> Result<Self, EncryptionError> {
71 let key = Key::<Aes256Gcm>::from(*key);
72 let cipher = Aes256Gcm::new(&key);
73 Ok(Self { cipher })
74 }
75
76 fn load_key_from_env() -> Result<Key<Aes256Gcm>, EncryptionError> {
78 let key = env::var("STORAGE_ENCRYPTION_KEY")
79 .map(|v| SecretString::new(&v))
80 .map_err(|_| {
81 EncryptionError::MissingKey("STORAGE_ENCRYPTION_KEY must be set".to_string())
82 })?;
83
84 key.as_str(|key_b64| {
85 let mut key_bytes = base64_decode(key_b64)
86 .map_err(|e| EncryptionError::KeyDerivationFailed(e.to_string()))?;
87 if key_bytes.len() != 32 {
88 key_bytes.zeroize(); return Err(EncryptionError::InvalidKeyLength(key_bytes.len()));
90 }
91
92 let key_array: [u8; 32] = key_bytes
93 .as_slice()
94 .try_into()
95 .map_err(|_| EncryptionError::InvalidKeyLength(key_bytes.len()))?;
96 Ok(Key::<Aes256Gcm>::from(key_array))
97 })
98 }
99
100 pub fn encrypt(&self, plaintext: &[u8]) -> Result<EncryptedData, EncryptionError> {
102 let mut nonce_bytes = [0u8; 12];
104 OsRng.fill_bytes(&mut nonce_bytes);
105 let nonce = &Nonce::from(nonce_bytes);
106
107 let ciphertext = self
109 .cipher
110 .encrypt(nonce, plaintext)
111 .map_err(|e| EncryptionError::EncryptionFailed(e.to_string()))?;
112
113 Ok(EncryptedData {
114 nonce: base64_encode(&nonce_bytes),
115 ciphertext: base64_encode(&ciphertext),
116 version: 1,
117 })
118 }
119
120 pub fn decrypt(&self, encrypted_data: &EncryptedData) -> Result<Vec<u8>, EncryptionError> {
122 if encrypted_data.version != 1 {
123 return Err(EncryptionError::InvalidFormat(format!(
124 "Unsupported encryption version: {}",
125 encrypted_data.version
126 )));
127 }
128
129 let nonce_bytes = base64_decode(&encrypted_data.nonce)
131 .map_err(|e| EncryptionError::InvalidFormat(format!("Invalid nonce: {e}")))?;
132
133 let ciphertext_bytes = base64_decode(&encrypted_data.ciphertext)
134 .map_err(|e| EncryptionError::InvalidFormat(format!("Invalid ciphertext: {e}")))?;
135
136 if nonce_bytes.len() != 12 {
137 return Err(EncryptionError::InvalidFormat(format!(
138 "Invalid nonce length: expected 12, got {}",
139 nonce_bytes.len()
140 )));
141 }
142
143 let nonce_array: [u8; 12] = nonce_bytes
144 .as_slice()
145 .try_into()
146 .map_err(|_| EncryptionError::InvalidFormat("Invalid nonce length".to_string()))?;
147 let nonce = &Nonce::from(nonce_array);
148
149 let plaintext = self
151 .cipher
152 .decrypt(nonce, ciphertext_bytes.as_ref())
153 .map_err(|e| EncryptionError::DecryptionFailed(e.to_string()))?;
154
155 Ok(plaintext)
156 }
157
158 pub fn encrypt_with_aad(
164 &self,
165 plaintext: &[u8],
166 aad: &[u8],
167 ) -> Result<EncryptedData, EncryptionError> {
168 let mut nonce_bytes = [0u8; 12];
170 OsRng.fill_bytes(&mut nonce_bytes);
171 let nonce = Nonce::from(nonce_bytes);
172
173 let ciphertext = self
175 .cipher
176 .encrypt(
177 &nonce,
178 Payload {
179 msg: plaintext,
180 aad,
181 },
182 )
183 .map_err(|e| EncryptionError::EncryptionFailed(e.to_string()))?;
184
185 Ok(EncryptedData {
186 nonce: base64_encode(&nonce_bytes),
187 ciphertext: base64_encode(&ciphertext),
188 version: 2, })
190 }
191
192 pub fn decrypt_with_aad(
197 &self,
198 encrypted_data: &EncryptedData,
199 aad: &[u8],
200 ) -> Result<Vec<u8>, EncryptionError> {
201 if encrypted_data.version != 2 {
202 return Err(EncryptionError::InvalidFormat(format!(
203 "Expected version 2 for AAD decryption, got {}",
204 encrypted_data.version
205 )));
206 }
207
208 let nonce_bytes = base64_decode(&encrypted_data.nonce)
210 .map_err(|e| EncryptionError::InvalidFormat(format!("Invalid nonce: {e}")))?;
211
212 let ciphertext_bytes = base64_decode(&encrypted_data.ciphertext)
213 .map_err(|e| EncryptionError::InvalidFormat(format!("Invalid ciphertext: {e}")))?;
214
215 if nonce_bytes.len() != 12 {
216 return Err(EncryptionError::InvalidFormat(format!(
217 "Invalid nonce length: expected 12, got {}",
218 nonce_bytes.len()
219 )));
220 }
221
222 let nonce_array: [u8; 12] = nonce_bytes
223 .as_slice()
224 .try_into()
225 .map_err(|_| EncryptionError::InvalidFormat("Invalid nonce length".to_string()))?;
226 let nonce = Nonce::from(nonce_array);
227
228 let plaintext = self
230 .cipher
231 .decrypt(
232 &nonce,
233 Payload {
234 msg: &ciphertext_bytes,
235 aad,
236 },
237 )
238 .map_err(|e| EncryptionError::DecryptionFailed(e.to_string()))?;
239
240 Ok(plaintext)
241 }
242
243 pub fn decrypt_auto(
251 &self,
252 encrypted_data: &EncryptedData,
253 aad: Option<&[u8]>,
254 ) -> Result<Vec<u8>, EncryptionError> {
255 match encrypted_data.version {
256 1 => self.decrypt(encrypted_data),
257 2 => {
258 let aad = aad.ok_or(EncryptionError::MissingAAD)?;
259 self.decrypt_with_aad(encrypted_data, aad)
260 }
261 v => Err(EncryptionError::UnsupportedVersion(v)),
262 }
263 }
264
265 pub fn encrypt_string(&self, plaintext: &str) -> Result<String, EncryptionError> {
267 let encrypted_data = self.encrypt(plaintext.as_bytes())?;
268 let json_data = serde_json::to_string(&encrypted_data)
269 .map_err(|e| EncryptionError::EncryptionFailed(format!("Serialization failed: {e}")))?;
270
271 Ok(base64_encode(json_data.as_bytes()))
273 }
274
275 pub fn decrypt_string(&self, encrypted_base64: &str) -> Result<String, EncryptionError> {
277 let json_bytes = base64_decode(encrypted_base64)
279 .map_err(|e| EncryptionError::InvalidFormat(format!("Invalid base64: {e}")))?;
280
281 let encrypted_json = String::from_utf8(json_bytes).map_err(|e| {
282 EncryptionError::InvalidFormat(format!("Invalid UTF-8 in decoded data: {e}"))
283 })?;
284
285 let encrypted_data: EncryptedData = serde_json::from_str(&encrypted_json)
286 .map_err(|e| EncryptionError::InvalidFormat(format!("Invalid JSON structure: {e}")))?;
287
288 let plaintext_bytes = self.decrypt(&encrypted_data)?;
289 String::from_utf8(plaintext_bytes).map_err(|e| {
290 EncryptionError::DecryptionFailed(format!("Invalid UTF-8 in plaintext: {e}"))
291 })
292 }
293
294 pub fn generate_key() -> String {
296 let mut key = [0u8; 32];
297 OsRng.fill_bytes(&mut key);
298 let key_b64 = base64_encode(&key);
299
300 let mut key_zeroize = key;
302 key_zeroize.zeroize();
303
304 key_b64
305 }
306
307 pub fn is_configured() -> bool {
309 env::var("STORAGE_ENCRYPTION_KEY").is_ok()
310 }
311}
312
313static ENCRYPTION_INSTANCE: std::sync::OnceLock<Result<FieldEncryption, EncryptionError>> =
315 std::sync::OnceLock::new();
316
317pub fn get_encryption() -> Result<&'static FieldEncryption, &'static EncryptionError> {
319 ENCRYPTION_INSTANCE
320 .get_or_init(FieldEncryption::new)
321 .as_ref()
322}
323
324pub fn encrypt_sensitive_field(data: &str) -> Result<String, EncryptionError> {
326 if FieldEncryption::is_configured() {
327 match get_encryption() {
328 Ok(encryption) => encryption.encrypt_string(data),
329 Err(e) => Err(e.clone()),
330 }
331 } else {
332 let json_data = serde_json::to_string(data)
335 .map_err(|e| EncryptionError::EncryptionFailed(format!("JSON encoding failed: {e}")))?;
336 Ok(base64_encode(json_data.as_bytes()))
337 }
338}
339
340pub fn decrypt_sensitive_field(data: &str) -> Result<String, EncryptionError> {
342 let json_bytes = base64_decode(data)
344 .map_err(|e| EncryptionError::InvalidFormat(format!("Invalid base64: {e}")))?;
345
346 let json_str = String::from_utf8(json_bytes)
347 .map_err(|e| EncryptionError::InvalidFormat(format!("Invalid UTF-8: {e}")))?;
348
349 if FieldEncryption::is_configured() {
351 if let Ok(encryption) = get_encryption() {
352 if let Ok(encrypted_data) = serde_json::from_str::<EncryptedData>(&json_str) {
354 let plaintext_bytes = encryption.decrypt(&encrypted_data)?;
356 return String::from_utf8(plaintext_bytes).map_err(|e| {
357 EncryptionError::DecryptionFailed(format!("Invalid UTF-8 in plaintext: {e}"))
358 });
359 }
360 }
361 }
362
363 serde_json::from_str(&json_str)
366 .map_err(|e| EncryptionError::DecryptionFailed(format!("Invalid JSON string: {e}")))
367}
368
369pub fn encrypt_sensitive_field_with_aad(data: &str) -> Result<String, EncryptionError> {
375 let aad = EncryptionContext::get().ok_or_else(|| {
377 EncryptionError::EncryptionFailed("EncryptionContext not set".to_string())
378 })?;
379
380 if FieldEncryption::is_configured() {
381 match get_encryption() {
382 Ok(encryption) => {
383 let encrypted_data =
384 encryption.encrypt_with_aad(data.as_bytes(), aad.as_bytes())?;
385 let json_data = serde_json::to_string(&encrypted_data).map_err(|e| {
386 EncryptionError::EncryptionFailed(format!("Serialization failed: {e}"))
387 })?;
388 Ok(base64_encode(json_data.as_bytes()))
389 }
390 Err(e) => Err(e.clone()),
391 }
392 } else {
393 let json_data = serde_json::to_string(data)
396 .map_err(|e| EncryptionError::EncryptionFailed(format!("JSON encoding failed: {e}")))?;
397 Ok(base64_encode(json_data.as_bytes()))
398 }
399}
400
401pub fn decrypt_sensitive_field_auto(data: &str) -> Result<String, EncryptionError> {
408 let aad = EncryptionContext::get();
410
411 let json_bytes = base64_decode(data)
413 .map_err(|e| EncryptionError::InvalidFormat(format!("Invalid base64: {e}")))?;
414
415 let json_str = String::from_utf8(json_bytes)
416 .map_err(|e| EncryptionError::InvalidFormat(format!("Invalid UTF-8: {e}")))?;
417
418 if FieldEncryption::is_configured() {
420 if let Ok(encryption) = get_encryption() {
421 if let Ok(encrypted_data) = serde_json::from_str::<EncryptedData>(&json_str) {
423 let aad_bytes = aad.as_deref().map(|s| s.as_bytes());
425 let plaintext_bytes = encryption.decrypt_auto(&encrypted_data, aad_bytes)?;
426 return String::from_utf8(plaintext_bytes).map_err(|e| {
427 EncryptionError::DecryptionFailed(format!("Invalid UTF-8 in plaintext: {e}"))
428 });
429 }
430 }
431 }
432
433 serde_json::from_str(&json_str)
436 .map_err(|e| EncryptionError::DecryptionFailed(format!("Invalid JSON string: {e}")))
437}
438
439pub fn generate_encryption_key() -> String {
441 FieldEncryption::generate_key()
442}
443
444#[cfg(test)]
445mod tests {
446 use super::*;
447 use std::env;
448
449 #[test]
450 fn test_encrypt_decrypt_data() {
451 let key = [0u8; 32]; let encryption = FieldEncryption::new_with_key(&key).unwrap();
453
454 let plaintext = b"This is a secret message!";
455 let encrypted = encryption.encrypt(plaintext).unwrap();
456 let decrypted = encryption.decrypt(&encrypted).unwrap();
457
458 assert_eq!(plaintext, decrypted.as_slice());
459 }
460
461 #[test]
462 fn test_encrypt_decrypt_string() {
463 let key = [1u8; 32]; let encryption = FieldEncryption::new_with_key(&key).unwrap();
465
466 let plaintext = "Sensitive API key: sk-1234567890abcdef";
467 let encrypted = encryption.encrypt_string(plaintext).unwrap();
468 let decrypted = encryption.decrypt_string(&encrypted).unwrap();
469
470 assert_eq!(plaintext, decrypted);
471 }
472
473 #[test]
474 fn test_different_keys_produce_different_results() {
475 let key1 = [1u8; 32];
476 let key2 = [2u8; 32];
477 let encryption1 = FieldEncryption::new_with_key(&key1).unwrap();
478 let encryption2 = FieldEncryption::new_with_key(&key2).unwrap();
479
480 let plaintext = "secret";
481 let encrypted1 = encryption1.encrypt_string(plaintext).unwrap();
482 let encrypted2 = encryption2.encrypt_string(plaintext).unwrap();
483
484 assert_ne!(encrypted1, encrypted2);
485
486 assert_eq!(encryption1.decrypt_string(&encrypted1).unwrap(), plaintext);
488 assert_eq!(encryption2.decrypt_string(&encrypted2).unwrap(), plaintext);
489
490 assert!(encryption1.decrypt_string(&encrypted2).is_err());
492 assert!(encryption2.decrypt_string(&encrypted1).is_err());
493 }
494
495 #[test]
496 fn test_nonce_uniqueness() {
497 let key = [3u8; 32];
498 let encryption = FieldEncryption::new_with_key(&key).unwrap();
499
500 let plaintext = "same message";
501 let encrypted1 = encryption.encrypt_string(plaintext).unwrap();
502 let encrypted2 = encryption.encrypt_string(plaintext).unwrap();
503
504 assert_ne!(encrypted1, encrypted2);
506
507 assert_eq!(encryption.decrypt_string(&encrypted1).unwrap(), plaintext);
509 assert_eq!(encryption.decrypt_string(&encrypted2).unwrap(), plaintext);
510 }
511
512 #[test]
513 fn test_invalid_encrypted_data() {
514 let key = [4u8; 32];
515 let encryption = FieldEncryption::new_with_key(&key).unwrap();
516
517 assert!(encryption.decrypt_string("invalid base64!").is_err());
519
520 assert!(encryption
522 .decrypt_string(&base64_encode(b"not json"))
523 .is_err());
524
525 let invalid_json_b64 = base64_encode(b"{\"wrong\": \"structure\"}");
527 assert!(encryption.decrypt_string(&invalid_json_b64).is_err());
528
529 assert!(encryption
531 .decrypt_string(&base64_encode(
532 b"{\"nonce\":\"test\",\"ciphertext\":\"test\",\"version\":1}"
533 ))
534 .is_err());
535 }
536
537 #[test]
538 fn test_generate_key() {
539 let key1 = FieldEncryption::generate_key();
540 let key2 = FieldEncryption::generate_key();
541
542 assert_ne!(key1, key2);
544
545 assert!(base64_decode(&key1).is_ok());
547 assert!(base64_decode(&key2).is_ok());
548
549 assert_eq!(base64_decode(&key1).unwrap().len(), 32);
551 assert_eq!(base64_decode(&key2).unwrap().len(), 32);
552 }
553
554 #[test]
555 fn test_env_key_loading() {
556 let test_key = FieldEncryption::generate_key();
558 env::set_var("STORAGE_ENCRYPTION_KEY", &test_key);
559
560 let encryption = FieldEncryption::new().unwrap();
561 let plaintext = "test message";
562 let encrypted = encryption.encrypt_string(plaintext).unwrap();
563 let decrypted = encryption.decrypt_string(&encrypted).unwrap();
564 assert_eq!(plaintext, decrypted);
565
566 env::remove_var("STORAGE_ENCRYPTION_KEY");
568 assert!(FieldEncryption::new().is_err());
569
570 env::set_var("STORAGE_ENCRYPTION_KEY", &test_key);
572 }
573
574 #[test]
575 fn test_high_level_encryption_functions() {
576 let key = [8u8; 32];
579 let encryption = FieldEncryption::new_with_key(&key).unwrap();
580
581 let plaintext = "sensitive data";
582
583 let encoded = encryption.encrypt_string(plaintext).unwrap();
585 let decoded = encryption.decrypt_string(&encoded).unwrap();
586 assert_eq!(plaintext, decoded);
587
588 assert!(base64_decode(&encoded).is_ok());
590
591 let encrypted_data = encryption.encrypt(plaintext.as_bytes()).unwrap();
593 let decrypted_bytes = encryption.decrypt(&encrypted_data).unwrap();
594 assert_eq!(plaintext.as_bytes(), decrypted_bytes.as_slice());
595 }
596
597 #[test]
598 fn test_fallback_when_encryption_disabled() {
599 let old_key = env::var("STORAGE_ENCRYPTION_KEY").ok();
601
602 env::remove_var("STORAGE_ENCRYPTION_KEY");
603
604 let plaintext = "fallback test";
605
606 let encoded = encrypt_sensitive_field(plaintext).unwrap();
608 let decoded = decrypt_sensitive_field(&encoded).unwrap();
609 assert_eq!(plaintext, decoded);
610
611 let expected_json = serde_json::to_string(plaintext).unwrap();
613 let expected_b64 = base64_encode(expected_json.as_bytes());
614 assert_eq!(encoded, expected_b64);
615
616 if let Some(key) = old_key {
618 env::set_var("STORAGE_ENCRYPTION_KEY", key);
619 }
620 }
621
622 #[test]
623 fn test_core_encryption_methods() {
624 let key = [9u8; 32];
625 let encryption = FieldEncryption::new_with_key(&key).unwrap();
626 let plaintext = "core encryption test";
627
628 let encrypted = encryption.encrypt_string(plaintext).unwrap();
630 let decrypted = encryption.decrypt_string(&encrypted).unwrap();
631 assert_eq!(plaintext, decrypted);
632
633 assert!(base64_decode(&encrypted).is_ok());
635 assert!(!encrypted.contains("nonce"));
637 assert!(!encrypted.contains("ciphertext"));
638 assert!(!encrypted.contains("{"));
639 }
640
641 #[test]
642 fn test_base64_encoding_hides_structure() {
643 let key = [7u8; 32];
644 let encryption = FieldEncryption::new_with_key(&key).unwrap();
645
646 let plaintext = "secret message";
647 let encrypted = encryption.encrypt_string(plaintext).unwrap();
648
649 assert!(base64_decode(&encrypted).is_ok());
651
652 assert!(!encrypted.contains("nonce"));
654 assert!(!encrypted.contains("ciphertext"));
655 assert!(!encrypted.contains("version"));
656 assert!(!encrypted.contains("{"));
657 assert!(!encrypted.contains("}"));
658
659 let decrypted = encryption.decrypt_string(&encrypted).unwrap();
661 assert_eq!(plaintext, decrypted);
662 }
663
664 #[test]
667 fn test_encrypt_decrypt_with_aad() {
668 let key = [0u8; 32];
669 let encryption = FieldEncryption::new_with_key(&key).unwrap();
670 let plaintext = b"secret";
671 let aad = b"oz-relayer:signer:test-id";
672
673 let encrypted = encryption.encrypt_with_aad(plaintext, aad).unwrap();
674 assert_eq!(encrypted.version, 2);
675
676 let decrypted = encryption.decrypt_with_aad(&encrypted, aad).unwrap();
677 assert_eq!(plaintext, decrypted.as_slice());
678 }
679
680 #[test]
681 fn test_wrong_aad_fails() {
682 let key = [0u8; 32];
683 let encryption = FieldEncryption::new_with_key(&key).unwrap();
684 let plaintext = b"secret";
685
686 let encrypted = encryption.encrypt_with_aad(plaintext, b"key-A").unwrap();
687 let result = encryption.decrypt_with_aad(&encrypted, b"key-B");
688
689 assert!(result.is_err()); if let Err(EncryptionError::DecryptionFailed(_)) = result {
691 } else {
693 panic!("Expected DecryptionFailed error for AAD mismatch");
694 }
695 }
696
697 #[test]
698 fn test_v1_backwards_compatibility() {
699 let key = [0u8; 32];
700 let encryption = FieldEncryption::new_with_key(&key).unwrap();
701 let plaintext = b"secret";
702
703 let encrypted = encryption.encrypt(plaintext).unwrap();
705 assert_eq!(encrypted.version, 1);
706
707 let decrypted = encryption.decrypt_auto(&encrypted, None).unwrap();
709 assert_eq!(plaintext, decrypted.as_slice());
710 }
711
712 #[test]
713 fn test_v2_requires_aad() {
714 let key = [0u8; 32];
715 let encryption = FieldEncryption::new_with_key(&key).unwrap();
716 let plaintext = b"secret";
717 let aad = b"storage-key";
718
719 let encrypted = encryption.encrypt_with_aad(plaintext, aad).unwrap();
721 assert_eq!(encrypted.version, 2);
722
723 let result = encryption.decrypt_auto(&encrypted, None);
725 assert!(matches!(result, Err(EncryptionError::MissingAAD)));
726
727 let decrypted = encryption.decrypt_auto(&encrypted, Some(aad)).unwrap();
729 assert_eq!(plaintext, decrypted.as_slice());
730 }
731
732 #[test]
733 fn test_decrypt_auto_unsupported_version() {
734 let key = [0u8; 32];
735 let encryption = FieldEncryption::new_with_key(&key).unwrap();
736
737 let invalid_data = EncryptedData {
738 nonce: base64_encode(&[0u8; 12]),
739 ciphertext: base64_encode(b"fake"),
740 version: 99, };
742
743 let result = encryption.decrypt_auto(&invalid_data, None);
744 assert!(matches!(
745 result,
746 Err(EncryptionError::UnsupportedVersion(99))
747 ));
748 }
749
750 #[test]
751 fn test_encrypt_sensitive_field_with_aad() {
752 let key = [11u8; 32];
754 let encryption = FieldEncryption::new_with_key(&key).unwrap();
755
756 let plaintext = b"sensitive-api-key";
757 let aad = b"oz-relayer:signer:my-signer-id";
758
759 let encrypted = encryption.encrypt_with_aad(plaintext, aad).unwrap();
760 assert_eq!(encrypted.version, 2);
761
762 let decrypted = encryption.decrypt_auto(&encrypted, Some(aad)).unwrap();
763 assert_eq!(plaintext, decrypted.as_slice());
764 }
765
766 #[test]
767 fn test_decrypt_sensitive_field_auto_v1_compat() {
768 let key = [12u8; 32];
770 let encryption = FieldEncryption::new_with_key(&key).unwrap();
771
772 let plaintext = b"legacy-secret";
773
774 let encrypted = encryption.encrypt(plaintext).unwrap();
776 assert_eq!(encrypted.version, 1);
777
778 let decrypted = encryption.decrypt_auto(&encrypted, None).unwrap();
780 assert_eq!(plaintext, decrypted.as_slice());
781
782 let decrypted_with_aad = encryption
784 .decrypt_auto(&encrypted, Some(b"ignored"))
785 .unwrap();
786 assert_eq!(plaintext, decrypted_with_aad.as_slice());
787 }
788
789 #[test]
790 fn test_ciphertext_swap_prevention() {
791 let key = [0u8; 32];
792 let encryption = FieldEncryption::new_with_key(&key).unwrap();
793
794 let secret_a = b"secret-for-signer-a";
795 let secret_b = b"secret-for-signer-b";
796 let aad_a = b"oz-relayer:signer:signer-a";
797 let aad_b = b"oz-relayer:signer:signer-b";
798
799 let encrypted_a = encryption.encrypt_with_aad(secret_a, aad_a).unwrap();
801 let encrypted_b = encryption.encrypt_with_aad(secret_b, aad_b).unwrap();
802
803 let swap_result = encryption.decrypt_with_aad(&encrypted_a, aad_b);
805 assert!(swap_result.is_err());
806
807 let correct_a = encryption.decrypt_with_aad(&encrypted_a, aad_a).unwrap();
809 let correct_b = encryption.decrypt_with_aad(&encrypted_b, aad_b).unwrap();
810
811 assert_eq!(secret_a, correct_a.as_slice());
812 assert_eq!(secret_b, correct_b.as_slice());
813 }
814
815 #[test]
816 fn test_decrypt_with_aad_version_mismatch() {
817 let key = [0u8; 32];
818 let encryption = FieldEncryption::new_with_key(&key).unwrap();
819
820 let encrypted_v1 = encryption.encrypt(b"secret").unwrap();
822 assert_eq!(encrypted_v1.version, 1);
823
824 let result = encryption.decrypt_with_aad(&encrypted_v1, b"some-aad");
826 assert!(result.is_err());
827 if let Err(EncryptionError::InvalidFormat(msg)) = result {
828 assert!(msg.contains("Expected version 2"));
829 assert!(msg.contains("got 1"));
830 } else {
831 panic!("Expected InvalidFormat error for version mismatch");
832 }
833 }
834
835 #[test]
836 fn test_decrypt_with_aad_invalid_nonce_base64() {
837 let key = [0u8; 32];
838 let encryption = FieldEncryption::new_with_key(&key).unwrap();
839
840 let invalid_data = EncryptedData {
841 nonce: "not-valid-base64!!!".to_string(),
842 ciphertext: base64_encode(b"fake"),
843 version: 2,
844 };
845
846 let result = encryption.decrypt_with_aad(&invalid_data, b"aad");
847 assert!(result.is_err());
848 if let Err(EncryptionError::InvalidFormat(msg)) = result {
849 assert!(msg.contains("Invalid nonce"));
850 } else {
851 panic!("Expected InvalidFormat error for invalid nonce base64");
852 }
853 }
854
855 #[test]
856 fn test_decrypt_with_aad_invalid_ciphertext_base64() {
857 let key = [0u8; 32];
858 let encryption = FieldEncryption::new_with_key(&key).unwrap();
859
860 let invalid_data = EncryptedData {
861 nonce: base64_encode(&[0u8; 12]),
862 ciphertext: "not-valid-base64!!!".to_string(),
863 version: 2,
864 };
865
866 let result = encryption.decrypt_with_aad(&invalid_data, b"aad");
867 assert!(result.is_err());
868 if let Err(EncryptionError::InvalidFormat(msg)) = result {
869 assert!(msg.contains("Invalid ciphertext"));
870 } else {
871 panic!("Expected InvalidFormat error for invalid ciphertext base64");
872 }
873 }
874
875 #[test]
876 fn test_decrypt_with_aad_invalid_nonce_length() {
877 let key = [0u8; 32];
878 let encryption = FieldEncryption::new_with_key(&key).unwrap();
879
880 let invalid_data = EncryptedData {
882 nonce: base64_encode(&[0u8; 8]),
883 ciphertext: base64_encode(b"fake-ciphertext"),
884 version: 2,
885 };
886
887 let result = encryption.decrypt_with_aad(&invalid_data, b"aad");
888 assert!(result.is_err());
889 if let Err(EncryptionError::InvalidFormat(msg)) = result {
890 assert!(msg.contains("Invalid nonce length"));
891 } else {
892 panic!("Expected InvalidFormat error for invalid nonce length");
893 }
894 }
895
896 #[test]
897 fn test_decrypt_auto_v2_wrong_aad() {
898 let key = [0u8; 32];
899 let encryption = FieldEncryption::new_with_key(&key).unwrap();
900 let plaintext = b"secret";
901 let correct_aad = b"correct-aad";
902 let wrong_aad = b"wrong-aad";
903
904 let encrypted = encryption.encrypt_with_aad(plaintext, correct_aad).unwrap();
905
906 let result = encryption.decrypt_auto(&encrypted, Some(wrong_aad));
908 assert!(result.is_err());
909 if let Err(EncryptionError::DecryptionFailed(_)) = result {
910 } else {
912 panic!("Expected DecryptionFailed error for wrong AAD");
913 }
914 }
915
916 #[test]
917 fn test_encrypt_with_aad_empty_plaintext() {
918 let key = [0u8; 32];
919 let encryption = FieldEncryption::new_with_key(&key).unwrap();
920 let aad = b"context";
921
922 let encrypted = encryption.encrypt_with_aad(b"", aad).unwrap();
924 assert_eq!(encrypted.version, 2);
925
926 let decrypted = encryption.decrypt_with_aad(&encrypted, aad).unwrap();
927 assert!(decrypted.is_empty());
928 }
929
930 #[test]
931 fn test_encrypt_with_aad_empty_aad() {
932 let key = [0u8; 32];
933 let encryption = FieldEncryption::new_with_key(&key).unwrap();
934 let plaintext = b"secret";
935
936 let encrypted = encryption.encrypt_with_aad(plaintext, b"").unwrap();
938 assert_eq!(encrypted.version, 2);
939
940 let decrypted = encryption.decrypt_with_aad(&encrypted, b"").unwrap();
941 assert_eq!(plaintext, decrypted.as_slice());
942
943 let result = encryption.decrypt_with_aad(&encrypted, b"some-aad");
945 assert!(result.is_err());
946 }
947
948 #[test]
949 fn test_encrypt_with_aad_large_data() {
950 let key = [0u8; 32];
951 let encryption = FieldEncryption::new_with_key(&key).unwrap();
952 let large_plaintext = vec![0xABu8; 10_000];
953 let aad = b"large-data-context";
954
955 let encrypted = encryption.encrypt_with_aad(&large_plaintext, aad).unwrap();
956 assert_eq!(encrypted.version, 2);
957
958 let decrypted = encryption.decrypt_with_aad(&encrypted, aad).unwrap();
959 assert_eq!(large_plaintext, decrypted);
960 }
961
962 #[test]
963 fn test_encrypt_with_aad_nonce_uniqueness() {
964 let key = [0u8; 32];
965 let encryption = FieldEncryption::new_with_key(&key).unwrap();
966 let plaintext = b"same message";
967 let aad = b"same-aad";
968
969 let encrypted1 = encryption.encrypt_with_aad(plaintext, aad).unwrap();
970 let encrypted2 = encryption.encrypt_with_aad(plaintext, aad).unwrap();
971
972 assert_ne!(encrypted1.nonce, encrypted2.nonce);
974 assert_ne!(encrypted1.ciphertext, encrypted2.ciphertext);
975
976 assert_eq!(
978 encryption.decrypt_with_aad(&encrypted1, aad).unwrap(),
979 plaintext
980 );
981 assert_eq!(
982 encryption.decrypt_with_aad(&encrypted2, aad).unwrap(),
983 plaintext
984 );
985 }
986
987 #[test]
988 fn test_encrypt_sensitive_field_with_aad_fallback() {
989 let old_key = env::var("STORAGE_ENCRYPTION_KEY").ok();
991 env::remove_var("STORAGE_ENCRYPTION_KEY");
992
993 let plaintext = "fallback-secret";
994 let aad = "context-aad".to_string();
995
996 let encoded = EncryptionContext::with_aad_sync(aad.clone(), || {
998 encrypt_sensitive_field_with_aad(plaintext).unwrap()
999 });
1000 let decoded = EncryptionContext::with_aad_sync(aad, || {
1001 decrypt_sensitive_field_auto(&encoded).unwrap()
1002 });
1003 assert_eq!(plaintext, decoded);
1004
1005 let expected_json = serde_json::to_string(plaintext).unwrap();
1007 let expected_b64 = base64_encode(expected_json.as_bytes());
1008 assert_eq!(encoded, expected_b64);
1009
1010 if let Some(key) = old_key {
1012 env::set_var("STORAGE_ENCRYPTION_KEY", key);
1013 }
1014 }
1015
1016 #[test]
1017 fn test_encrypt_sensitive_field_with_aad_wrong_aad_on_decrypt() {
1018 let key = [10u8; 32];
1020 let encryption = FieldEncryption::new_with_key(&key).unwrap();
1021
1022 let plaintext = b"sensitive-data";
1023 let correct_aad = b"correct-context";
1024 let wrong_aad = b"wrong-context";
1025
1026 let encrypted = encryption.encrypt_with_aad(plaintext, correct_aad).unwrap();
1027 assert_eq!(encrypted.version, 2);
1028
1029 let result = encryption.decrypt_auto(&encrypted, Some(wrong_aad));
1031 assert!(result.is_err());
1032 if let Err(EncryptionError::DecryptionFailed(_)) = result {
1033 } else {
1035 panic!("Expected DecryptionFailed error for wrong AAD");
1036 }
1037
1038 let decrypted = encryption
1040 .decrypt_auto(&encrypted, Some(correct_aad))
1041 .unwrap();
1042 assert_eq!(plaintext, decrypted.as_slice());
1043 }
1044
1045 #[test]
1046 fn test_decrypt_auto_v1_ignores_aad() {
1047 let key = [0u8; 32];
1048 let encryption = FieldEncryption::new_with_key(&key).unwrap();
1049 let plaintext = b"secret";
1050
1051 let encrypted = encryption.encrypt(plaintext).unwrap();
1053 assert_eq!(encrypted.version, 1);
1054
1055 let decrypted = encryption
1057 .decrypt_auto(&encrypted, Some(b"any-aad"))
1058 .unwrap();
1059 assert_eq!(plaintext, decrypted.as_slice());
1060
1061 let decrypted2 = encryption
1063 .decrypt_auto(&encrypted, Some(b"different-aad"))
1064 .unwrap();
1065 assert_eq!(plaintext, decrypted2.as_slice());
1066 }
1067
1068 #[test]
1069 fn test_decrypt_with_aad_tampered_ciphertext() {
1070 let key = [0u8; 32];
1071 let encryption = FieldEncryption::new_with_key(&key).unwrap();
1072 let plaintext = b"secret";
1073 let aad = b"context";
1074
1075 let mut encrypted = encryption.encrypt_with_aad(plaintext, aad).unwrap();
1076
1077 let mut ciphertext_bytes = base64_decode(&encrypted.ciphertext).unwrap();
1079 if !ciphertext_bytes.is_empty() {
1080 ciphertext_bytes[0] ^= 0xFF; }
1082 encrypted.ciphertext = base64_encode(&ciphertext_bytes);
1083
1084 let result = encryption.decrypt_with_aad(&encrypted, aad);
1086 assert!(result.is_err());
1087 if let Err(EncryptionError::DecryptionFailed(_)) = result {
1088 } else {
1090 panic!("Expected DecryptionFailed error for tampered ciphertext");
1091 }
1092 }
1093}