Code Review Stack Exchange is a question and answer site for peer programmer code reviews. It's 100% free, no registration required.

Sign up
Here's how it works:
  1. Anybody can ask a question
  2. Anybody can answer
  3. The best answers are voted up and rise to the top

I am trying to create a fully async example of a client and server using SSL.

I think these are the required assumptions:

  1. Connecting may require socket readability and writeability notifications.
  2. When the socket is readable, SSL_write may need to be called depending on the result of the last call to SSL_write.
  3. When the socket is writable, SSL_read may need to be called depending on the result of the last call to SSL_read.
  4. If the last call to SSL_connect, SSL_write, or SLL_read returned SSL_ERROR_WANT_WRITE, then the application cannot write anything new with SSL_write until the last call is recalled.
  5. On the connecting side, SSL_write cannot be called until SSL_connect succeeds.

Are there any others?

On the accepting side, how can I tell if the socket is ready for a call to SSL_write?

Here is an example that seems to work completely. Please tell me if anything is wrong with it.

  • It can be compiled with gcc ssl.c -lssl -lcrypto.
  • The client is run with ./a.out client.
  • The server is run with ./a.out server.
  • You can make a sample pem for use by the server with:
echo -e "\n\n\n\n\n\n" | /usr/bin/openssl req -x509 -nodes -days 365 -newkey rsa:1024 -keyout cert.pem -out cert.pem -config openssl.cnf
openssl x509 -in cert.pem -outform DER -out cert.pem.crt

#include <stdio.h>
#include <string.h>
#include <openssl/ssl.h>
#include <openssl/crypto.h>
#include <openssl/err.h>
#include <fcntl.h>
#include <sys/types.h>
#include <sys/socket.h>
#include <netinet/tcp.h>
#include <sys/select.h>
#include <netinet/in.h>

typedef enum {
  CONTINUE,
  BREAK,
  NEITHER
} ACTION;

ACTION ssl_connect(SSL* ssl, int* wants_tcp_write, int* connecting) {
  printf("calling SSL_connect\n");

  int result = SSL_connect(ssl);
  if (result == 0) {
    long error = ERR_get_error();
    const char* error_str = ERR_error_string(error, NULL);
    printf("could not SSL_connect: %s\n", error_str);
    return BREAK;
  } else if (result < 0) {
    int ssl_error = SSL_get_error(ssl, result);
    if (ssl_error == SSL_ERROR_WANT_WRITE) {
      printf("SSL_connect wants write\n");
      *wants_tcp_write = 1;
      return CONTINUE;
    }

    if (ssl_error == SSL_ERROR_WANT_READ) {
      printf("SSL_connect wants read\n");
      // wants_tcp_read is always 1;
      return CONTINUE;
    }

    long error = ERR_get_error();
    const char* error_string = ERR_error_string(error, NULL);
    printf("could not SSL_connect %s\n", error_string);
    return BREAK;
  } else {
    printf("connected\n");
    *connecting = 0;
    return CONTINUE;
  }

  return NEITHER;
}

ACTION ssl_read(SSL* ssl, int* wants_tcp_write, int* call_ssl_read_instead_of_write) {
  printf("calling SSL_read\n");

  *call_ssl_read_instead_of_write = 0;

  char buffer[1024];
  int num = SSL_read(ssl, buffer, sizeof(buffer));
  if (num == 0) {
    long error = ERR_get_error();
    const char* error_str = ERR_error_string(error, NULL);
    printf("could not SSL_read (returned 0): %s\n", error_str);
    return BREAK;
  } else if (num < 0) {
    int ssl_error = SSL_get_error(ssl, num);
    if (ssl_error == SSL_ERROR_WANT_WRITE) {
      printf("SSL_read wants write\n");
      *wants_tcp_write = 1;
      *call_ssl_read_instead_of_write = 1;
      return CONTINUE;
    }

    if (ssl_error == SSL_ERROR_WANT_READ) {
      printf("SSL_read wants read\n");
      // wants_tcp_read is always 1;
      return CONTINUE;
    }

    long error = ERR_get_error();
    const char* error_string = ERR_error_string(error, NULL);
    printf("could not SSL_read (returned -1) %s\n", error_string);
    return BREAK;
  } else {
    printf("read %d bytes\n", num);
  }

  return NEITHER;
}

