# CCS Computer Science
#   Admin
#
# Douglas Thrift
#
# $Id$

from __future__ import with_statement
import common
import ldap
import os
import psycopg2
import shutil
import sys

if sys.hexversion >= 0x2060000:
	import warnings

	with warnings.catch_warnings():
		warnings.filterwarnings('ignore', 'the sets module is deprecated', DeprecationWarning)

		import MySQLdb
else:
	import MySQLdb

MASTER = 'zweihander.ccs.ucsb.edu'
BASE = 'dc=ccs,dc=ucsb,dc=edu'
PEOPLE = 'ou=People,' + BASE
GROUP = 'ou=Group,' + BASE
SHELLS = map(lambda system: 'ucsbCcs' + system.capitalize(), common.SYSTEMS)
SAMBA_SID = 'S-1-5-21-3739982181-3886045993-82308153-%u'

ldap.set_option(ldap.OPT_PROTOCOL_VERSION, 3)
ldap.set_option(ldap.OPT_X_TLS_CACERTFILE, '/ccs/ssl/ccscert.pem')
ldap.set_option(ldap.OPT_X_TLS_REQUIRE_CERT, ldap.OPT_X_TLS_DEMAND)

def _user(user):
	return 'uid=%s,%s' % (user, PEOPLE)

def _group(group):
	return 'cn=%s,%s' % (group, GROUP)

def ldap_connection():
	connection = ldap.initialize('ldaps://' + MASTER)

	with open('/ccs/etc/secret', 'rb') as secret:
		connection.simple_bind_s(_user('root'), secret.read())
	
	return connection

def master():
	return common.HOST == MASTER

def run(errors):
	if errors:
		for host, error in errors.iteritems():
			sys.stderr.write('%s: %s\n' % (host, error))

		sys.exit(1)

def error(error):
	sys.exit('%s: %s' % (sys.argv[0], error))

def eof():
	print

	sys.exit(130)

def adduser(user, name, password):
	connection = ldap_connection()
	uid = max(map(lambda user: int(user[1]['uidNumber'][0]), connection.search_s(PEOPLE, ldap.SCOPE_ONELEVEL, '(&(uid=*)(!(uid=root)))', ('uidNumber',)))) + 1
	gid = uid
	samba_gid = gid + 1000
	home = os.path.join('/home', user)

	connection.add_s(_user(user), [
		('objectclass', ['top', 'account', 'posixAccount', 'shadowAccount', 'ucsbCcsLoginShells', 'sambaSamAccount']),
		('cn', name),
		('uid', user),
		('uidNumber', str(uid)),
		('gidNumber', str(gid)),
		('homeDirectory', home),
		('loginShell', 'bash'),
	] + zip(SHELLS, dict(common.SHELLS)['bash']) + [
		('sambaAcctFlags', '[U          ]'),
		('sambaSID', SAMBA_SID % uid),
		('sambaPrimaryGroupSID', SAMBA_SID % samba_gid),
	])
	connection.add_s(_group(user), [
		('objectclass', ['top', 'posixGroup', 'sambaGroupMapping']),
		('cn', user),
		('gidNumber', str(gid)),
		('sambaSID', SAMBA_SID % samba_gid),
		('sambaGroupType', '4'),
	])

	for group in ('wheel', 'fuse', 'operator'):
		connection.modify_s(_group(group), [(ldap.MOD_ADD, 'memberUid', user)])

	connection.unbind_s()
	os.umask(0022)
	os.mkdir(home)
	os.chown(home, uid, gid)

	for skel in ('/usr/share/skel', '/ccs/skel'):
		for source, directories, files in os.walk(skel):
			destination = os.path.join(home, source[len(skel):])

			for directory in directories:
				target = os.path.join(destination, directory[3:] if directory.startswith('dot') else directory)

				os.mkdir(target)
				shutil.copymode(os.path.join(source, directory), target)
				os.chown(target, uid, gid)

			for file in files:
				target = os.path.join(destination, file[3:] if file.startswith('dot') else file)

				shutil.copy(os.path.join(source, file), target)
				os.chown(target, uid, gid)

	db = psycopg2.connect(database = 'postgres')
	cursor = db.cursor()

	cursor.execute('create user %s with createdb' % user)
	db.commit()

	passwd(user, None, password)

def chfn(user, name):
	connection = ldap_connection()

	connection.modify_s(_user(user), [(ldap.MOD_REPLACE, 'cn', name)])
	connection.unbind_s()

def chsh(user, shell, shells):
	if shell != 'custom':
		shells = dict(common.SHELLS)[shell]
	else:
		for _shell, _shells in common.SHELLS[:-1]:
			if shells == _shells:
				shell = _shell

	connection = ldap_connection()

	connection.modify_s(_user(user), map(lambda (key, value): (ldap.MOD_REPLACE, key, value), [('loginShell', shell)] + zip(SHELLS, shells)))
	connection.unbind_s()

def passwd(user, old_password, new_password):
	connection = ldap_connection()

	connection.passwd_s(_user(user), old_password, new_password)
	connection.unbind_s()

	with open('/ccs/etc/secret', 'rb') as secret:
		db = MySQLdb.connect(passwd = secret.read(), db = 'mysql')

	cursor = db.cursor()

	cursor.execute('select count(User) from user where User = %s', (user,))

	if cursor.fetchone()[0]:
		cursor.execute('update user set Password = PASSWORD(%s) where User = %s', (new_password, user))
		cursor.execute('flush privileges');
	else:
		cursor.executemany('grant all on `' + db.escape_string(user) + r'\_%%`.* to %s@%s identified by %s', map(lambda host: (user, host, new_password), ('localhost', '%')))
