# Copyright 1999-2014. Parallels IP Holdings GmbH. All Rights Reserved.
""" plesk_mail_handlers - Base classes from creating Plesk mail handlers in Python """

import sys
import os
import pwd
import logging
import logging.handlers
import email
import email.utils
from email.generator import Generator
from cStringIO import StringIO

import plesk_log; plesk_log.setDefaultLevel('warning')

from plesk_outgoing_mail_db import OutgoingMailLimitsDb


log = None

def getLogger(name, syslog=True):
    """ Get a logger to syslog with a given name flavor. """
    global log
    log = plesk_log.getLogger(name)
    if syslog:
        # Log to syslog LOG_MAIL only
        format_tail = ' [at %(filename)s:%(lineno)d]' if os.getenv('PLESK_DEBUG') == '1' else ''
        formatter = logging.Formatter('%(name)s[%(process)d]: %(levelname)s %(message)s' + format_tail)
        syslog_handler = logging.handlers.SysLogHandler(address='/dev/log', facility='mail')
        syslog_handler.setFormatter(formatter)
        log.addHandler(syslog_handler)
        log.propagate = False
    return log


class PleskMailHandler(object):
    # Mail handler exit codes (taken from mhandlers_types.h enum mh_status_t)
    HOOK_PASS    = 0
    HOOK_STOP    = 1
    HOOK_DEFER   = 2
    HOOK_REJECT  = 3
    HOOK_LOG     = 4
    HOOK_SKIP    = 5   # skip checks and return PASS
    HOOK_UNKNOWN = 6   # status undefined
    HOOK_ERROR   = 255 # shell handlers convert (-1) into one byte number

    # Mail handler argument indicies (taken from mhandlers_types.h enum mh_args_t)
    MH_ARGV_EXECUTABLE      = 0
    MH_ARGV_CONTEXT         = 1
    MH_ARGV_FROM            = 2
    MH_ARGV_RCPT_LIST_BEGIN = 3

    # Environment variables
    ENV_SENDMAIL_CONTEXT = "PPP_INSIDE_SENDMAIL"
    ENV_SMTP_AUTH_USER = "SMTP_AUTHORIZED"
    ENV_OVERRIDE_SENDER = "PPP_OVERRIDE_SENDER"
    ENV_SHORTNAMES_ALLOWED = "SHORTNAMES"

    def __init__(self, argv=sys.argv):
        try:
            self.executable = argv[self.MH_ARGV_EXECUTABLE]
            self.context = argv[self.MH_ARGV_CONTEXT]
            self.from_ = argv[self.MH_ARGV_FROM]
            self.recipients = argv[self.MH_ARGV_RCPT_LIST_BEGIN:]
        except IndexError as ex:
            raise Exception("Not enough arguments. Usage: %s <CONTEXT> <FROM> <RECIPIENTS_LIST>" % sys.argv[0])
        log.debug("%s: executable='%s', context='%s', from='%s', recipients=%s", 
                  self.__class__.__name__, self.executable, self.context, self.from_, self.recipients)

    def run(self):
        pass

    def exit_skip(self):
        sys.stderr.write("SKIP")
        sys.exit(self.HOOK_SKIP)

    def exit_pass(self):
        sys.stderr.write("PASS")
        sys.exit(self.HOOK_PASS)

    def exit_reject(self):
        sys.stderr.write("REJECT")
        sys.exit(self.HOOK_REJECT)

    def exit_defer(self):
        sys.stderr.write("DEFER")
        sys.exit(self.HOOK_DEFER)

    def data(self, msg):
        sys.stderr.write("DATA %s\n" % msg)

    def reply(self, rcode, xcode, message):
        if not rcode and xcode:
            raise ValueError("rcode should not be empty when xcode is specified")
        r = "REPLY"
        if rcode:
            r += ":" + rcode
        if xcode:
            r += ":" + xcode
        r += " " + message
        self.data(r)

    def change_envelope_sender(self, address):
        """ Change envelope sender address. Supported only under Postfix in before-queue. No effect with QMail. """
        if not address or '@' not in address:
            raise ValueError("Cannot change envelope sender to invalid address: %s", address)
        self.data('CHMFROM:' + address)

    def log(self, msg):
        sys.stderr.write("LOG %s\n" % msg)

    def is_in_sendmail_context(self):
        return bool(os.getenv(self.ENV_SENDMAIL_CONTEXT))


