#!/usr/bin/python3

import configparser
import imaplib
import email
import io
import zipfile
import xml.etree.ElementTree
import psycopg2
import re
import datetime
import argparse

parser = argparse.ArgumentParser(description='Process DMARC records.')
parser.add_argument('-c', '--config', action='store',
  default='', dest='config_file',
  help='Path to config file')
parser.add_argument('-t', '--test', action='store_true',
                   default=False,
                   help='Test, but do not add records to the database')
args = parser.parse_args()

def fetch_msg(num):
    return mailbox.uid('FETCH', num, '(RFC822)')[1][0][1]

def xml_of_part(part):
    try:
        with zipfile.ZipFile(io.BytesIO(part.get_payload(decode=True))) as zf:
            fn = zf.infolist()[0].filename
            contents = zf.read(fn).decode('utf-8')
            return xml.etree.ElementTree.fromstring(contents)
    except zipfile.BadZipFile:
        return None


def xml_of(message):
    reports = []
    if message.is_multipart():
        for p in message.get_payload():
            if 'zip' in p.get_content_type():
                reports += [xml_of_part(p)]
    else:
        reports = [xml_of_part(message)]
    return reports

def extract_report(msg):
    pmsg = email.message_from_bytes(msg)
    return xml_of(pmsg)

def maybe_strip(text):
    if text:
        return text.strip()
    else:
        return ''

field_maps = {'./policy_published/adkim': {'pg_field_name': 'policy_published_adkim',
  'pg_table': 'reports',
  'pg_type': 'varchar'},
 './policy_published/aspf': {'pg_field_name': 'policy_published_aspf',
  'pg_table': 'reports',
  'pg_type': 'varchar'},
 './policy_published/domain': {'pg_field_name': 'policy_published_domain',
  'pg_table': 'reports',
  'pg_type': 'varchar'},
 './policy_published/p': {'pg_field_name': 'policy_published_p',
  'pg_table': 'reports',
  'pg_type': 'varchar'},
 './policy_published/pct': {'pg_field_name': 'policy_published_pct',
  'pg_table': 'reports',
  'pg_type': 'int'},
 './record[{}]/auth_results/dkim/domain': {'pg_field_name': 'auth_results_dkim_domain',
  'pg_table': 'report_items',
  'pg_type': 'varchar'},
 './record[{}]/auth_results/dkim/result': {'pg_field_name': 'auth_results_dkim_result',
  'pg_table': 'report_items',
  'pg_type': 'varchar'},
 './record[{}]/auth_results/spf/domain': {'pg_field_name': 'auth_results_spf_domain',
  'pg_table': 'report_items',
  'pg_type': 'varchar'},
 './record[{}]/auth_results/spf/result': {'pg_field_name': 'auth_results_spf_result',
  'pg_table': 'report_items',
  'pg_type': 'varchar'},
 './record[{}]/identifiers/header_from': {'pg_field_name': 'identifiers_header_from',
  'pg_table': 'report_items',
  'pg_type': 'varchar'},
 './record[{}]/row/count': {'pg_field_name': 'count',
  'pg_table': 'report_items',
  'pg_type': 'int'},
 './record[{}]/row/policy_evaluated/disposition': {'pg_field_name': 'policy_evaluated_disposition',
  'pg_table': 'report_items',
  'pg_type': 'varchar'},
 './record[{}]/row/policy_evaluated/dkim': {'pg_field_name': 'policy_evaluated_dkim',
  'pg_table': 'report_items',
  'pg_type': 'varchar'},
 './record[{}]/row/policy_evaluated/spf': {'pg_field_name': 'policy_evaluated_spf',
  'pg_table': 'report_items',
  'pg_type': 'varchar'},
 './record[{}]/row/source_ip': {'pg_field_name': 'source_ip',
  'pg_table': 'report_items',
  'pg_type': 'inet'},
 './report_metadata/date_range/begin': {'pg_field_name': 'report_metadata_date_range_begin',
  'pg_table': 'reports',
  'pg_type': 'timestamptz'},
 './report_metadata/date_range/end': {'pg_field_name': 'report_metadata_date_range_end',
  'pg_table': 'reports',
  'pg_type': 'timestamptz'},
 './report_metadata/email': {'pg_field_name': 'report_metadata_email',
  'pg_table': 'reports',
  'pg_type': 'varchar'},
 './report_metadata/org_name': {'pg_field_name': 'report_metadata_org_name',
  'pg_table': 'reports',
  'pg_type': 'varchar'},
 './report_metadata/report_id': {'pg_field_name': 'report_metadata_report_id',
  'pg_table': 'reports',
  'pg_type': 'varchar'}}



