#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <ctype.h>
#include <dirent.h>
#include <pwd.h>
#include <regex.h>
#include <dlfcn.h>
#include <errno.h>

/* BUGS
 * ----
 * 1) The "hack" of using the same .so both as PAM module
 *    and preloaded lib is probably dangerous.
 * 2) The errors should be prefixed with "userdirs:" in
 *    order not to confuse the user into thinking that the
 *    error originated in the host program.
 */

static char *root_dir = NULL;
static uid_t g_uid = 65534;
static gid_t g_gid = 65534;

static void parse_init (const char *env);

/* Overloading of functions that read /etc/passwd,
 * needed because PAM can't provide the working directory.
 *
 * This should be preloaded with the LD_PRELOAD variable.
 * As LD_PRELOAD doesn't propagate to suid programs,
 * /etc/ld.so.preload has to be used for "spawning" daemons
 * like exim4.
 *
 * When the "USERDIRS" environment variable is unset, it
 * should behave exactly like the original libc function.
 */

#define GECOS_LEN 64
#define NAME_LEN 32
#define PASS_LEN 32
#define DIR_LEN 256

#define REGEXP "^(.{0,64})\\(([a-zA-Z0-9._]{1,32})-([a-zA-Z0-9]{1,32})\\)$"

struct passwd *getpwnam (const char *name)
{
	static struct passwd *(*real_getpwnam) (const char *name) = NULL;
	void *lib_handle;
	static struct passwd *pwd = NULL, *pwd2;
	static regex_t *preg = NULL;
	regmatch_t mt[4];
	struct dirent **namelist;
	char *dirname;
	int n, res, found;
	
	// Get pointer to the real getpwnam().
	if (real_getpwnam == NULL)
	{
		lib_handle = dlopen ("libc.so.6", RTLD_LAZY);
		if (lib_handle == NULL)
		{
			fputs (dlerror(), stderr);
			exit(1);
		}
		real_getpwnam = dlsym (lib_handle, "getpwnam");
		if (real_getpwnam == NULL)
		{
			fputs (dlerror(), stderr);
			exit(1);
		}
	}
	
	// Get user record from real getpwnam().
	pwd2 = real_getpwnam (name);
	
	if (pwd2 != NULL)
	{
		// Found the user, return just that.
		return pwd2;
	}
	else if (   errno == EINTR
	         || errno == EIO
	         || errno == EMFILE
	         || errno == ENFILE
	         || errno == ENOMEM
	         || errno == ERANGE )
	{
		// An error occured, bail out.
		return NULL;
	}
	
	// The user wasn't found, so we now try userdirs...
	
	if (root_dir == NULL)
	{
		// Get and parse initialization values.
		parse_init (getenv ("USERDIRS"));
		
		if (root_dir == NULL)
		{
			errno = 0;
			return NULL;
		}
	}
	
	if (strlen (name) > NAME_LEN)
	{
		errno = EINVAL;
		return NULL;
	}
	
	if (pwd == NULL)
	{
		// TODO: check security implications of one-time init (if modified?)
		pwd = malloc (sizeof (struct passwd));
		pwd->pw_name = malloc (NAME_LEN + 1);
		pwd->pw_passwd = malloc (PASS_LEN + 1);
		pwd->pw_gecos = malloc (GECOS_LEN + 1);
		pwd->pw_dir = malloc (DIR_LEN + 1);
		pwd->pw_shell = "/bin/false";
		pwd->pw_uid = g_uid;
		pwd->pw_gid = g_gid;
	}
	
	if (preg == NULL)
	{
		// Compile the regexp used for userdirs parsing.
		preg = malloc (sizeof (regex_t));
		res = regcomp (preg, REGEXP, REG_EXTENDED);
		if (res != 0)
		{
			errno = ENOSYS;
			return NULL;
		}
	}
	
