#!/usr/bin/env python3
# SPDX-License-Identifier: GPL2.0
# Copyright Thomas Gleixner <tglx@linutronix.de>

from email.utils import make_msgid, formatdate
from argparse import ArgumentParser
import subprocess
import mailbox
import email.policy
import email
import yaml
import time
import sys
import os
import re

def committer_from_git():
    '''
    Retrieve the commiter name/mail from git or environment
    '''
    cmd = "echo $(git config --get user.name)' <'$(git config --get user.email)'>'"
    committer = subprocess.getoutput(cmd)

    if committer == " <>":
        committer = os.environ.get("GIT_COMMITTER_NAME") + " <" + \
            os.environ.get("GIT_COMMITTER_EMAIL") + ">"
    return committer

'''
Regexes to split out the patch from a mail
'''
re_split_at_patch = [
    re.compile('\ndiff --git a/.* b/.*\n'),
    re.compile('\n--- a/.*\n'),
    re.compile('\n--- dev/null.*\n'),
    re.compile('\n--- .*/.*\n'),
]

def split_at_patch(text):
    '''
    Check whether the text has a patch embedded.
    If so, split the text and return the parts
    If not, return the text as a single entry list
    '''
    for r in re_split_at_patch:
        m = r.search(text)
        if m:
            s = m.start()
            return [text[:s], text[s+1:]]
    return [text, None]

def get_raw_mailaddr(addr):
    '''
    Extract the raw email address from a string
    '''
    try:
        return addr.split('<')[1].split('>')[0].strip()
    except:
        return addr

re_compress_space = re.compile('\s+')

def decode_hdr(hdr):
    '''
    Decode a mail header with encoding
    '''
    elm = email.header.decode_header(hdr.strip())
    res = ''
    for txt, enc in elm:
        # Groan ....
        if enc:
            res += ' ' + txt.decode(enc)
        elif isinstance(txt, str):
            res += ' ' + txt
        else:
            res += ' ' + txt.decode('ascii')
    return re_compress_space.sub(' ', res).strip()

def should_drop(addr, drop):
    raw = get_raw_mailaddr(addr)
    for d in drop:
        if raw == get_raw_mailaddr(d):
            return True
    return False

def decode_addrs(hdr, drop=[]):
    '''
    Decode mail addresses from a header and handle encondings
    '''
    addrs = mailaddrs()
    if not hdr:
        return addrs
    parts = re_compress_space.sub(' ', hdr).split(',')
    for p in parts:
        addr = decode_hdr(p)
        if not should_drop(addr, drop):
            addrs.add(addr)
    return addrs

def decode_from(msg):
    '''
    Decode the From header and return it as the topmost element
    '''
    addrs = decode_addrs(str(msg['From']))
    return addrs.get_first()

re_noquote = re.compile('[a-zA-Z0-9_\- ]+')

def encode_hdr(txt):
    try:
        return txt.encode('ascii').decode()
    except:
        txt = txt.encode('UTF-8').decode()

def encode_from(fulladdr):
    try:
        name, addr = fulladdr.split('<', 1)
        name = name.strip()
    except:
        return fulladdr

    try:
        name = txt.encode('ascii').decode()
        if not re_noquote.fullmatch(name):
            name = '"%s"' %name.replace('"', '')
    except:
        name = email.header.Header(name).encode()

    return name + ' <' + addr

class mailaddrs(object):
    '''
    Storage for mail addresses with no duplicates allowed
    '''
    def __init__(self, addrlist=[]):
        self.addrs = {}
        self.order = []
        self.append_list(addrlist)

    def add(self, addr_new):
        raw = get_raw_mailaddr(addr_new)
        if not raw in self.order:
            self.order.append(raw)
        # Prevent duplicates
        addr = self.addrs.get(raw, addr_new)
        # Prefer the full Name <addr> variant
        if len(addr) < len(addr_new):
            addr = addr_new
        self.addrs[raw] = addr

    def values(self):
        res = []
        for r in self.order:
            res.append(self.addrs[r])
        return res

    def extend(self, addrs):
        for addr in addrs.values():
            self.add(addr)

    def append_list(self, addrlist):
        for addr in addrlist:
            if len(addr):
                self.add(addr)

    def get_first(self):
        return self.addrs[self.order[0]]

    def get_last(self):
        if len(self.order):
            return self.addrs[self.order[-1]]
        return None

    def remove(self, addr):
        raw = get_raw_mailaddr(addr)
        if raw in self.order:
            del self.addrs[raw]
            self.order.remove(raw)

    def contains(self, addrlist):
        for addr in addrlist:
            raw = get_raw_mailaddr(addr)
            if raw in self.order:
                return True
        return False