ACTION ssl_write(SSL* ssl, int* wants_tcp_write, int* call_ssl_write_instead_of_read,
    int is_client, int should_start_a_new_write) {
  printf("calling SSL_write\n");

  static char buffer[1024];
  memset(buffer, 0, sizeof(buffer));
  static int to_write = 0;
  if (!*call_ssl_write_instead_of_read && !to_write && is_client && should_start_a_new_write) {
    to_write = 1024;
    printf("decided to write %d bytes\n", to_write);
  }

  if (*call_ssl_write_instead_of_read && (!to_write || !buffer)) {
    printf("ssl should not have requested a write from a read if no data was waiting to be written\n");
    return BREAK;
  }

  *call_ssl_write_instead_of_read = 0;

  if (!to_write) {
    return NEITHER;
  }

  int num = SSL_write(ssl, buffer, to_write);
  if (num == 0) {
    long error = ERR_get_error();
    const char* error_str = ERR_error_string(error, NULL);
    printf("could not SSL_write (returned 0): %s\n", error_str);
    return BREAK;
  } else if (num < 0) {
    int ssl_error = SSL_get_error(ssl, num);
    if (ssl_error == SSL_ERROR_WANT_WRITE) {
      printf("SSL_write wants write\n");
      *wants_tcp_write = 1;
      return CONTINUE;
    }

    if (ssl_error == SSL_ERROR_WANT_READ) {
      printf("SSL_write wants read\n");
      *call_ssl_write_instead_of_read = 1;
      // wants_tcp_read is always 1;
      return CONTINUE;
    }

    long error = ERR_get_error();
    const char* error_string = ERR_error_string(error, NULL);
    printf("could not SSL_write (returned -1): %s\n", error_string);
    return BREAK;
  } else {
    printf("wrote %d of %d bytes\n", num, to_write);
    if (to_write < num) {
      *wants_tcp_write = 1;
    } else {
      *wants_tcp_write = 0;
    }
    to_write -= num;
  }

  return NEITHER;
}