class LimitHandlerMessages(object):
    REJECT_MALFORMED_BOTH_VHOSTID_USERID = "Your message could not be sent due to domain misconfiguration. Contact your service provider to resolve this issue."
    REJECT_MALFORMED_EMPTY_VHOSTID = "Your message could not be sent. The sender's domain is not registered in Panel, or is misconfigured."
    REJECT_MALFORMED_INVALID_VHOSTID = "Your message could not be sent. The sender's domain is not configured properly."
    REJECT_MAIL_NOT_ALLOWED = "Your message could not be sent. The user %s is not allowed to send email."
    REJECT_SENDMAIL_NOT_ALLOWED = "The message could not be sent. You are not allowed to use sendmail utility."
    REJECT_OUTGOING_LIMIT_EXCEEDED = "Your message could not be sent. The limit on the number of allowed outgoing messages was exceeded. Try again later."

    REJECT_INVALID_SHORT_SMTP_AUTH = "Could not locate mail account for a given short mail account name '%s'. Try using full mail account name for authentication."

    REJECT_ENVELOPE_SENDER_SMTP_AUTH_MISMATCH = "Envelope sender address does not match authenticated user name."
    REJECT_SENDER_HEADER_DOMAIN_MISMATCH = "Sender address found in the message headers does not match sender domain."
    REJECT_TOO_MANY_SENDER_ADDRESSES = "The message does not conform to RFC2822: too many sender addresses."

    DEFER_INTERNAL_ERROR = "Your message could not be sent due to an internal error. See details in system mail log."