class tagstore(object):
    '''
    Storage for links and other things with no duplicates
    '''
    def __init__(self):
        self.links = []

    def values(self):
        return self.links

    def add(self, link):
        if not link in self.links:
            self.links.append(link)

tag_order = [
    'Fixes',
    'Reported-and-tested-by',
    'Reported-by',
    'Suggested-by',
    'Originally-from',
    'Originally-by',
    'Signed-off-by',
    'Tested-by',
    'Reviewed-by',
    'Acked-by',
    'Cc',
    'Link',
]

tag_nomail = [
    'Fixes',
    'Link',
]

tags_reply = [
    'Tested-by',
    'Reviewed-by',
    'Acked-by',
]

re_tags = []

def compile_tags():
    for tag in tag_order:
        re_tags.append(re.compile('^%s:' %tag, re.IGNORECASE))

class patchmsg(object):
    '''
    Internal representation of a patch message.
    '''

    re_rmsquare = re.compile('^\[.*?\]')
    re_tofname = re.compile('[^0-9a-zA-Z]+')

    def __init__(self, msg, args):
        self.args = args
        self.tags = {}

        self.date = msg.get('Date', formatdate())
        self.decode_subject(msg)
        self.get_references(msg)
        self.get_addrs(msg)
        self.split_at_patch(msg)
        self.scan_tags()

    def decode_subject(self, msg):
        subj = msg.get('Subject')
        if not subj:
            return
        # Remove starting brackets stuff
        while subj.startswith('['):
            subj = self.re_rmsquare.sub('', subj, count=1)
        subj = subj.replace('\n', ' ')
        re_compress_space.sub(' ', subj)

        # Try to automatically uppercase the start of the sentence
        # This assumes that the prefix(es) end with a colon and the
        # sentence afterwards does not contain one.
        spos = subj.rfind(':') + 1
        if spos > 0 and spos < len(subj):
            prefix = subj[:spos]
            tail = subj[spos:].strip()
            subj = prefix + ' ' + tail[0].capitalize() + tail[1:]

        self.subject = subj.strip()
        # Create a filename just in case
        self.fname = self.re_tofname.sub('_', self.subject)

    def get_references(self, msg):
        self.references = msg.get('References', '').split()
        self.inreplyto = msg.get('In-Reply-To', '').split()
        self.msgid = msg.get('Message-ID', None)

    def get_addrs(self, msg):
        self.author = decode_from(msg)
        # Get To and Cc
        ccs = decode_addrs(str(msg['To']), drop = self.args.nocc_addrs)
        ccs.extend(decode_addrs(str(msg['Cc']), drop = self.args.nocc_addrs))

        # Check whether the cc list contains a list address match
        self.listmatch = ccs.contains(self.args.list_addrs)
        # Remove the matching lists from Cc
        for addr in self.args.list_addrs:
            ccs.remove(addr)

        if self.args.collectcc:
            self.tags['Cc'] = ccs

    def split_at_patch(self, msg):
        self.text = ''
        self.patch = None

        # Handle multipart messages gracefully
        # i.e. extract attached patches as well
        for part in msg.walk():
            if part.is_multipart():
                continue
            # Refuse to handle anything else than plain text
            if part.get_content_type() != 'text/plain':
                continue
            # Get the text with decoding (could be base64 encoded)
            text = part.get_payload(decode=True)
            if not text:
                continue

            try:
                charset = part.get_content_charset('ascii')
                text, patch = split_at_patch(text.decode(charset))
            except:
                err = 'Failed to decode body of Message-ID: %s\n' %self.msgid
                sys.stderr.write(err)
                return

            self.text += text
            if patch:
                self.patch = patch

    def add_tag(self, key, addr, force=False):

        # Drop old links based on linkbase
        if key == 'Link' and not force and self.args.droplinks:
            if addr.startswith(self.args.linkbase):
                return

        # Drop all ccs if requested
        if key == 'Cc' and self.args.dropccs:
            # Check whether is must be kept, e.g. stable@vger...
            if get_raw_mailaddr(addr) not in self.args.keepcc_addrs:
                return

        if key not in tag_nomail:
            ts = self.tags.get(key, mailaddrs())
        else:
            ts = self.tags.get(key, tagstore())

        ts.add(addr)
        self.tags[key] = ts
        self.has_tags = True

    def scan_tag(self, txt):
        i = 0
        txt = txt.strip()
        for r in re_tags:
            m = r.search(txt)
            if m:
                pos = m.end()
                tag = tag_order[i]
                addr = txt[pos + 1:].strip()
                if tag == 'Reported-and-tested-by':
                    self.add_tag('Reported-by', addr)
                    tag = 'Tested-by'
                self.add_tag(tag, addr)
                return True
            i += 1
        return False

    def scan_tags(self):
        lines = self.text.splitlines()
        self.text = ''
        self.posttags = ''
        self.has_tags = False
        self.has_from_in_text = False
        pos = 0
        last_tagpos = 0
        for l in lines:
            pos += 1
            if self.scan_tag(l):
                last_tagpos = pos
                continue

            # Drop quoted text, except when it's '>From'
            if l.startswith('>'):
                if l.find('From') < 0:
                    continue
                # Remove the quote
                l = l[1:]

            # Respect a From: in the body and update author
            if l.startswith('From:'):
                self.author = l.split(':')[1]
                self.has_from_in_text = True
            if last_tagpos and pos >= last_tagpos:
                self.posttags += l + '\n'
            else:
                self.text += l + '\n'

    def update_tags(self, pmsg):
        for tag in tags_reply:
            rtags = pmsg.tags.get(tag, mailaddrs())
            for addr in rtags.values():
                self.add_tag(tag, addr)

    def sanitize_ccs(self):
        # If a CC address is in any other tag and remove it
        ccs = self.tags.get('Cc', mailaddrs())
        for cc in ccs.values():
            for tag in tag_order:
                if tag == 'Cc' or tag in tag_nomail:
                    continue
                if self.tags.get(tag, mailaddrs()).contains([cc]):
                   ccs.remove(cc)
                   break