int main(int argc, char** argv) {
  if (argc != 2 || (strcmp(argv[1], "client") && strcmp(argv[1], "server"))) {
    printf("need parameter 'client' or 'server'\n");
    return 1;
  }

  int is_client = !strcmp(argv[1], "client");
  int port = 10000;

  SSL_library_init();
  SSL_load_error_strings();

  SSL_CTX* ssl_ctx = SSL_CTX_new(is_client ?
      SSLv23_client_method() :
      SSLv23_server_method());
  if (!ssl_ctx) {
    printf("could not SSL_CTX_new\n");
    return 1;
  }

  int sockfd = 0;
  if (is_client) {
    sockfd = socket(AF_INET, SOCK_STREAM, 0);
    if (sockfd < 0) {
      printf("could not create socket\n");
      return 1;
    }
  } else {
    const char* certificate = "./cert.pem";
    if (SSL_CTX_use_certificate_file(ssl_ctx, certificate, SSL_FILETYPE_PEM) != 1) {
      printf("could not SSL_CTX_use_certificate_file\n");
      return 1;
    }

    if (SSL_CTX_use_PrivateKey_file(ssl_ctx, certificate, SSL_FILETYPE_PEM) != 1) {
      printf("could not SSL_CTX_use_PrivateKey_file\n");
      return 1;
    }

    int server = socket(AF_INET, SOCK_STREAM, 0);
    if (server < 0) {
      printf("could not create socket\n");
      return 1;
    }

    int on = 1;
    if (setsockopt(server, SOL_SOCKET, SO_REUSEADDR, &on, sizeof(on))) {
      close(server);
      printf("could not setsockopt\n");
      return 1;
    }

    // Bind on any interface.
    struct sockaddr_in addr;
    bzero(&addr, sizeof(addr));
    addr.sin_family = AF_INET;
    addr.sin_port = htons(port);

    if (bind(server, (struct sockaddr*)(&addr),
      sizeof(addr))) {
      printf("could not bind\n");
      close(server);
      return 1;
    }

    if (listen(server, 1)) {
      printf("could not listen\n");
      close(server);
      return 1;
    }

    sockfd = accept(server, NULL, NULL);
    if (sockfd < 0) {
      printf("could not create accept\n");
      return 1;
    }
  }

  SSL* ssl = SSL_new(ssl_ctx);
  if (!ssl) {
    printf("could not SSL_new\n");
    return 1;
  }

  // Set the socket to be non blocking.
  int flags = fcntl(sockfd, F_GETFL, 0);
  if (fcntl(sockfd, F_SETFL, flags | O_NONBLOCK)) {
    printf("could not fcntl\n");
    close(sockfd);
    return 1;
  }

  int one = 1;
  if (setsockopt(sockfd, SOL_TCP, TCP_NODELAY, &one, sizeof(one))) {
    printf("could not setsockopt\n");
    close(sockfd);
    return 1;
  }

  if (is_client) {
    struct sockaddr_in addr;
    memset(&addr, 0, sizeof(addr));
    addr.sin_family = AF_INET;
    addr.sin_port = htons(port);
    addr.sin_addr.s_addr = htonl((in_addr_t)(0x7f000001));

    if (connect(sockfd, (struct sockaddr*)(&addr), sizeof(addr)) && errno != EINPROGRESS) {
      printf("could not connect\n");
      return 1;
    }
  }

  if (!SSL_set_fd(ssl, sockfd)) {
    close(sockfd);
    printf("could not SSL_set_fd\n");
    return 1;
  }

  int connecting = 1;
  if (is_client) {
    SSL_set_connect_state(ssl);
  } else {
    SSL_set_accept_state(ssl);
    connecting = 0;
  }

  fd_set read_fds, write_fds;
  int wants_tcp_read = 1, wants_tcp_write = is_client;
  int call_ssl_read_instead_of_write = 0;
  int call_ssl_write_instead_of_read = 0;

  for (;;) {
    printf("selecting\n");
    FD_ZERO(&read_fds);
    FD_ZERO(&write_fds);
    if (wants_tcp_read) {
      FD_SET(sockfd, &read_fds);
    }

    if (wants_tcp_write) {
      FD_SET(sockfd, &write_fds);
    }

    struct timeval timeout = { 1, 0 };

    if (select(sockfd + 1, &read_fds, &write_fds, NULL, &timeout)) {
      if (FD_ISSET(sockfd, &read_fds)) {
        printf("readable\n");

        if (connecting) {
          ACTION action = ssl_connect(ssl, &wants_tcp_write, &connecting);
          if (action == CONTINUE) {
            continue;
          } else if (action == BREAK) {
            break;
          }
        } else {
          ACTION action;
          if (call_ssl_write_instead_of_read) {
            action = ssl_write(ssl, &wants_tcp_write, &call_ssl_write_instead_of_read, is_client, 0);
          } else {
            action = ssl_read(ssl, &wants_tcp_write, &call_ssl_read_instead_of_write);
          }

          if (action == CONTINUE) {
            continue;
          } else if (action == BREAK) {
            break;
          }
        }
      }

      if (FD_ISSET(sockfd, &write_fds)) {
        printf("writable\n");

        if (connecting) {
          wants_tcp_write = 0;

          ACTION action = ssl_connect(ssl, &wants_tcp_write, &connecting);
          if (action == CONTINUE) {
            continue;
          } else if (action == BREAK) {
            break;
          }
        } else {
          ACTION action;
          if (call_ssl_read_instead_of_write) {
            action = ssl_read(ssl, &wants_tcp_write, &call_ssl_read_instead_of_write);
          } else {
            action = ssl_write(ssl, &wants_tcp_write, &call_ssl_write_instead_of_read, is_client, 0);
          }

          if (action == CONTINUE) {
            continue;
          } else if (action == BREAK) {
            break;
          }
        }
      }
    } else if (is_client & !connecting && !call_ssl_write_instead_of_read) {
      ACTION action = ssl_write(ssl, &wants_tcp_write, &call_ssl_write_instead_of_read, is_client, 1);
      if (action == CONTINUE) {
        continue;
      } else if (action == BREAK) {
        break;
      }
    }
  }

  SSL_CTX_free(ssl_ctx);

  return 0;
}
share|improve this question

migrated from stackoverflow.com Oct 24 '15 at 15:18

This question came from our site for professional and enthusiast programmers.

    
I realized i forgot about ssl_pending – John Cashew Oct 24 '15 at 3:46
    
I also did not include the libray cleanup calls shown here: wiki.openssl.org/index.php/Library_Initialization – John Cashew Oct 24 '15 at 18:45

Also kind of minor:

In the ssl_connect function, the last return NEITHER; is unreachable. In ssl_read and ssl_write, I suggest you implement it similarly: move return NEITHER; into the else.

share|improve this answer
    
Thanks for the feedback. Unfortunately it is not a bug. – John Cashew Dec 4 '15 at 0:38

This is kinda minor, but still sticks out to me:

typedef enum {
    CONTINUE,
    BREAK,
    NEITHER
} ACTION;

Here, ACTION is the type name, so just the first letter should be capitalized. This will also help with differentiating this type name from the actual enum values.

typedef enum {
    CONTINUE,
    BREAK,
    NEITHER
} Action;
share|improve this answer
    
Thanks for reading! Honestly I am mostly looking for feedback on the ssl calls since the library is quite tricky to use. Almost every example I can find does not work for a general purpose socket. – John Cashew Dec 2 '15 at 22:49

Your Answer

 
discard

By posting your answer, you agree to the privacy policy and terms of service.

Not the answer you're looking for? Browse other questions tagged or ask your own question.