class LimitHandler(PleskMailHandler):
    HEADER_PLESK_MESSAGE_ID = "X-PPP-Message-ID"

    ENV_VHOSTID = "PPP_SENDER_VHOST_ID"
    ENV_USERID = "PPP_SENDER_USER_ID"

    _DEFAULT_OBJECTS_DICT = {
        'whitelist': None,
        'mail': None,
        'domain': None,
        'subscription': None,
    }

    _generate_plesk_message_id = False
    __db = None
    __smtp_auth_user = False


    def _db(self):
        if self.__db is None:
            log.debug("Initializing connection to backend DB")
            if "in-memory" in self.context:
                self.__db = OutgoingMailLimitsDb(OutgoingMailLimitsDb.IN_MEMORY)
            else:
                self.__db = OutgoingMailLimitsDb()
        return self.__db

    def run(self):
        try:
            self.preamble()

            self.message = email.message_from_file(sys.stdin)

            if self._generate_plesk_message_id:
                self.message_id = self._generate_message_id()
            else:
                self.message_id = self._get_message_id()

            db = self._db()

            self.handle_message(db, self.message, self.message_id)
            raise Exception("handle_message() must call one of exit_*() functions internally")

        except SystemExit:
            raise
        except Exception as ex:
            if 'defer-on-internal-errors' not in self.context:
                raise

            log.error('Deferring message due to internal error. %s', ex)
            log.debug('This exception happened at:', exc_info=sys.exc_info())
            self.reply("451", "4.7.1", LimitHandlerMessages.DEFER_INTERNAL_ERROR)
            self.exit_defer()

    @staticmethod
    def output_message(message):
        """ Format and output a given email.message.Message instance to stdout. """
        fp = StringIO()
        generator = Generator(fp, mangle_from_=False)
        generator.flatten(message)
        sys.stdout.write(fp.getvalue())

    def exit_pass(self):
        self.output_message(self.message)
        super(LimitHandler, self).exit_pass()

    def get_smtp_auth_user(self, db):
        """ Returns full username used for authentication in SMTP session or None.
            If short mail account names are allowed and the short username used for authentication cannot be reliably
            matched to a full one, message is rejected.
            If the username is invalid, Exception is raised.
        """
        if self.__smtp_auth_user is False:
            force_sender = os.getenv(self.ENV_OVERRIDE_SENDER)
            if force_sender == "":
                log.error("Empty variable %s. Ignoring.", self.ENV_OVERRIDE_SENDER)

            if force_sender:
                log.debug("Force sender email: %s", force_sender)
                self.__smtp_auth_user = force_sender
            elif self.is_in_sendmail_context():
                self.__smtp_auth_user = None
            else:
                self.__smtp_auth_user = os.getenv(self.ENV_SMTP_AUTH_USER)

                if self.__smtp_auth_user and '@' not in self.__smtp_auth_user:
                    if os.getenv(self.ENV_SHORTNAMES_ALLOWED):
                        full_mailnames = db.list_mailnames_by_local_alias(self.__smtp_auth_user)

                        if len(full_mailnames) != 1:
                            log.error("Rejecting message: mail account could not be reliably located for short mail "
                                      "account name '%s'. %d different matching accounts were found. %s",
                                      self.__smtp_auth_user, len(full_mailnames),
                                      "Mail subsystem is likely misconfigured." if len(full_mailnames) else "")
                            self.reply("554", "5.7.8", 
                                       LimitHandlerMessages.REJECT_INVALID_SHORT_SMTP_AUTH % self.__smtp_auth_user)
                            self.exit_reject()

                        mailname = full_mailnames[0]
                        self.__smtp_auth_user = '%s@%s' % (mailname['mail_name'], mailname['domain_name'])

            if self.__smtp_auth_user and self.__smtp_auth_user.count('@') != 1:
                raise Exception("Malformed SMTP session auth user: %s" % self.__smtp_auth_user)

        return self.__smtp_auth_user.lower()

    def get_vhost_id(self):
        """ Returns "PPP_SENDER_VHOST_ID" environment variable value or None. """
        return os.getenv(self.ENV_VHOSTID)

    def get_user_id(self):
        """ Returns "PPP_SENDER_USER_ID" environment variable value (as int) or None. """
        uid = os.getenv(self.ENV_USERID)
        if uid is not None:
            return int(uid)
        return None

    def _log_debug_info(self, db):
        """ Log debug info regarding various input parameters. """
        if "debug" not in self.context:
            return

        smtp_auth_user = self.get_smtp_auth_user(db)
        vhost_id = self.get_vhost_id()
        user_id = self.get_user_id()
        message_id = self.message_id

        log.info("[%s] message_id='%s' smtp_auth_user='%s' vhost_id='%s' user_id='%s'",
                 self.__class__.__name__, message_id, smtp_auth_user, vhost_id, user_id)

        if smtp_auth_user:
            mail_name, domain_name = smtp_auth_user.split('@')
            mailname = db.fetch_mailname_by_aliases(mail_name, domain_name)
            log.info("  mailname (%s@%s) => %s", mail_name, domain_name, mailname)

        if vhost_id:
            domain = db.fetch_domain_by_vhost_id(vhost_id)
            log.info("  domain (%s) => %s", vhost_id, domain)

        if user_id is not None:
            if db.is_user_in_whitelist(user_id):
                log.info("  sysuser (%s) => whitelisted", user_id)

            sysuser = db.fetch_sysuser(user_id)
            log.info("  sysuser (%s) => %s", user_id, sysuser)

        if message_id:
            msg_record = db.run("SELECT * FROM messages WHERE message_id = ?", message_id).fetchone()
            log.info("  message_id (%s) => %s", message_id, msg_record)

    def _check_limits_and_update_usage(self, db, message_id, 
                                       limits, limit_name, positive_resolution, negative_resolution):
        """ Check current usage against supplied limits structure and decide whether to allow sending message.
            Update DB messages table with a record for the message.
            Returns True if message should be blocked, False if it should be allowed to go further.
        """
        checked_ids = self._DEFAULT_OBJECTS_DICT.copy()
        exceeded_ids = self._DEFAULT_OBJECTS_DICT.copy()

        if not message_id:
            raise ValueError("%s header value is missing" % self.HEADER_PLESK_MESSAGE_ID)

        if limit_name not in ('out_limit'):
            raise ValueError("limit_name must be 'out_limit'")

        if positive_resolution not in db.VALID_RESOLUTION_VALUES:
            raise ValueError("positive_resolution must be one of: %s" % db.VALID_RESOLUTION_VALUES)

        if negative_resolution not in db.VALID_RESOLUTION_VALUES:
            raise ValueError("negative_resolution must be one of: %s" % db.VALID_RESOLUTION_VALUES)

        log.debug("Checking limits usage against following data: %s", limits)

        for obj in ('whitelist', 'mail', 'domain', 'subscription'):
            if not limits[obj]:
                continue

            limit_value = limits[obj][limit_name]
            object_id = limits[obj]['_id']

            if limit_value <= 0:
                # Edge cases are checked before counting messages
                if limit_value == 0:
                    log.debug("Checking against %s %s = %s: sending mail is not allowed apriori", 
                              obj, limit_name, limit_value)
                    checked_ids[obj] = exceeded_ids[obj] = object_id
                elif limit_value < 0:
                    log.debug("Checking against %s %s = %s: may send unlimited amount of mail messages",
                              obj, limit_name, limit_value)
                    checked_ids[obj] = object_id
            else:
                # We don't actually need to count messages for whitelist, but we can anyway :)
                count = db.count_messages(obj, object_id, resolution='passed')
                log.debug("Number of passed messages during last hour for %s is %s", obj, count)

                if count < limit_value:
                    log.debug("Checking against %s %s = %s: limit is not exceeded yet, "
                              "passed only %s messages during last hour",
                              obj, limit_name, limit_value, count)
                    checked_ids[obj] = object_id
                else:
                    log.debug("Checking against %s %s = %s: limit is already exceeded, "
                              "passed %s messages during last hour",
                              obj, limit_name, limit_value, count)
                    checked_ids[obj] = exceeded_ids[obj] = object_id

        # SQLite ids (rowid) start with 1 (not 0), therefore using implicit cast to boolean is OK
        block_message = any(exceeded_ids.values())
        resolution = block_message and negative_resolution or positive_resolution
        ids = block_message and exceeded_ids or checked_ids

        log.debug("Decision: message will be %s", resolution)
        log.debug("Checked the following objects: %s", [ k for k, v in checked_ids.items() if v ])
        log.debug("Exceeded limits for the following objects: %s", [ k for k, v in exceeded_ids.items() if v ])

        try:
            db.update_message(message_id, resolution,
                              mail_id=ids['mail'],
                              domain_id=ids['domain'],
                              subscription_id=ids['subscription'],
                              whitelist_id=ids['whitelist'],
                              unique=True)
        except DuplicateRecordError as ex:
            # This should not happen, since we currently generate message_id ourselves
            log.info("REJECT message as its ID is not unique: %s", ex)
            self.exit_reject()

        return block_message

    def _get_limits(self, db):
        """ Returns limits structure. Verifies that handler environment is correct and possibly rejects
            message if it is not. Returns None if message is either unrestricted outgoing or incoming.
        """
        smtp_auth_user = self.get_smtp_auth_user(db)
        vhost_id = self.get_vhost_id()
        user_id = self.get_user_id()

        if smtp_auth_user:
            mail_name, domain_name = smtp_auth_user.split('@')
            return self._fetch_limits(db, mail_name=mail_name, domain_name=domain_name)

        if vhost_id and user_id is not None:
            log.error("Rejecting malformed message: both %s and %s are present in environment", 
                      self.ENV_VHOSTID, self.ENV_USERID)
            self.reply("554", "5.7.0", LimitHandlerMessages.REJECT_MALFORMED_BOTH_VHOSTID_USERID)
            self.exit_reject()

        if vhost_id is not None:
            if vhost_id == "":
                log.error("Rejecting forged message: %s is present in environment, but its value is empty", 
                          self.ENV_VHOSTID)
                self.reply("554", "5.7.0", LimitHandlerMessages.REJECT_MALFORMED_EMPTY_VHOSTID)
                self.exit_reject()
            limits = self._fetch_limits(db, vhost_id=vhost_id)
            if not limits['domain']:
                log.error("Rejecting forged message: %s environment variable value is invalid", self.ENV_VHOSTID)
                self.reply("554", "5.7.0", LimitHandlerMessages.REJECT_MALFORMED_INVALID_VHOSTID)
                self.exit_reject()
            return limits

        if user_id is not None:
            limits = self._fetch_limits(db, sysuser_uid=user_id)
            if not limits['subscription'] and not limits['whitelist']:
                log.error("Rejecting message: system user uid='%s' is not allowed to send mail", user_id)
                try:
                    user_name = pwd.getpwuid(user_id).pw_name
                except KeyError :
                    user_name = "<uid=%d>" % user_id
                self.reply("554", "5.7.0", LimitHandlerMessages.REJECT_MAIL_NOT_ALLOWED % user_name)
                self.exit_reject()
            return limits

        return None

    def _fetch_limits(self, db, mail_name=None, domain_name=None, sysuser_uid=None, vhost_id=None):
        """ Build limits structure based on known input data. """
        limits = self._DEFAULT_OBJECTS_DICT.copy()
        main_domain_name = None

        if sysuser_uid is not None and db.is_user_in_whitelist(sysuser_uid):
            limits['whitelist'] = db.run("SELECT id AS _id, name AS sysuser_name FROM whitelist WHERE id = ?",
                                         sysuser_uid).fetchone()
            limits['whitelist'].update(out_limit=-1, allow_sendmail=1)
            return limits

        if vhost_id and not domain_name:
            limits['domain'] = db.fetch_domain_by_vhost_id(vhost_id)
            if limits['domain']:
                domain_name = limits['domain']['name']
                main_domain_name = limits['domain']['main_domain_name']

        if mail_name and domain_name:
            limits['mail'] = db.fetch_mailname_by_aliases(mail_name, domain_name)
            if limits['mail']:
                mail_name = limits['mail']['mail_name']
                domain_name = limits['mail']['domain_name']

        if domain_name:
            limits['domain'] = db.fetch_domain(domain_name)
            if limits['domain']:
                main_domain_name = limits['domain']['main_domain_name']

        if sysuser_uid is not None:
            limits['subscription'] = db.fetch_sysuser(sysuser_uid)
        elif main_domain_name is not None:
            limits['subscription'] = db.fetch_subscription(main_domain_name)

        return limits
    
    def _should_reject_sendmail_usage(self, limits):
        """ Check whether the mail being checked was sent via sendmail and its usage is not allowed. """
        return not (limits['whitelist'] or  # whitelisted users are allowed to use sendmail
                    limits['mail'] or       # SMTP authenticated users aren't using sendmail (obviously)
                    (limits['subscription'] and limits['subscription']['allow_sendmail']))

    def _get_sender_str(self, limits):
        """ Return externally-recognizable sender string or None. """
        def get_fqdn_hostname():
            import socket
            return socket.getfqdn()

        if limits['whitelist']:
            return get_fqdn_hostname()

        if limits['domain']:
            return limits['domain']['name']

        if limits['subscription']:
            return limits['subscription']['main_domain_name']

        return None

    def _get_message_id(self):
        return self.message[self.HEADER_PLESK_MESSAGE_ID]

    def _generate_message_id(self):
        del self.message[self.HEADER_PLESK_MESSAGE_ID]
        self.message[self.HEADER_PLESK_MESSAGE_ID] = message_id = email.utils.make_msgid()
        log.debug("Generated new header %s: %s", self.HEADER_PLESK_MESSAGE_ID, message_id)
        return message_id

    def __check_and_update_sender_header(self, db, action, domain_name, real_sender_address, 
                                         header, (realname, address)):
        """ Check that mail address in a given sender header matches expected domain_name.
            If it doesn't, perform action (fix or verify).
            If real_sender_address is specified it will be used to fix sender, 
            otherwise sender domain will be changed to domain_name.
        """
        mail_name, sender_domain_name = address.split('@')
        sender_domain = db.fetch_domain_by_alias(sender_domain_name)
        if not sender_domain or domain_name not in (sender_domain['name'], sender_domain['main_domain_name']):
            if action == 'fix':
                address = real_sender_address if real_sender_address else '%s@%s' % (mail_name, domain_name)
                log.debug("Address in header '%s: %s' was changed to '%s'", header, self.message[header], address)
                del self.message[header]
                self.message[header] = email.utils.formataddr((realname, address))
            elif action == 'verify':
                log.error("Rejecting message with invalid header '%s: %s' since it doesn't match sender domain '%s'",
                          header, self.message[header], domain_name)
                self.reply("501", "5.1.8", LimitHandlerMessages.REJECT_SENDER_HEADER_DOMAIN_MISMATCH)
                self.exit_reject()
        else:
            log.debug("Address in header '%s: %s' matches domain '%s'", header, self.message[header], domain_name)

    def __check_and_update_sender_headers(self, db, action, domain_name, real_sender_address=None):
        """ Check addresses in mail message sender headers against expected domain_name and perform action
            upon mismatch.
        """
        def get_valid_addresses(message, header):
            """ List all parsed emails with valid addresses listed under given header. """
            parsed_addresses = email.utils.getaddresses(message.get_all(header) or [])
            return filter(lambda (realname, address): '@' in address, parsed_addresses)

        HEADER_FROM   = 'From'
        HEADER_SENDER = 'Sender'

        from_headers   = get_valid_addresses(self.message, HEADER_FROM)
        sender_headers = get_valid_addresses(self.message, HEADER_SENDER)

        if len(sender_headers) == 1:
            self.__check_and_update_sender_header(db, action, domain_name, real_sender_address, 
                                                  HEADER_SENDER, sender_headers[0])
        elif len(from_headers) == 1:
            self.__check_and_update_sender_header(db, action, domain_name, real_sender_address, 
                                                  HEADER_FROM, from_headers[0])
        elif not from_headers and not sender_headers:
            log.debug("Neither '%s' nor '%s' header contains mail address. Sender header cannot be checked.", 
                      HEADER_FROM, HEADER_SENDER)
        else:
            log.error("%s mail message which doesn't conform to RFC2822: there are %d '%s' addresses and "
                      "%d '%s' addresses.", "Rejecting" if action == 'verify' else "Passing",
                      len(from_headers), HEADER_FROM, len(sender_headers), HEADER_SENDER)
            if action == 'verify':
                self.reply("501", "5.5.2", LimitHandlerMessages.REJECT_TOO_MANY_SENDER_ADDRESSES)
                self.exit_reject()

    def _check_and_update_sender(self, db, limits, envelope=None, headers=None):
        """ Check sender against actual sending mail address or domain.
            If there is a mismatch in envelope sender, perform envelope action ('fix' has effect only on Postfix).
            If there is a mismatch in sender mail headers, perform headers action.
            Action is either 'fix' in which case sender is fixed to our best abilities,
            or 'verify' in which case mail will be rejected upon mismatch, or None to skip checking.

            Warning: this method should not be considered a security feature against sending spam (particularly
            with 'fix' actions). Instead it is a feature to reduce chances that legitimate outgoing mail will
            be considered as spam by reciever.
        """
        KNOWN_ACTIONS = (None, 'fix', 'verify')

        if envelope not in KNOWN_ACTIONS or headers not in KNOWN_ACTIONS:
            raise ValueError("Sender check or update action should be either of %s", KNOWN_ACTIONS)

        if limits['whitelist']:
            log.debug("Sender '%s' is in whitelist. No sender validity checks will be done.", 
                      limits['whitelist']['sysuser_name'])

        smtp_auth_user = self.get_smtp_auth_user(db)

        if self.is_in_sendmail_context():
            domain_name = None
            if limits['domain']:
                domain_name = limits['domain']['name']
            elif limits['subscription']:
                domain_name = limits['subscription']['main_domain_name']
        else:
            domain_name = limits['domain']['name']
            if envelope and smtp_auth_user:
                envelope_domain = db.fetch_domain_by_alias(self.from_.split('@')[1])
                if not envelope_domain or domain_name != envelope_domain['name']:
                    if envelope == 'fix':
                        self.change_envelope_sender(smtp_auth_user)
                        log.debug("Envelope sender was changed from '%s' to '%s'", self.from_, smtp_auth_user)
                        self.from_ = smtp_auth_user
                    elif envelope == 'verify':
                        log.error("Rejecting message with invalid envelope sender '%s' which doesn't match SMTP "
                                  "authentication id '%s'", self.from_, smtp_auth_user)
                        self.reply("501", "5.1.8", LimitHandlerMessages.REJECT_ENVELOPE_SENDER_SMTP_AUTH_MISMATCH)
                        self.exit_reject()

        if headers:
            self.__check_and_update_sender_headers(db, headers, domain_name, smtp_auth_user)

    def preamble(self):
        """ Called before parsing message and opening DB. Should be overridden in subclasses. """
        pass

    def handle_message(self, db, message, message_id):
        """ Called after parsing message and opening DB to process message. Should be overridden in subclasses. """
        pass

# vim: ts=4 sts=4 sw=4 et :