	// Scan directories for the supplied username.
	n = scandir (root_dir, &namelist, 0, alphasort);
	if (n < 0)
	{
		errno = ENOENT;
		return NULL;
	}
	else
	{
		found = 0;
		while (n--)
		{
			// FIXME: Filter out non-directories...
			dirname = namelist[n]->d_name;
			res = regexec (preg, dirname, 4, mt, 0);
			if (   res == 0
			    && strlen (root_dir) + strlen (dirname) + 1 <= DIR_LEN
				&& strlen (name) == mt[2].rm_eo - mt[2].rm_so
				&& strncmp (name, dirname + mt[2].rm_so, strlen (name)) == 0)
			{
				found = 1;
				bzero (pwd->pw_name, NAME_LEN + 1);
				strcpy (pwd->pw_name, name);
				bzero (pwd->pw_gecos, GECOS_LEN + 1);
				strncpy (pwd->pw_gecos, dirname + mt[1].rm_so, mt[1].rm_eo - mt[1].rm_so);
				bzero (pwd->pw_passwd, PASS_LEN + 1);
				strncpy (pwd->pw_passwd, dirname + mt[3].rm_so, mt[3].rm_eo - mt[3].rm_so);
				bzero (pwd->pw_dir, DIR_LEN + 1);
				strcpy (pwd->pw_dir, root_dir);
				strcat (pwd->pw_dir, "/");
				strcat (pwd->pw_dir, dirname);
			}
			free (namelist[n]);
		}
		free (namelist);
	}
	
	if (found)
	{
		return pwd;
	}
	else
	{
		errno = 0;
		return NULL;
	}
}

/* PAM authentication functions.
 * Uses the local getpwnam from above.
 */

#define PAM_SM_AUTH
#define PAM_SM_ACCOUNT
#include <security/pam_modules.h>

PAM_EXTERN int pam_sm_authenticate (pam_handle_t *pamh, int flags, int argc, const char **argv)
{
	struct pam_conv *conv;
	struct pam_message msg;
	const struct pam_message *msgp;
	struct pam_response *resp;
	
	const char *user;
	char password [PASS_LEN + 1];
	struct passwd *pwd;
	int res;
	int a;
	
	if (root_dir == NULL)
	{
		// Get and parse initialization values.
		parse_init (getenv ("USERDIRS"));
		
		if (root_dir == NULL)
			return PAM_SYSTEM_ERR;
	}
	
	/* identify user */
	res = pam_get_user (pamh, &user, NULL);
	if (res != PAM_SUCCESS)
		return PAM_SYSTEM_ERR;
	
	/* get password */
	res = pam_get_item (pamh, PAM_CONV, (const void **) &conv);
	if (res != PAM_SUCCESS)
		return PAM_SYSTEM_ERR;
	
	msg.msg_style = PAM_PROMPT_ECHO_OFF;
	msg.msg = "";
	msgp = &msg;
	resp = NULL;
	
	res = (*conv->conv) (1, &msgp, &resp, conv->appdata_ptr);
	if (res != PAM_SUCCESS)
		return PAM_AUTH_ERR;
	
	bzero (password, PASS_LEN + 1);
	strncpy (password, resp->resp, PASS_LEN);
	free (resp->resp);
	free (resp);
	
	/* compare passwords */
	pwd = getpwnam (user);	// local function
	if (pwd == NULL)
		return PAM_AUTH_ERR;
	
	if (strcmp (pwd->pw_passwd, "x") == 0)
		return PAM_AUTH_ERR;	// *May* come from /etc/passwd (FIXME)

	if (strcmp (password, pwd->pw_passwd) != 0)
		return PAM_AUTH_ERR;

	return PAM_SUCCESS;
}

PAM_EXTERN int pam_sm_setcred (pam_handle_t *pamh, int flags, int argc, const char **argv)
{
     return PAM_SUCCESS;
}

static void parse_init (const char *env)
{
	const char *a, *b, *end;
	
	if (env == NULL)
		return;
	
	a = env;
	end = a + strlen (a);
	
	while (a < end)
	{
		if (strncmp (a, "root_dir=", 9) == 0)
		{
			a += 9;
			b = strstr (a, ":");
			if (b == NULL)
				b = end;
			root_dir = malloc (b - a + 1);
			strncpy (root_dir, a, b - a);
			root_dir[b-a] = '\0';
			a = b + 1;
		}
		else if (strncmp (a, "uid=", 4) == 0)
		{
			a += 4;
			b = strstr (a, ":");
			if (b == NULL)
				b = end;
			sscanf (a, "%d", &g_uid);
			a = b + 1;
		}
		else if (strncmp (a, "gid=", 4) == 0)
		{
			a += 4;
			b = strstr (a, ":");
			if (b == NULL)
				b = end;
			sscanf (a, "%d", &g_gid);
			a = b + 1;
		}
		else
			break;
	}
}
