(in-package :crypto-tests)

(defun hex-string-to-byte-array (string &key (start 0) (end nil))
  (declare (type string string))
  (let* ((end (or end (length string)))
         (length (/ (- end start) 2))
         (key (make-array length :element-type '(unsigned-byte 8))))
    (declare (type (simple-array (unsigned-byte 8) (*)) key))
    (flet ((char-to-digit (char)
             (let ((x (position char "0123456789abcdef" :test #'char-equal)))
               (or x (error "Invalid hex key ~A specified" string)))))
      (loop for i from 0
            for j from start below end by 2
            do (setf (aref key i)
                     (+ (* (char-to-digit (char string j)) 16)
                        (char-to-digit (char string (1+ j)))))
            finally (return key)))))

(defmacro ecb-mode-test (hexkey hexinput hexoutput)
  (let* ((cipher (symbol-value '*test-cipher*))
         (e-test-name (intern (format nil "~A.ENCRYPT.~D" cipher
                                      (symbol-value '*test-number*))))
         (d-test-name (intern (format nil "~A.DECRYPT.~D" cipher
                                      (symbol-value '*test-number*)))))
    `(progn
      (eval-when (:compile-toplevel)
        (incf *test-number*))
      (rt:deftest ,e-test-name
          (ecb-mode-test-guts ,cipher #'crypto:encrypt
           ,hexkey ,hexinput ,hexoutput)
        t)
      (rt:deftest ,d-test-name
        (ecb-mode-test-guts ,cipher #'crypto:decrypt
         ,hexkey ,hexoutput ,hexinput)
        t))))

(defun frob-hex-string (cipher-name func hexkey hexinput)
  (let* ((key (hex-string-to-byte-array hexkey))
         (input (hex-string-to-byte-array hexinput))
         (cipher (crypto:make-cipher cipher-name :ecb key))
         (scratch (copy-seq input)))
    (funcall func cipher input scratch)
    scratch))

(defun ecb-mode-test-guts (cipher-name func hexkey hexinput hexoutput)
  (let* ((output (hex-string-to-byte-array hexoutput)))
    (not (mismatch (frob-hex-string cipher-name func hexkey hexinput)
                   output))))

;;; FIXME: this should really be some block cipher that we know has passed
;;; its tests from earlier in the testing process
(defvar *encryption/decryption-consistency-cipher* :aes)
(defvar *encryption/decryption-consistency-blocks* 5)
(defvar *encryption/decryption-consistency-operations* 1000)

(defun random-key (length)
  (let ((key (make-array length :element-type '(unsigned-byte 8))))
    (dotimes (i length key)
      (setf (aref key i) (random 256)))))

(defun run-single-consistency-check (mode key iv)
  (let* ((block-length (ironclad:block-length *encryption/decryption-consistency-cipher*))
         (input (make-array (* *encryption/decryption-consistency-blocks* block-length)
                            :element-type '(unsigned-byte 8)
                            :initial-element 0))
         (scratch (make-array (* *encryption/decryption-consistency-blocks* block-length)
                              :element-type '(unsigned-byte 8)))
         (forward-cipher (ironclad:make-cipher *encryption/decryption-consistency-cipher*
                                               mode key iv))
         (reverse-cipher (ironclad:make-cipher *encryption/decryption-consistency-cipher*
                                               mode key iv)))
    (ironclad:encrypt forward-cipher input scratch)
    (ironclad:decrypt reverse-cipher scratch scratch)
    (not (mismatch input scratch))))
    
(defun check-encryption/decryption-consistency (mode)
  ;; try not to stomp all over the user's state...
  (let ((*random-state* (make-random-state))
        (block-length (ironclad:block-length *encryption/decryption-consistency-cipher*)))
    (dotimes (i *encryption/decryption-consistency-operations* t)
      ;; FIXME: should introspect for the key length instead
      (unless (run-single-consistency-check mode (random-key 16)
                                            (random-key block-length))
        (return-from check-encryption/decryption-consistency nil)))))

;;; digest testing routines

(defmacro digest-test (string hexdigest)
  (let* ((digest (symbol-value '*test-digest*))
         (test-one-shot-name (intern (format nil "~A-STRING-ONESHOT.~D" digest
                                             (symbol-value '*test-number*))))
         (test-incremental-name (intern (format nil "~A-STRING-INCREMENTAL.~D"
                                                digest
                                                (symbol-value '*test-number*))))
         (test-fill-pointer-name (intern (format nil "~A-STRING-FILL-POINTER.~D"
                                                 digest
                                                 (symbol-value '*test-number*)))))
    `(progn
      (eval-when (:compile-toplevel)
        (incf *test-number*))
      (rt:deftest ,test-one-shot-name
          (digest-test-one-shot-guts ,digest ,string ,hexdigest) t)
      (rt:deftest ,test-incremental-name
          (digest-test-incremental-guts ,digest ,string ,hexdigest) t)
      (rt:deftest ,test-fill-pointer-name
          (digest-test-fill-pointer-guts ,digest ,string ,hexdigest) t))))

(defmacro digest-bit-test (leading byte trailing hexdigest)
  (let* ((digest (symbol-value '*test-digest*))
         (test-name (intern (format nil "~A-BIT.~D" digest
                                    (symbol-value '*test-number*)))))
    `(progn
      (eval-when (:compile-toplevel)
        (incf *test-number*))
      (rt:deftest ,test-name (digest-bit-test-guts ,digest ,leading ,byte ,trailing ,hexdigest) t))))

(defun digest-test-one-shot-guts (digest string hexdigest)
  (let* ((input (crypto:ascii-string-to-byte-array string))
         (expected-digest (hex-string-to-byte-array hexdigest))
         (result (crypto:digest-sequence digest input)))
    (not (mismatch result expected-digest))))

(defun digest-test-incremental-guts (digest string hexdigest)
  (let* ((input (crypto:ascii-string-to-byte-array string))
         (expected-digest (hex-string-to-byte-array hexdigest))
         (digester (crypto:make-digest digest))
         (length (length input)))
    (loop for i from 0 below length
          do (crypto:update-digest digester input :start i :end (1+ i))
          finally (let ((result (crypto:produce-digest digester)))
                    (return (not (mismatch result expected-digest)))))))

(defun digest-test-fill-pointer-guts (digest string hexdigest)
  (let* ((octets (crypto:ascii-string-to-byte-array string))
         (input (let ((x (make-array (* 2 (length octets))
                                     :fill-pointer 0
                                     :element-type '(unsigned-byte 8))))
                  (dotimes (i (length octets) x)
                    (vector-push (aref octets i) x))))
         (expected-digest (hex-string-to-byte-array hexdigest))
         (result (crypto:digest-sequence digest input)))
    (not (mismatch result expected-digest))))

(defun digest-bit-test-guts (digest leading byte trailing hexdigest)
  (let* ((input (let ((vector (make-array (+ 1 leading trailing)
                                          :element-type '(unsigned-byte 8)
                                          :initial-element 0)))
                  (setf (aref vector leading) byte)
                  vector))
         (expected-digest (hex-string-to-byte-array hexdigest))
         (result (crypto:digest-sequence digest input)))
    (not (mismatch result expected-digest))))
