Added blocking support for wrapped packet stream.
[silc.git] / lib / silccore / silcpacket.c
index 9da144abe00a74d185e6cf46d95a1c57825096df..f6b4c87ff7a587ffd285801456683f75aec0fc30 100644 (file)
@@ -4,7 +4,7 @@
 
   Author: Pekka Riikonen <priikone@silcnet.org>
 
-  Copyright (C) 1997 - 2006 Pekka Riikonen
+  Copyright (C) 1997 - 2007 Pekka Riikonen
 
   This program is free software; you can redistribute it and/or modify
   it under the terms of the GNU General Public License as published by
@@ -50,10 +50,10 @@ struct SilcPacketEngineStruct {
 
 /* Packet processor context */
 typedef struct SilcPacketProcessStruct {
-  SilcInt32 priority;                   /* Priority */
   SilcPacketType *types;                /* Packets to process */
   SilcPacketCallbacks *callbacks;       /* Callbacks or NULL */
   void *callback_context;
+  SilcInt32 priority;                   /* Priority */
 } *SilcPacketProcess;
 
 /* UDP remote stream tuple */
@@ -67,7 +67,7 @@ struct SilcPacketStreamStruct {
   struct SilcPacketStreamStruct *next;
   SilcPacketEngineContext sc;           /* Per scheduler context */
   SilcStream stream;                    /* Underlaying stream */
-  SilcMutex lock;                       /* Stream lock */
+  SilcMutex lock;                       /* Packet stream lock */
   SilcDList process;                    /* Packet processors, or NULL */
   SilcPacketRemoteUDP remote_udp;       /* UDP remote stream tuple, or NULL */
   void *stream_context;                         /* Stream context */
@@ -325,7 +325,7 @@ static inline SilcBool silc_packet_stream_read(SilcPacketStream ps,
       }
 
       /* See if remote packet stream exist for this sender */
-      snprintf(tuple, sizeof(tuple), "%d%s", remote_port, remote_ip);
+      silc_snprintf(tuple, sizeof(tuple), "%d%s", remote_port, remote_ip);
       silc_mutex_lock(ps->sc->engine->lock);
       if (silc_hash_table_find(ps->sc->engine->udp_remote, tuple, NULL,
                               (void *)&remote)) {
@@ -402,6 +402,8 @@ static void silc_packet_stream_io(SilcStream stream, SilcStreamStatus status,
 
   switch (status) {
   case SILC_STREAM_CAN_READ:
+    /* Reading is locked also with stream->lock because we may be reading
+       at the same time other thread is writing to same underlaying stream. */
     SILC_LOG_DEBUG(("Reading data from stream"));
 
     /* Read data from stream */
@@ -667,10 +669,16 @@ SilcPacketStream silc_packet_stream_create(SilcPacketEngine engine,
                                               silc_hash_string_compare, NULL,
                                               silc_packet_engine_hash_destr,
                                               NULL, TRUE);
+
   silc_mutex_unlock(engine->lock);
 
   /* Set IO notifier callback.  This schedules this stream for I/O. */
-  silc_stream_set_notifier(ps->stream, schedule, silc_packet_stream_io, ps);
+  if (!silc_stream_set_notifier(ps->stream, schedule, 
+                               silc_packet_stream_io, ps)) {
+    SILC_LOG_DEBUG(("Cannot set stream notifier for packet stream"));
+    silc_packet_stream_destroy(ps);
+    return NULL;
+  }
 
   return ps;
 }
@@ -769,8 +777,12 @@ void silc_packet_stream_destroy(SilcPacketStream stream)
   if (!stream)
     return;
 
-  if (silc_atomic_get_int8(&stream->refcnt) > 1) {
+  if (silc_atomic_sub_int8(&stream->refcnt, 1) > 0) {
     stream->destroyed = TRUE;
+
+    /* Close the underlaying stream */
+    if (!stream->udp && stream->stream)
+      silc_stream_close(stream->stream);
     return;
   }
 
@@ -796,7 +808,7 @@ void silc_packet_stream_destroy(SilcPacketStream stream)
   } else {
     /* Delete from UDP remote hash table */
     char tuple[64];
-    snprintf(tuple, sizeof(tuple), "%d%s", stream->remote_udp->remote_port,
+    silc_snprintf(tuple, sizeof(tuple), "%d%s", stream->remote_udp->remote_port,
             stream->remote_udp->remote_ip);
     silc_mutex_lock(stream->sc->engine->lock);
     silc_hash_table_del(stream->sc->engine->udp_remote, tuple);
@@ -1001,14 +1013,22 @@ SilcBool silc_packet_get_sender(SilcPacket packet,
 void silc_packet_stream_ref(SilcPacketStream stream)
 {
   silc_atomic_add_int8(&stream->refcnt, 1);
+  SILC_LOG_DEBUG(("Stream %p, refcnt %d->%d", stream,
+                 silc_atomic_get_int8(&stream->refcnt) - 1,
+                 silc_atomic_get_int8(&stream->refcnt)));
 }
 
 /* Unreference packet stream */
 
 void silc_packet_stream_unref(SilcPacketStream stream)
 {
-  if (silc_atomic_sub_int8(&stream->refcnt, 1) == 0)
-    silc_packet_stream_destroy(stream);
+  SILC_LOG_DEBUG(("Stream %p, refcnt %d->%d", stream,
+                 silc_atomic_get_int8(&stream->refcnt),
+                 silc_atomic_get_int8(&stream->refcnt) - 1));
+  if (silc_atomic_sub_int8(&stream->refcnt, 1) > 0)
+    return;
+  silc_atomic_add_int8(&stream->refcnt, 1);
+  silc_packet_stream_destroy(stream);
 }
 
 /* Return engine */
@@ -1398,7 +1418,7 @@ static inline SilcBool silc_packet_send_raw(SilcPacketStream stream,
 
   /* Get random padding */
   for (i = 0; i < padlen; i++) tmppad[i] =
-                                silc_rng_get_byte_fast(stream->sc->engine->rng);
+    silc_rng_get_byte_fast(stream->sc->engine->rng);
 
   silc_mutex_lock(stream->lock);
 
@@ -1628,34 +1648,24 @@ static inline SilcBool silc_packet_check_mac(SilcHmac hmac,
 /* Increments/sets counter when decrypting in counter mode. */
 
 static inline void silc_packet_receive_ctr_increment(SilcPacketStream stream,
-                                                    SilcCipher cipher,
-                                                    unsigned char *ret_iv)
+                                                    unsigned char *iv,
+                                                    unsigned char *packet_iv)
 {
-  unsigned char *iv = silc_cipher_get_iv(cipher);
   SilcUInt32 pc;
 
-  /* Increment packet counter */
-  SILC_GET32_MSB(pc, iv + 8);
-  pc++;
-  SILC_PUT32_MSB(pc, iv + 8);
+  /* If IV Included flag, set the IV from packet to block counter. */
+  if (stream->iv_included) {
+    memcpy(iv + 4, packet_iv, 8);
+  } else {
+    /* Increment packet counter */
+    SILC_GET32_MSB(pc, iv + 8);
+    pc++;
+    SILC_PUT32_MSB(pc, iv + 8);
+  }
 
   /* Reset block counter */
   memset(iv + 12, 0, 4);
 
-  /* If IV Included flag, return the 64-bit IV for inclusion in packet */
-  if (stream->iv_included) {
-    /* Get new nonce */
-    ret_iv[0] = silc_rng_get_byte_fast(stream->sc->engine->rng);
-    ret_iv[1] = ret_iv[0] + iv[4];
-    ret_iv[2] = ret_iv[0] ^ ret_iv[1];
-    ret_iv[3] = ret_iv[0] + ret_iv[2];
-    SILC_PUT32_MSB(pc, ret_iv + 4);
-    SILC_LOG_HEXDUMP(("IV"), ret_iv, 8);
-
-    /* Set new nonce to counter block */
-    memcpy(iv + 4, ret_iv, 4);
-  }
-
   SILC_LOG_HEXDUMP(("Counter Block"), iv, 16);
 }
 
@@ -1919,8 +1929,17 @@ static void silc_packet_read_process(SilcPacketStream stream)
       if (stream->iv_included) {
        /* SID, IV and sequence number is included in the ciphertext */
        sid = (SilcUInt8)inbuf->data[0];
-       memcpy(iv, inbuf->data + 1, block_len);
-       ivlen = block_len + 1;
+
+       if (silc_cipher_get_mode(cipher) == SILC_CIPHER_MODE_CTR) {
+         /* Set the CTR mode IV from packet to counter block */
+         memcpy(iv, silc_cipher_get_iv(cipher), block_len);
+         silc_packet_receive_ctr_increment(stream, iv, inbuf->data + 1);
+         ivlen = 8 + 1;
+       } else {
+         /* Get IV from packet */
+         memcpy(iv, inbuf->data + 1, block_len);
+         ivlen = block_len + 1;
+       }
        psnlen = 4;
 
        /* Check SID, and get correct decryption key */
@@ -1943,6 +1962,10 @@ static void silc_packet_read_process(SilcPacketStream stream)
        }
       } else {
        memcpy(iv, silc_cipher_get_iv(cipher), block_len);
+
+       /* If using CTR mode, increment the counter */
+       if (silc_cipher_get_mode(cipher) == SILC_CIPHER_MODE_CTR)
+         silc_packet_receive_ctr_increment(stream, iv, NULL);
       }
 
       silc_cipher_decrypt(cipher, inbuf->data + ivlen, tmp,
@@ -1955,6 +1978,7 @@ static void silc_packet_read_process(SilcPacketStream stream)
        header += 4;
       }
     } else {
+      /* Unencrypted packet */
       block_len = SILC_PACKET_MIN_HEADER_LEN;
       header = inbuf->data;
     }
@@ -2093,7 +2117,6 @@ static void silc_packet_read_process(SilcPacketStream stream)
   silc_buffer_reset(inbuf);
 }
 
-
 /****************************** Packet Waiting ******************************/
 
 /* Packet wait receive callback */
@@ -2241,3 +2264,244 @@ int silc_packet_wait(void *waiter, int timeout, SilcPacket *return_packet)
 
   return ret == TRUE ? 1 : 0;
 }
+
+/************************** Packet Stream Wrapper ***************************/
+
+/* Packet stream wrapper receive callback */
+static SilcBool
+silc_packet_wrap_packet_receive(SilcPacketEngine engine,
+                               SilcPacketStream stream,
+                               SilcPacket packet,
+                               void *callback_context,
+                               void *stream_context);
+
+const SilcStreamOps silc_packet_stream_ops;
+
+/* Packet stream wrapper context */
+typedef struct {
+  const SilcStreamOps *ops;
+  SilcPacketStream stream;
+  SilcMutex lock;
+  void *waiter;                        /* Waiter context in blocking mode */
+  SilcStreamNotifier callback;
+  void *context;
+  SilcList in_queue;
+  SilcPacketType type;
+  SilcPacketFlags flags;
+  unsigned int closed        : 1;
+  unsigned int blocking      : 1;
+} *SilcPacketWrapperStream;
+
+/* Packet wrapper callbacks */
+static SilcPacketCallbacks silc_packet_wrap_cbs =
+{
+  silc_packet_wrap_packet_receive, NULL, NULL
+};
+
+/* Packet stream wrapper receive callback, non-blocking mode */
+
+static SilcBool
+silc_packet_wrap_packet_receive(SilcPacketEngine engine,
+                               SilcPacketStream stream,
+                               SilcPacket packet,
+                               void *callback_context,
+                               void *stream_context)
+{
+  SilcPacketWrapperStream pws = callback_context;
+
+  if (!pws->closed || !pws->callback)
+    return FALSE;
+
+  silc_mutex_lock(pws->lock);
+  silc_list_add(pws->in_queue, packet);
+  silc_mutex_unlock(pws->lock);
+
+  /* Call notifier callback */
+  pws->callback((SilcStream)pws, SILC_STREAM_CAN_READ, pws->context);
+
+  return TRUE;
+}
+
+/* Read SILC packet */
+
+int silc_packet_wrap_read(SilcStream stream, unsigned char *buf,
+                         SilcUInt32 buf_len)
+{
+  SilcPacketWrapperStream pws = stream;
+  SilcPacket packet;
+  int len;
+
+  if (pws->closed)
+    return -2;
+
+  if (pws->blocking) {
+    /* Block until packet is received */
+    if ((silc_packet_wait(pws->waiter, 0, &packet)) < 0)
+      return -2;
+    if (pws->closed)
+      return -2;
+  } else {
+    /* Non-blocking mode */
+    silc_mutex_lock(pws->lock);
+    if (!silc_list_count(pws->in_queue)) {
+      silc_mutex_unlock(pws->lock);
+      return -1;
+    }
+
+    silc_list_start(pws->in_queue);
+    packet = silc_list_get(pws->in_queue);
+    silc_list_del(pws->in_queue, packet);
+    silc_mutex_unlock(pws->lock);
+  }
+
+  len = silc_buffer_len(&packet->buffer);
+  if (len > buf_len)
+    len = buf_len;
+
+  memcpy(buf, packet->buffer.data, len);
+
+  silc_packet_free(packet);
+  return len;
+}
+
+/* Write SILC packet */
+
+int silc_packet_wrap_write(SilcStream stream, const unsigned char *data,
+                          SilcUInt32 data_len)
+{
+  SilcPacketWrapperStream pws = stream;
+
+  /* Send the SILC packet */
+  if (!silc_packet_send(pws->stream, pws->type, pws->flags, data, data_len))
+    return -2;
+
+  return data_len;
+}
+
+/* Close stream */
+
+SilcBool silc_packet_wrap_close(SilcStream stream)
+{
+  SilcPacketWrapperStream pws = stream;
+
+  if (pws->closed)
+    return TRUE;
+
+  if (pws->blocking) {
+    /* Close packet waiter */
+    silc_packet_wait_uninit(pws->waiter, pws->stream);
+  } else {
+    /* Unlink */
+    if (pws->callback)
+      silc_packet_stream_unlink(pws->stream, &silc_packet_wrap_cbs, pws);
+  }
+  pws->closed = TRUE;
+
+  return TRUE;
+}
+
+/* Destroy wrapper stream */
+
+void silc_packet_wrap_destroy(SilcStream stream)
+
+{
+  SilcPacketWrapperStream pws = stream;
+  SilcPacket packet;
+
+  SILC_LOG_DEBUG(("Destroying wrapped packet stream %p", pws));
+
+  silc_stream_close(stream);
+  silc_list_start(pws->in_queue);
+  while ((packet = silc_list_get(pws->in_queue)))
+    silc_packet_free(packet);
+  if (pws->lock)
+    silc_mutex_free(pws->lock);
+  silc_packet_stream_unref(pws->stream);
+
+  silc_free(pws);
+}
+
+/* Link stream to receive packets */
+
+SilcBool silc_packet_wrap_notifier(SilcStream stream,
+                                  SilcSchedule schedule,
+                                  SilcStreamNotifier callback,
+                                  void *context)
+{
+  SilcPacketWrapperStream pws = stream;
+
+  if (pws->closed || pws->blocking)
+    return FALSE;
+
+  /* Link to receive packets */
+  if (callback)
+    silc_packet_stream_link(pws->stream, &silc_packet_wrap_cbs, pws,
+                           100000, pws->type, -1);
+  else
+    silc_packet_stream_unlink(pws->stream, &silc_packet_wrap_cbs, pws);
+
+  pws->callback = callback;
+  pws->context = context;
+
+  return TRUE;
+}
+
+/* Return schedule */
+
+SilcSchedule silc_packet_wrap_get_schedule(SilcStream stream)
+{
+  return NULL;
+}
+
+/* Wraps packet stream into SilcStream. */
+
+SilcStream silc_packet_stream_wrap(SilcPacketStream stream,
+                                   SilcPacketType type,
+                                   SilcPacketFlags flags,
+                                  SilcBool blocking_mode)
+{
+  SilcPacketWrapperStream pws;
+
+  pws = silc_calloc(1, sizeof(*pws));
+  if (!pws)
+    return NULL;
+
+  SILC_LOG_DEBUG(("Wrapping packet stream %p to stream %p", stream, pws));
+
+  pws->ops = &silc_packet_stream_ops;
+  pws->stream = stream;
+  pws->type = type;
+  pws->flags = flags;
+  pws->blocking = blocking_mode;
+
+  if (pws->blocking) {
+    /* Blocking mode.  Use packet waiter to do the thing. */
+    pws->waiter = silc_packet_wait_init(pws->stream, pws->type, -1);
+    if (!pws->waiter) {
+      silc_free(pws);
+      return NULL;
+    }
+  } else {
+    /* Non-blocking mode */
+    if (!silc_mutex_alloc(&pws->lock)) {
+      silc_free(pws);
+      return NULL;
+    }
+
+    silc_list_init(pws->in_queue, struct SilcPacketStruct, next);
+  }
+
+  silc_packet_stream_ref(stream);
+
+  return (SilcStream)pws;
+}
+
+const SilcStreamOps silc_packet_stream_ops =
+{
+  silc_packet_wrap_read,
+  silc_packet_wrap_write,
+  silc_packet_wrap_close,
+  silc_packet_wrap_destroy,
+  silc_packet_wrap_notifier,
+  silc_packet_wrap_get_schedule,
+};