view main.c @ 0:17cb7cdbb8be draft default tip

Working prototype
author Ivo Smits <Ivo@UCIS.nl>
date Fri, 07 Feb 2014 23:28:39 +0100
parents
children
line wrap: on
line source

/* Copyright 2014 Ivo Smits <Ivo@UCIS.nl>. All rights reserved.
   Redistribution and use in source and binary forms, with or without modification, are
   permitted provided that the following conditions are met:

   1. Redistributions of source code must retain the above copyright notice, this list of
      conditions and the following disclaimer.

   2. Redistributions in binary form must reproduce the above copyright notice, this list
      of conditions and the following disclaimer in the documentation and/or other materials
      provided with the distribution.

   THIS SOFTWARE IS PROVIDED ``AS IS'' AND ANY EXPRESS OR IMPLIED
   WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
   FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE AUTHORS OR
   CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
   CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
   SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
   ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
   NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF
   ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

   The views and conclusions contained in the software and documentation are those of the
   authors and should not be interpreted as representing official policies, either expressed
   or implied, of Ivo Smits.*/

#include <stdlib.h>
#include <stdio.h>
#include <stdbool.h>
#include <fcntl.h>
#include <string.h>
#include <sys/types.h>
#include <poll.h>
#include <sys/socket.h>
#ifndef HAVE_NETINET_IN_H
#include <netinet/in.h>
#endif
#include <arpa/inet.h>
#include <netdb.h>
#include "include.h"

char* (*getconf)(const char*) = getenv;

static void hex2bin(unsigned char* dest, const char* src, const int count) {
	int i;
	for (i = 0; i < count; i++) {
		if (*src >= '0' && *src <= '9') *dest = *src - '0';
		else if (*src >= 'a' && * src <='f') *dest = *src - 'a' + 10;
		else if (*src >= 'A' && * src <='F') *dest = *src - 'A' + 10;
		src++; *dest = *dest << 4;
		if (*src >= '0' && *src <= '9') *dest += *src - '0';
		else if (*src >= 'a' && *src <= 'f') *dest += *src - 'a' + 10;
		else if (*src >= 'A' && *src <= 'F') *dest += *src - 'A' + 10;
		src++; dest++;
	}
}

static bool crypto_init(connection_context* context, bool* keyupdate) {
	unsigned char cpublickey[32], csecretkey[32];
	bool hpublickey = false, hsecretkey = false;
	char* envval;
	*keyupdate = false;
	if ((envval = getconf("PUBLIC_KEY"))) {
		if (strlen(envval) != 64) return errorexit("PUBLIC_KEY length");
		hex2bin(cpublickey, envval, 32);
		hpublickey = true;
	}
	if ((envval = getconf("PRIVATE_KEY"))) {
		if (strlen(envval) != 64) return errorexit("PRIVATE_KEY length");
		hex2bin(csecretkey, envval, 32);
		hsecretkey = true;
	} else if ((envval = getconf("PRIVATE_KEY_FILE"))) {
		FILE* pkfile = fopen(envval, "rb");
		if (!pkfile) return errorexitp("Could not open PRIVATE_KEY_FILE");
		char pktextbuf[64];
		const size_t pktextsize = fread(pktextbuf, 1, sizeof(pktextbuf), pkfile);
		if (pktextsize == 32) {
			memcpy(csecretkey, pktextbuf, 32);
		} else if (pktextsize == 64) {
			hex2bin(csecretkey, pktextbuf, 32);
		} else {
			return errorexit("PRIVATE_KEY length");
		}
		fclose(pkfile);
		hsecretkey = true;
	}
	if (hpublickey || hsecretkey || ((envval = getconf("ENCRYPT")) && atoi(envval))) {
		if (!hpublickey) fprintf(stderr, "Warning: encryption enabled but remote key not set, cryptographic authentication disabled.\n");
		if (!connection_init_encryption(context, hsecretkey ? csecretkey : NULL, hpublickey ? cpublickey : NULL)) return false;
		*keyupdate = true;
	} else {
		fprintf(stderr, "Warning: encryption disabled.\n");
	}
	if ((envval = getconf("PASSWORD"))) {
		if (!connection_init_passwordauth(context, strdup(envval))) return false;
	}
	return true;
}

typedef union {
	struct sockaddr any;
	struct sockaddr_in ip4;
	struct sockaddr_in6 ip6;
} sockaddr_any;

static int sockaddr_set_port(sockaddr_any* sa, int port) {
	port = htons(port);
	int af = sa->any.sa_family;
	if (af == AF_INET) sa->ip4.sin_port = port;
	else if (af == AF_INET6) sa->ip6.sin6_port = port;
	else return errorexit("Unknown address family");
	return 0;
}

