updates.
[silc.git] / lib / silcske / silcske.c
index e8e777a41351515fbe1050c9f1df6e7e47a78c1d..978739224bed989658ad40d099649f9a90e52739 100644 (file)
@@ -4,7 +4,7 @@
 
   Author: Pekka Riikonen <priikone@poseidon.pspt.fi>
 
-  Copyright (C) 2000 Pekka Riikonen
+  Copyright (C) 2000 - 2001 Pekka Riikonen
 
   This program is free software; you can redistribute it and/or modify
   it under the terms of the GNU General Public License as published by
@@ -48,13 +48,9 @@ void silc_ske_free(SilcSKE ske)
     if (ske->start_payload)
       silc_ske_payload_start_free(ske->start_payload);
 
-    /* Free KE1 payload */
+    /* Free KE payload */
     if (ske->ke1_payload)
-      silc_ske_payload_one_free(ske->ke1_payload);
-
-    /* Free KE2 payload */
-    if (ske->ke2_payload)
-      silc_ske_payload_two_free(ske->ke2_payload);
+      silc_ske_payload_ke_free(ske->ke1_payload);
 
     /* Free rest */
     if (ske->prop) {
@@ -222,17 +218,18 @@ SilcSKEStatus silc_ske_initiator_phase_1(SilcSKE ske,
 
 /* This function creates random number x, such that 1 < x < q and 
    computes e = g ^ x mod p and sends the result to the remote end in 
-   Key Exchange Payload. */
+   Key Exchange Payload. */
 
 SilcSKEStatus silc_ske_initiator_phase_2(SilcSKE ske,
                                         SilcPublicKey public_key,
+                                        SilcPrivateKey private_key,
                                         SilcSKESendPacketCb send_packet,
                                         void *context)
 {
   SilcSKEStatus status = SILC_SKE_STATUS_OK;
   SilcBuffer payload_buf;
   SilcInt *x, e;
-  SilcSKEOnePayload *payload;
+  SilcSKEKEPayload *payload;
   unsigned int pk_len;
 
   SILC_LOG_DEBUG(("Start"));
@@ -257,10 +254,15 @@ SilcSKEStatus silc_ske_initiator_phase_2(SilcSKE ske,
   silc_mp_init(&e);
   silc_mp_powm(&e, &ske->prop->group->generator, x, 
               &ske->prop->group->group);
-  
-  /* Encode the result to Key Exchange 1 Payload. */
+
+  /* Encode the result to Key Exchange Payload. */
+
   payload = silc_calloc(1, sizeof(*payload));
-  payload->e = e;
+  ske->ke1_payload = payload;
+
+  payload->x = e;
+
+  /* Get public key */
   payload->pk_data = silc_pkcs_public_key_encode(public_key, &pk_len);
   if (!payload->pk_data) {
     silc_mp_clear(x);
@@ -272,7 +274,32 @@ SilcSKEStatus silc_ske_initiator_phase_2(SilcSKE ske,
   }
   payload->pk_len = pk_len;
   payload->pk_type = SILC_SKE_PK_TYPE_SILC;
-  status = silc_ske_payload_one_encode(ske, payload, &payload_buf);
+
+  /* Compute signature data if we are doing mutual authentication */
+  if (ske->start_payload->flags & SILC_SKE_SP_FLAG_MUTUAL) {
+    unsigned char hash[32], sign[1024];
+    unsigned int hash_len, sign_len;
+
+    SILC_LOG_DEBUG(("We are doing mutual authentication"));
+    SILC_LOG_DEBUG(("Computing HASH value"));
+
+    /* Compute the hash value */
+    memset(hash, 0, sizeof(hash));
+    silc_ske_make_hash(ske, hash, &hash_len, TRUE);
+
+    SILC_LOG_DEBUG(("Signing HASH_i value"));
+    
+    /* Sign the hash value */
+    silc_pkcs_private_key_data_set(ske->prop->pkcs, private_key->prv, 
+                                  private_key->prv_len);
+    silc_pkcs_sign(ske->prop->pkcs, hash, hash_len, sign, &sign_len);
+    payload->sign_data = silc_calloc(sign_len, sizeof(unsigned char));
+    memcpy(payload->sign_data, sign, sign_len);
+    memset(sign, 0, sizeof(sign));
+    payload->sign_len = sign_len;
+  }
+
+  status = silc_ske_payload_ke_encode(ske, payload, &payload_buf);
   if (status != SILC_SKE_STATUS_OK) {
     silc_mp_clear(x);
     silc_free(x);
@@ -283,7 +310,6 @@ SilcSKEStatus silc_ske_initiator_phase_2(SilcSKE ske,
     return status;
   }
 
-  ske->ke1_payload = payload;
   ske->x = x;
 
   /* Send the packet. */
@@ -295,19 +321,19 @@ SilcSKEStatus silc_ske_initiator_phase_2(SilcSKE ske,
   return status;
 }
 
-/* Receives Key Exchange Payload from responder consisting responders
+/* Receives Key Exchange Payload from responder consisting responders
    public key, f, and signature. This function verifies the public key,
    computes the secret shared key and verifies the signature. */
 
 SilcSKEStatus silc_ske_initiator_finish(SilcSKE ske,
-                                       SilcBuffer ke2_payload,
+                                       SilcBuffer ke_payload,
                                        SilcSKEVerifyCb verify_key,
                                        void *verify_context,
                                        SilcSKECb callback,
                                        void *context)
 {
   SilcSKEStatus status = SILC_SKE_STATUS_OK;
-  SilcSKETwoPayload *payload;
+  SilcSKEKEPayload *payload;
   SilcPublicKey public_key = NULL;
   SilcInt *KEY;
   unsigned char hash[32];
@@ -316,7 +342,7 @@ SilcSKEStatus silc_ske_initiator_finish(SilcSKE ske,
   SILC_LOG_DEBUG(("Start"));
 
   /* Decode the payload */
-  status = silc_ske_payload_two_decode(ske, ke2_payload, &payload);
+  status = silc_ske_payload_ke_decode(ske, ke_payload, &payload);
   if (status != SILC_SKE_STATUS_OK) {
     ske->status = status;
     return status;
@@ -328,7 +354,7 @@ SilcSKEStatus silc_ske_initiator_finish(SilcSKE ske,
   /* Compute the shared secret key */
   KEY = silc_calloc(1, sizeof(*KEY));
   silc_mp_init(KEY);
-  silc_mp_powm(KEY, &payload->f, ske->x, &ske->prop->group->group);
+  silc_mp_powm(KEY, &payload->x, ske->x, &ske->prop->group->group);
   ske->KEY = KEY;
 
   SILC_LOG_DEBUG(("Verifying public key"));
@@ -349,7 +375,7 @@ SilcSKEStatus silc_ske_initiator_finish(SilcSKE ske,
   SILC_LOG_DEBUG(("Public key is authentic"));
 
   /* Compute the hash value */
-  status = silc_ske_make_hash(ske, hash, &hash_len);
+  status = silc_ske_make_hash(ske, hash, &hash_len, FALSE);
   if (status != SILC_SKE_STATUS_OK)
     goto err;
 
@@ -357,7 +383,7 @@ SilcSKEStatus silc_ske_initiator_finish(SilcSKE ske,
   memcpy(ske->hash, hash, hash_len);
   ske->hash_len = hash_len;
 
-  SILC_LOG_DEBUG(("Verifying signature"));
+  SILC_LOG_DEBUG(("Verifying signature (HASH_i)"));
 
   /* Verify signature */
   silc_pkcs_public_key_data_set(ske->prop->pkcs, public_key->pk, 
@@ -384,7 +410,7 @@ SilcSKEStatus silc_ske_initiator_finish(SilcSKE ske,
 
  err:
   memset(hash, 'F', sizeof(hash));
-  silc_ske_payload_two_free(payload);
+  silc_ske_payload_ke_free(payload);
   ske->ke2_payload = NULL;
 
   silc_mp_clear(ske->KEY);
@@ -417,6 +443,7 @@ SilcSKEStatus silc_ske_responder_start(SilcSKE ske, SilcRng rng,
                                       SilcSocketConnection sock,
                                       char *version,
                                       SilcBuffer start_payload,
+                                      int mutual_auth,
                                       SilcSKECb callback,
                                       void *context)
 {
@@ -439,6 +466,12 @@ SilcSKEStatus silc_ske_responder_start(SilcSKE ske, SilcRng rng,
      compute the HASH value. */
   ske->start_payload_copy = silc_buffer_copy(start_payload);
 
+  /* Force the mutual authentication flag if we want to do it. */
+  if (mutual_auth) {
+    SILC_LOG_DEBUG(("Force mutual authentication"));
+    remote_payload->flags |= SILC_SKE_SP_FLAG_MUTUAL;
+  }
+
   /* Parse and select the security properties from the payload */
   payload = silc_calloc(1, sizeof(*payload));
   status = silc_ske_select_security_properties(ske, version,
@@ -551,33 +584,87 @@ SilcSKEStatus silc_ske_responder_phase_1(SilcSKE ske,
   return status;
 }
 
-/* This function receives the Key Exchange Payload from the initiator.
+/* This function receives the Key Exchange Payload from the initiator.
    After processing the payload this then selects random number x,
    such that 1 < x < q and computes f = g ^ x mod p. This then puts
-   the result f to a Key Exchange Payload which is later processed
+   the result f to a Key Exchange Payload which is later processed
    in ske_responder_finish function. The callback function should
    not touch the payload (it should merely call the ske_responder_finish
    function). */
 
 SilcSKEStatus silc_ske_responder_phase_2(SilcSKE ske,
-                                        SilcBuffer ke1_payload,
+                                        SilcBuffer ke_payload,
+                                        SilcSKEVerifyCb verify_key,
+                                        void *verify_context,
                                         SilcSKECb callback,
                                         void *context)
 {
   SilcSKEStatus status = SILC_SKE_STATUS_OK;
-  SilcSKEOnePayload *one_payload;
-  SilcSKETwoPayload *two_payload;
+  SilcSKEKEPayload *recv_payload, *send_payload;
   SilcInt *x, f;
 
   SILC_LOG_DEBUG(("Start"));
 
-  /* Decode Key Exchange Payload */
-  status = silc_ske_payload_one_decode(ske, ke1_payload, &one_payload);
+  /* Decode Key Exchange Payload */
+  status = silc_ske_payload_ke_decode(ske, ke_payload, &recv_payload);
   if (status != SILC_SKE_STATUS_OK) {
     ske->status = status;
     return status;
   }
 
+  ske->ke1_payload = recv_payload;
+
+  /* Verify the received public key and verify the signature if we are
+     doing mutual authentication. */
+  if (ske->start_payload->flags & SILC_SKE_SP_FLAG_MUTUAL) {
+    SilcPublicKey public_key = NULL;
+    unsigned char hash[32];
+    unsigned int hash_len;
+
+    SILC_LOG_DEBUG(("We are doing mutual authentication"));
+    SILC_LOG_DEBUG(("Verifying public key"));
+    
+    if (!silc_pkcs_public_key_decode(recv_payload->pk_data, 
+                                    recv_payload->pk_len, 
+                                    &public_key)) {
+      status = SILC_SKE_STATUS_UNSUPPORTED_PUBLIC_KEY;
+      return status;
+    }
+
+    if (verify_key) {
+      status = (*verify_key)(ske, recv_payload->pk_data, recv_payload->pk_len,
+                            recv_payload->pk_type, verify_context);
+      if (status != SILC_SKE_STATUS_OK)
+       return status;
+    }
+
+    SILC_LOG_DEBUG(("Public key is authentic"));
+
+    /* Compute the hash value */
+    status = silc_ske_make_hash(ske, hash, &hash_len, TRUE);
+    if (status != SILC_SKE_STATUS_OK)
+      return status;
+
+    SILC_LOG_DEBUG(("Verifying signature"));
+    
+    /* Verify signature */
+    silc_pkcs_public_key_data_set(ske->prop->pkcs, public_key->pk, 
+                                 public_key->pk_len);
+    if (silc_pkcs_verify(ske->prop->pkcs, recv_payload->sign_data, 
+                        recv_payload->sign_len, hash, hash_len) == FALSE) {
+      
+      SILC_LOG_DEBUG(("Signature don't match"));
+      
+      status = SILC_SKE_STATUS_INCORRECT_SIGNATURE;
+      return status;
+    }
+    
+    SILC_LOG_DEBUG(("Signature is Ok"));
+    
+    silc_pkcs_public_key_free(public_key);
+    memset(hash, 'F', hash_len);
+  }
+
   /* Create the random number x, 1 < x < q. */
   x = silc_calloc(1, sizeof(*x));
   silc_mp_init(x);
@@ -599,11 +686,10 @@ SilcSKEStatus silc_ske_responder_phase_2(SilcSKE ske,
               &ske->prop->group->group);
   
   /* Save the results for later processing */
-  two_payload = silc_calloc(1, sizeof(*two_payload));
-  two_payload->f = f;
+  send_payload = silc_calloc(1, sizeof(*send_payload));
+  send_payload->x = f;
   ske->x = x;
-  ske->ke1_payload = one_payload;
-  ske->ke2_payload = two_payload;
+  ske->ke2_payload = send_payload;
 
   /* Call the callback. */
   if (callback)
@@ -614,7 +700,7 @@ SilcSKEStatus silc_ske_responder_phase_2(SilcSKE ske,
 
 /* This function computes the secret shared key KEY = e ^ x mod p, and, 
    a hash value to be signed and sent to the other end. This then
-   encodes Key Exchange Payload and sends it to the other end. */
+   encodes Key Exchange Payload and sends it to the other end. */
 
 SilcSKEStatus silc_ske_responder_finish(SilcSKE ske,
                                        SilcPublicKey public_key,
@@ -641,7 +727,7 @@ SilcSKEStatus silc_ske_responder_finish(SilcSKE ske,
   /* Compute the shared secret key */
   KEY = silc_calloc(1, sizeof(*KEY));
   silc_mp_init(KEY);
-  silc_mp_powm(KEY, &ske->ke1_payload->e, ske->x, 
+  silc_mp_powm(KEY, &ske->ke1_payload->x, ske->x, 
               &ske->prop->group->group);
   ske->KEY = KEY;
 
@@ -661,7 +747,7 @@ SilcSKEStatus silc_ske_responder_finish(SilcSKE ske,
 
   /* Compute the hash value */
   memset(hash, 0, sizeof(hash));
-  status = silc_ske_make_hash(ske, hash, &hash_len);
+  status = silc_ske_make_hash(ske, hash, &hash_len, FALSE);
   if (status != SILC_SKE_STATUS_OK)
     goto err;
 
@@ -680,9 +766,9 @@ SilcSKEStatus silc_ske_responder_finish(SilcSKE ske,
   memset(sign, 0, sizeof(sign));
   ske->ke2_payload->sign_len = sign_len;
 
-  /* Encode the Key Exchange Payload */
-  status = silc_ske_payload_two_encode(ske, ske->ke2_payload,
-                                      &payload_buf);
+  /* Encode the Key Exchange Payload */
+  status = silc_ske_payload_ke_encode(ske, ske->ke2_payload,
+                                     &payload_buf);
   if (status != SILC_SKE_STATUS_OK)
     goto err;
 
@@ -698,7 +784,7 @@ SilcSKEStatus silc_ske_responder_finish(SilcSKE ske,
   silc_mp_clear(ske->KEY);
   silc_free(ske->KEY);
   ske->KEY = NULL;
-  silc_ske_payload_two_free(ske->ke2_payload);
+  silc_ske_payload_ke_free(ske->ke2_payload);
 
   if (status == SILC_SKE_STATUS_OK)
     return SILC_SKE_STATUS_ERROR;
@@ -1223,11 +1309,15 @@ SilcSKEStatus silc_ske_create_rnd(SilcSKE ske, SilcInt n,
   return status;
 }
 
-/* Creates a hash value HASH as defined in the SKE protocol. */
+/* Creates a hash value HASH as defined in the SKE protocol. If the
+   `initiator' is TRUE then this function is used to create the HASH_i
+   hash value defined in the protocol. If it is FALSE then this is used
+   to create the HASH value defined by the protocol. */
 
 SilcSKEStatus silc_ske_make_hash(SilcSKE ske, 
                                 unsigned char *return_hash,
-                                unsigned int *return_hash_len)
+                                unsigned int *return_hash_len,
+                                int initiator)
 {
   SilcSKEStatus status = SILC_SKE_STATUS_OK;
   SilcBuffer buf;
@@ -1237,47 +1327,79 @@ SilcSKEStatus silc_ske_make_hash(SilcSKE ske,
 
   SILC_LOG_DEBUG(("Start"));
 
-  e = silc_mp_mp2bin(&ske->ke1_payload->e, 0, &e_len);
-  f = silc_mp_mp2bin(&ske->ke2_payload->f, 0, &f_len);
-  KEY = silc_mp_mp2bin(ske->KEY, 0, &KEY_len);
-
-  buf = silc_buffer_alloc(ske->start_payload_copy->len + 
-                         ske->pk_len + e_len + f_len + KEY_len);
-  silc_buffer_pull_tail(buf, SILC_BUFFER_END(buf));
+  if (initiator == FALSE) {
+    e = silc_mp_mp2bin(&ske->ke1_payload->x, 0, &e_len);
+    f = silc_mp_mp2bin(&ske->ke2_payload->x, 0, &f_len);
+    KEY = silc_mp_mp2bin(ske->KEY, 0, &KEY_len);
+    
+    buf = silc_buffer_alloc(ske->start_payload_copy->len + 
+                           ske->pk_len + e_len + f_len + KEY_len);
+    silc_buffer_pull_tail(buf, SILC_BUFFER_END(buf));
+    
+    /* Format the buffer used to compute the hash value */
+    ret = 
+      silc_buffer_format(buf,
+                        SILC_STR_UI_XNSTRING(ske->start_payload_copy->data,
+                                             ske->start_payload_copy->len),
+                        SILC_STR_UI_XNSTRING(ske->pk, ske->pk_len),
+                        SILC_STR_UI_XNSTRING(e, e_len),
+                        SILC_STR_UI_XNSTRING(f, f_len),
+                        SILC_STR_UI_XNSTRING(KEY, KEY_len),
+                        SILC_STR_END);
+    if (ret == -1) {
+      silc_buffer_free(buf);
+      memset(e, 0, e_len);
+      memset(f, 0, f_len);
+      memset(KEY, 0, KEY_len);
+      silc_free(e);
+      silc_free(f);
+      silc_free(KEY);
+      return SILC_SKE_STATUS_ERROR;
+    }
 
-  /* Format the buffer used to compute the hash value */
-  ret = silc_buffer_format(buf,
-                          SILC_STR_UI_XNSTRING(ske->start_payload_copy->data,
-                                               ske->start_payload_copy->len),
-                          SILC_STR_UI_XNSTRING(ske->pk, ske->pk_len),
-                          SILC_STR_UI_XNSTRING(e, e_len),
-                          SILC_STR_UI_XNSTRING(f, f_len),
-                          SILC_STR_UI_XNSTRING(KEY, KEY_len),
-                          SILC_STR_END);
-  if (ret == -1) {
-    silc_buffer_free(buf);
     memset(e, 0, e_len);
     memset(f, 0, f_len);
     memset(KEY, 0, KEY_len);
     silc_free(e);
     silc_free(f);
     silc_free(KEY);
-    return SILC_SKE_STATUS_ERROR;
+  } else {
+    e = silc_mp_mp2bin(&ske->ke1_payload->x, 0, &e_len);
+
+    buf = silc_buffer_alloc(ske->start_payload_copy->len + 
+                           ske->pk_len + e_len);
+    silc_buffer_pull_tail(buf, SILC_BUFFER_END(buf));
+    
+    /* Format the buffer used to compute the hash value */
+    ret = 
+      silc_buffer_format(buf,
+                        SILC_STR_UI_XNSTRING(ske->start_payload_copy->data,
+                                             ske->start_payload_copy->len),
+                        SILC_STR_UI_XNSTRING(ske->pk, ske->pk_len),
+                        SILC_STR_UI_XNSTRING(e, e_len),
+                        SILC_STR_END);
+    if (ret == -1) {
+      silc_buffer_free(buf);
+      memset(e, 0, e_len);
+      silc_free(e);
+      return SILC_SKE_STATUS_ERROR;
+    }
+
+    memset(e, 0, e_len);
+    silc_free(e);
   }
 
   /* Make the hash */
   silc_hash_make(ske->prop->hash, buf->data, buf->len, return_hash);
   *return_hash_len = ske->prop->hash->hash->hash_len;
 
-  SILC_LOG_HEXDUMP(("Hash"), return_hash, *return_hash_len);
+  if (initiator == FALSE) {
+    SILC_LOG_HEXDUMP(("HASH"), return_hash, *return_hash_len);
+  } else {
+    SILC_LOG_HEXDUMP(("HASH_i"), return_hash, *return_hash_len);
+  }
 
   silc_buffer_free(buf);
-  memset(e, 0, e_len);
-  memset(f, 0, f_len);
-  memset(KEY, 0, KEY_len);
-  silc_free(e);
-  silc_free(f);
-  silc_free(KEY);
 
   return status;
 }