class quilter(object):
    '''
    The worker which turns mbox into quilt
    '''
    def __init__(self, args):
        # Store args for further reference
        self.args = args

        # Extract the acked, tested, reviewedby tags which came from
        # the command line
        self.tags = {
            'Acked-by'    : mailaddrs(args.ackedby.split(',')),
            'Tested-by'   : mailaddrs(args.testedby.split(',')),
            'Reviewed-by' : mailaddrs(args.reviewedby.split(',')),
        }

    def should_drop(self, msg):
        try:
            mfrom = str(msg['From'])
        except:
            mfrom = msg['From']
        for drop in self.args.drop_from:
            if mfrom.find(drop) >= 0:
                return True
        return False

    def check_replies(self, pmsg, ref):
        for r in self.replies:
            if ref:
                which = r.references
            else:
                which = r.inreplyto

            if pmsg.msgid in which:
                pmsg.update_tags(r)
                self.replies.remove(r)

    def scan_mbox(self, mbox):
        self.patches = []
        self.replies = []
        self.others = []

        # Initial scan to gather information
        for k, msg in mbox.iteritems():

            if self.should_drop(msg):
                continue
            if not msg.get('Message-ID') and not self.args.noid:
                sys.stderr.write('Drop Message w/o Message-ID: %s\n'
                                 %msg.get('Subject','No subject'))
                continue
            pmsg = patchmsg(msg, self.args)

            # Found a patch
            if pmsg.patch:
                # Add the command line supplied tags
                pmsg.update_tags(self)
                if pmsg.msgid and pmsg.listmatch and not self.args.nolink:
                    mid = pmsg.msgid.lstrip('<').rstrip('>')
                    pmsg.add_tag('Link', self.args.linkbase + mid, force = True)
                self.patches.append(pmsg)
            elif pmsg.has_tags:
                self.replies.append(pmsg)
            else:
                self.others.append(pmsg)

        # Find all direct replies to patches
        for pmsg in self.patches:
            self.check_replies(pmsg, False)

        # Find all replies which reference a patch
        for pmsg in self.patches:
            self.check_replies(pmsg, True)

        # Find all direct replies to others (cover letter etc)
        for pmsg in self.others:
            self.check_replies(pmsg, False)

        # Find all replies which reference others (cover letter etc)
        for pmsg in self.others:
            self.check_replies(pmsg, True)

        # Sort out others which are not referenced by patches
        cover = []
        for o in self.others:
            for p in self.patches:
                if o.msgid == p.inreplyto or o.msgid in p.references:
                    cover.append(o)
                    break

        # Update patches which reference a cover letter
        for c in cover:
            for p in self.patches:
                if c.msgid == p.inreplyto or c.msgid in p.references:
                    p.update_tags(c)

        # Sanitize Ccs:
        for p in self.patches:
            p.sanitize_ccs()

    def write_series(self, pdir):
        # Prevent duplicates in the series. i.e. somebody sent a V3 in
        # reply to V2 of 5/N. So the filename generated from the subject
        # after stripping '[PATCH...]' is the same.

        fnames = []
        for pmsg in self.patches:
            i = 0
            fname = pmsg.fname
            while fname in fnames:
                i += 1
                fname = pmsg.fname + '_%d' %i
            pmsg.fname = fname
            fnames.append(fname)

        # Write the series file
        with open(os.path.join(pdir, 'series'), 'a') as fd:
            for pmsg in self.patches:
                fname = pmsg.fname + '.patch'
                i = 0
                # If the patch file exists, create a new one
                # Duplicate subjects happen ....
                while os.path.isfile(os.path.join(pdir, fname)):
                    i += 1
                    fname = pmsg.fname + '-%d.patch' % i
                pmsg.fname = fname
                fd.write(fname + '\n')

    def needs_sob(self, pmsg):
        last = pmsg.tags.get('Signed-off-by', mailaddrs()).get_last()
        if not last:
            return True
        return get_raw_mailaddr(self.args.committer) != get_raw_mailaddr(last)

    def create_patch(self, pmsg):
        res = ''
        if not self.args.fromnodup and not pmsg.has_from_in_text:
            res += 'From: %s\n\n' %pmsg.author

        # Drop extra newlines before tags
        res += pmsg.text.rstrip() + '\n\n'

        last_tag = ''
        for tag in tag_order:
            # If SOB is empty warn
            if tag == 'Signed-off-by' and not tag in pmsg.tags:
                sys.stderr.write('Patch %s has no SOB\n' %pmsg.subject)

            # Write out the committer SOB after the last SOB except when
            # committer is the author and no other SOBs are behind his SOB
            if self.args.sob_before_cc:
                if last_tag == 'Signed-off-by' and self.needs_sob(pmsg):
                    res += 'Signed-off-by: %s\n' %self.args.committer

            last_tag = tag
            if not tag in pmsg.tags:
                continue

            for val in pmsg.tags[tag].values():
                res += '%s: %s\n' %(tag, val)

        if not self.args.sob_before_cc and self.needs_sob(pmsg):
                res += 'Signed-off-by: %s\n' %self.args.committer

        res += '\n'
        res += pmsg.posttags
        res += '\n'
        res += pmsg.patch
        return res

    def write_patch(self, pmsg, pdir):
        with open(os.path.join(pdir, pmsg.fname), 'w') as fd:
            res  = 'From: %s\n' %pmsg.author
            res += 'Subject: %s\n' %pmsg.subject
            res += 'Date: %s\n\n' %pmsg.date
            res += self.create_patch(pmsg)
            fd.write(res)

    def write_patches(self, pdir):
        for pmsg in self.patches:
            self.write_patch(pmsg, pdir)

    def create_mbox(self, fpath):
        mbox = mailbox.mbox(fpath, create = True)
        for pmsg in self.patches:
            pol = email.policy.EmailPolicy(utf8=True)
            msg = email.message.EmailMessage(pol)
            msg['From'] = email.header.Header(pmsg.author)
            msg['Date'] = pmsg.date
            msg['Subject'] = encode_hdr(pmsg.subject)
            msg['MIME-Version'] = '1.0'
            msg['Content-Type'] = 'text/plain'
            msg.set_param('charset', 'utf-8', header='Content-Type')
            msg['Content-Transfer-Encoding'] = '8bit'
            msg['Content-Disposition'] = 'inline'
            msg.set_unixfrom('From mb2q ' + time.ctime(time.time()))
            msg.set_content(self.create_patch(pmsg))
            mbox.add(mailbox.mboxMessage(msg))
        mbox.close()