def build_insert_command(table_name, report, preamble_values=None, i=None):
    field_names = []
    if preamble_values:
        values = preamble_values.copy()
    else:
        values = {}
    for f in [f for f in field_maps if field_maps[f]['pg_table'] == table_name]:
        if i:
            fp = f.format(i)
        else:
            fp = f
        field_names += [field_maps[f]['pg_field_name']]
        if field_maps[f]['pg_type'] == 'int':
            values[field_maps[f]['pg_field_name']] = int(report.find(fp).text)
        elif field_maps[f]['pg_type'] == 'timestamptz':
            # values[field_maps[f]['pg_field_name']] = datetime.datetime.utcfromtimestamp(int(report.find(fp).text))
            values[field_maps[f]['pg_field_name']] = \
                datetime.datetime.fromtimestamp(int(report.find(fp).text),  
                    tz=datetime.timezone.utc)
        elif field_maps[f]['pg_type'] == 'inet':
            values[field_maps[f]['pg_field_name']] = maybe_strip(report.find(fp).text)
        else:
            values[field_maps[f]['pg_field_name']] = maybe_strip(report.find(fp).text)
    insert_string = 'insert into {} ('.format(table_name)
    if preamble_values:
        insert_string += ', '.join(sorted(preamble_values.keys())) + ', '
    insert_string += ', '.join(field_names) + ') '
    insert_string += 'values ('
    if preamble_values:
        insert_string += ', '.join('%({})s'.format(fn) for fn in sorted(preamble_values.keys())) + ', '
    insert_string += ', '.join('%({})s'.format(f) for f in field_names) + ');'
    return insert_string, values


def write_report(connection, cursor, report):
    try:
        insert_string, values = build_insert_command('reports', report)
        # print(insert_string, values)
        cursor.execute(insert_string, values)
        
        for i in range(1, len(report.findall('./record'))+1):
            field_names = []
            cursor.execute('select id, report_metadata_report_id from reports where report_metadata_report_id = %s;', 
                [report.find('./report_metadata/report_id').text])
            results = cursor.fetchall()
            if len(results) != 1:
                raise RuntimeError('Could not find report record for report item')
            else:
                report_id = results[0][0]
            insert_string, values = build_insert_command('report_items', report, i=i,
                                                         preamble_values={'report_id': report_id})
            # print(insert_string, values)
            cursor.execute(insert_string, values)
        connection.commit()
    except AttributeError:
        pass

config = configparser.ConfigParser()
if args.config_file:
    config.read(args.config_file)
else:
    config.read(['/etc/dmarc_to_database.ini', './dmarc_to_database.ini'])

if not config.sections():
    raise RuntimeError('Could not find configuration file')

conn = psycopg2.connect(host=config['database']['server'],
                        database=config['database']['database'], 
                        user=config['database']['username'], 
                        password=config['database']['password']) 

cur = conn.cursor()
cur.execute('select max(report_metadata_date_range_end) from reports')
results = cur.fetchall()
most_recent_date = results[0][0]

mailbox = imaplib.IMAP4(host=config['imap']['server'], 
                      port=config['imap']['port'])
mailbox.starttls()
mailbox.login(config['imap']['username'], config['imap']['password'])
mailbox.select('INBOX', readonly=True)


if most_recent_date:
    mails_from = "SINCE " + (most_recent_date - datetime.timedelta(days=2)).strftime("%d-%b-%Y")
else:
    mails_from = "ALL"
resp, nums = mailbox.uid('SEARCH', None, mails_from)


dmarc_reports = [report for report_set in [extract_report(fetch_msg(n)) for n in nums[0].split()]
                for report in report_set
                if report]

mailbox.close()
mailbox.logout()

for report in dmarc_reports:
    cur.execute('select id, report_metadata_report_id from reports where report_metadata_report_id = %s;', 
        [report.find('./report_metadata/report_id').text])
    results = cur.fetchall()
    if not results:
        print('write', report.find('./report_metadata/report_id').text)
        if not args.test:
            write_report(conn, cur, report)

conn.close()