static bool socket_init(connection_context* context) {
	char* envval;
	fprintf(stderr, "Initializing socket...\n");
	struct addrinfo *ai_local = NULL, *ai_remote = NULL;
	unsigned short af = 0;
	int ret;
	if ((envval = getconf("LOCAL_ADDRESS"))) {
		if ((ret = getaddrinfo(envval, NULL, NULL, &ai_local))) return errorexitf("getaddrinfo(LOCAL_ADDRESS)", gai_strerror(ret));
		if (!ai_local) return errorexit("LOCAL_ADDRESS lookup failed");
		if (ai_local->ai_addrlen > sizeof(sockaddr_any)) return errorexit("Resolved LOCAL_ADDRESS is too big");
		af = ai_local->ai_family;
	}
	if ((envval = getconf("REMOTE_ADDRESS"))) {
		if ((ret = getaddrinfo(envval, NULL, NULL, &ai_remote))) return errorexitf("getaddrinfo(REMOTE_ADDRESS)", gai_strerror(ret));
		if (!ai_remote) return errorexit("REMOTE_ADDRESS lookup failed");
		if (ai_remote->ai_addrlen > sizeof(sockaddr_any)) return errorexit("Resolved REMOTE_ADDRESS is too big");
		if (af && af != ai_remote->ai_family) return errorexit("Address families do not match");
		af = ai_remote->ai_family;
	}
	if (!af) return connection_init_socket(context, 0, 1);
	int sa_size = sizeof(sockaddr_any);
	if (af == AF_INET) sa_size = sizeof(struct sockaddr_in);
	else if (af == AF_INET6) sa_size = sizeof(struct sockaddr_in6);
	int sfd = socket(af, SOCK_STREAM, IPPROTO_TCP);
	if (sfd < 0) return errorexitp("Could not create socket");
	sockaddr_any udpaddr;
	if (ai_local) {
		memset(&udpaddr, 0, sizeof(udpaddr));
		udpaddr.any.sa_family = af;
		memcpy(&udpaddr, ai_local->ai_addr, ai_local->ai_addrlen);
		int port = 2998;
		if ((envval = getconf("LOCAL_PORT"))) port = atoi(envval);
		if (sockaddr_set_port(&udpaddr, port)) return -1;
		if (bind(sfd, &udpaddr.any, sa_size)) return errorexitp("Could not bind socket");
	}
	if (ai_remote) {
		memset(&udpaddr, 0, sizeof(udpaddr));
		udpaddr.any.sa_family = af;
		memcpy(&udpaddr, ai_remote->ai_addr, ai_remote->ai_addrlen);
		int port = 2998;
		if ((envval = getconf("REMOTE_PORT"))) port = atoi(envval);
		if (sockaddr_set_port(&udpaddr, port)) return -1;
		if (connect(sfd, &udpaddr.any, sa_size)) return errorexitp("Could not connect socket");
	} else {
		return errorexit("REMOTE_ADDRESS not specified and server mode is currently not supported :-(. Please use (x)inetd or similar.");
	}
	if (ai_local) freeaddrinfo(ai_local);
	if (ai_remote) freeaddrinfo(ai_remote);
	return connection_init_socket(context, sfd, sfd);
}

static bool mainA() {
	connection_context context;
	tunnel_context tunnel;
	bool keyupdate = false;
	if (!connection_init(&context)) return false;
	if (!socket_init(&context)) return false;
	if (!crypto_init(&context, &keyupdate)) return false;
	if (!connection_init_done(&context)) return false;
	while (!context.local_tunnelready) if (!connection_read(&context)) return false;
	if (keyupdate) connection_update_key(&context);

	if (!tunnel_init(&tunnel)) return false;
	context.tunnel = &tunnel;
	tunnel.connection = &context;

	struct pollfd fds[2];
	fds[0].fd = context.recv_socket;
	fds[0].events = POLLIN;
	fds[1].fd = tunnel.fd;
	fds[1].events = POLLIN;

	while (true) {
		int len = poll(fds, 2, 10000);
		if (len < 0) return errorexitp("poll failed");
		else if (fds[0].revents & (POLLERR | POLLHUP | POLLNVAL)) return errorexit("poll error on socket");
		else if (fds[1].revents & (POLLHUP | POLLNVAL)) return errorexit("poll error on tap device");
		if (len == 0) {
			if (keyupdate) {
				if (!context.key_updated) return errorexit("key update timed out");
				if (!connection_update_key(&context)) return false;
			} else {
				if (!connection_ping(&context)) return errorexit("ping timed out");
			}
		}
		if (fds[0].revents & POLLERR) return errorexitp("poll error on socket");
		if (fds[0].revents & POLLIN) if (!connection_read(&context)) return false;
		if (fds[1].revents & POLLIN) if (!tunnel_read(&tunnel)) return false;
	}
	return -1;
}

int main() {
	return mainA() ? 0 : -1;
}

int errorexit(const char* text) {
	fprintf(stderr, "%s\n", text);
	return false;
}
int errorexitf(const char* text, const char* error) {
	fprintf(stderr, "%s: %s\n", text, error);
	return false;
}
bool errorexitp(const char* text) {
	perror(text);
	return false;
}