if __name__ == '__main__':

    parser = ArgumentParser(description='Mailbox 2 quilt converter')
    parser.add_argument('inbox', metavar='inbox', help='Input mbox file')
    parser.add_argument('-c', '--config', dest='config', default='~/.mb2q.yaml',
                        help='Config file. Default: ~/.mb2q.yaml')
    parser.add_argument('-d', '--droplinks', dest='droplinks',
                        action='store_true',
                        help='Drop existing Link:// tags in patches')
    parser.add_argument('-F', '--fromnodup', dest='fromnodup',
                        action='store_true',
                        help='Do not duplicate From into the body')
    parser.add_argument('-m', '--mboxout', dest='mboxout',
                        help='Store result in mailbox and do not quilt')
    parser.add_argument('-n', '--nolink', dest='nolink',
                        action='store_true',
                        help='Do not add Link:// tags to patches')
    parser.add_argument('-N', '--noid', dest='noid',
                        action='store_true',
                        help='Do not require Message ID. Implies --nolink.')
    parser.add_argument('-p', '--patchesdir', dest='patchesdir',
                        default=None,
                        help='''
                        Use given directory and do not automatically create
                        patches-$inbox/''')
    parser.add_argument('-a', '--ackedby', dest='ackedby', type=str,
                        default='',
                        help='''
                        Add Acked-by to all patches. Comma separated list.
                        "Joe Devel <joe@dev.el>, Dev Joe <dev@joe.org>"
                        ''')
    parser.add_argument('-r', '--reviewedby', dest='reviewedby', type=str,
                        default='',
                        help='Add Reviewed-by to all patches')
    parser.add_argument('-t', '--testedby', dest='testedby', type=str,
                        default='',
                        help='Add Tested-by to all patches')
    parser.add_argument('-C', '--CollectCC', dest='collectcc',
                        action='store_true',
                        help='Collect all CCs from mail header')
    args = parser.parse_args()

    try:
        cfg = yaml.load(open(os.path.expanduser(args.config)), Loader=yaml.FullLoader)
    except:
        cfg = {}

    # Get committer from config or GIT
    args.committer = cfg.get('committer', committer_from_git())
    # Cc's which should be dropped
    args.nocc_addrs = cfg.get('nocc_addrs', [])
    # Cc's which must not be dropped
    args.keepcc_addrs = cfg.get('keepcc_addrs', ['stable@vger.kernel.org'])
    # Mailing list addresses
    args.list_addrs = cfg.get('list_addrs', ['linux-kernel@vger.kernel.org'])
    # Drop mails which originate from a mail clients management mail
    # usually the first mail in a mbox
    args.drop_from = cfg.get('drop_from', ['<MAILER-DAEMON@'])
    # The base address for generating Link://
    args.linkbase = cfg.get('link_base', 'https://lkml.kernel.org/r/')
    # Drop all ccs?
    args.dropccs = cfg.get('dropccs', True)
    # Committer SOB before Cc and Link
    args.sob_before_cc = cfg.get('sob_before_cc', True)

    compile_tags()

    q = quilter(args)
    mbox = mailbox.mbox(args.inbox, create=False)
    q.scan_mbox(mbox)

    if not len(q.patches):
        sys.stderr.write("No patches found in mbox\n")
        sys.exit(1)

    if args.mboxout:
        q.create_mbox(args.mboxout)
    else:
        # Create the patch directory
        if args.patchesdir:
            pdir = os.path.expanduser(args.patchesdir)
        else:
            pdir = 'patches-%s' % os.path.basename(args.inbox)
        if not os.path.isdir(pdir):
            os.makedirs(pdir)
        q.write_series(pdir)
        q.write_patches(pdir)
