patchwork-bot: add proper typing hints

It helps to catch easy bugs when writing in PyCharm and other IDEs.

Signed-off-by: Konstantin Ryabitsev <konstantin@linuxfoundation.org>
diff --git a/git-patchwork-bot.py b/git-patchwork-bot.py
index a9bc771..6adef71 100755
--- a/git-patchwork-bot.py
+++ b/git-patchwork-bot.py
@@ -39,6 +39,7 @@
 from requests.packages.urllib3.util.retry import Retry
 
 from string import Template
+from typing import Optional, Tuple, Union, Dict, List, Set
 
 # Send all email 8-bit, this is not 1999
 from email import charset
@@ -66,7 +67,15 @@
 
 
 class Restmaker:
-    def __init__(self, server):
+    server: str
+    url: str
+    series_url: str
+    patches_url: str
+    projects_url: str
+    session: requests.Session
+    _patches: Dict[int, Optional[dict]]
+
+    def __init__(self, server: str) -> None:
         self.server = server
         self.url = '/'.join((server.rstrip('/'), 'api', REST_API_VERSION))
 
@@ -91,7 +100,7 @@
             headers['Authorization'] = f'Token {apitoken}'
         self.session.headers.update(headers)
 
-    def get_unpaginated(self, url, params):
+    def get_unpaginated(self, url: str, params: list) -> List[dict]:
         # Caller should catch RequestException
         page = 0
         results = list()
@@ -113,7 +122,7 @@
 
         return results
 
-    def get_cover(self, cover_id):
+    def get_cover(self, cover_id: int) -> dict:
         try:
             logger.debug('Grabbing cover %d', cover_id)
             url = '/'.join((self.covers_url, str(cover_id), ''))
@@ -123,9 +132,9 @@
             return rsp.json()
         except requests.exceptions.RequestException as ex:
             logger.info('REST error: %s', ex)
-            return None
+            raise KeyError('Not able to get cover %s', cover_id)
 
-    def get_patch(self, patch_id):
+    def get_patch(self, patch_id: int) -> dict:
         if patch_id not in self._patches:
             try:
                 logger.debug('Grabbing patch %d', patch_id)
@@ -137,23 +146,23 @@
             except requests.exceptions.RequestException as ex:
                 logger.info('REST error: %s', ex)
                 self._patches[patch_id] = None
+                raise KeyError('Not able to get patch_id %s', patch_id)
 
         return self._patches[patch_id]
 
-    def get_series(self, series_id):
+    def get_series(self, series_id: int) -> dict:
         try:
             logger.debug('Grabbing series %d', series_id)
             url = '/'.join((self.series_url, str(series_id), ''))
             logger.debug('url=%s', url)
             rsp = self.session.get(url, stream=False)
             rsp.raise_for_status()
+            return rsp.json()
         except requests.exceptions.RequestException as ex:
             logger.info('REST error: %s', ex)
-            return None
+            raise KeyError('Not able to get series %s', series_id)
 
-        return rsp.json()
-
-    def get_patches_list(self, params, unpaginated=True):
+    def get_patches_list(self, params: list, unpaginated: bool = True) -> List[dict]:
         try:
             if unpaginated:
                 return self.get_unpaginated(self.patches_url, params)
@@ -163,10 +172,9 @@
                 return rsp.json()
         except requests.exceptions.RequestException as ex:
             logger.info('REST error: %s', ex)
+            return list()
 
-        return None
-
-    def get_series_list(self, params, unpaginated=True):
+    def get_series_list(self, params: list, unpaginated: bool = True) -> List[dict]:
         try:
             if unpaginated:
                 return self.get_unpaginated(self.series_url, params)
@@ -176,18 +184,17 @@
                 return rsp.json()
         except requests.exceptions.RequestException as ex:
             logger.info('REST error: %s', ex)
+            return list()
 
-        return None
-
-    def get_projects_list(self, params):
+    def get_projects_list(self, params: list) -> list:
         try:
             return self.get_unpaginated(self.projects_url, params)
         except requests.exceptions.RequestException as ex:
             logger.info('REST error: %s', ex)
+            return list()
 
-        return None
-
-    def update_patch(self, patch_id, state=None, archived=False, commit_ref=None):
+    def update_patch(self, patch_id: int, state: Optional[str] = None, archived: bool = False,
+                     commit_ref: Optional[str] = None) -> list:
         # Clear it out of the cache
         if patch_id in self._patches:
             del self._patches[patch_id]
@@ -211,12 +218,12 @@
             rsp.raise_for_status()
         except requests.exceptions.RequestException as ex:
             logger.info('REST error: %s', ex)
-            return None
+            raise RuntimeError('Unable to update patch %s', patch_id)
 
         return rsp.json()
 
 
-def get_patchwork_patches_by_project_hash(rm, project, pwhash):
+def get_patchwork_patches_by_project_hash(rm: Restmaker, project: int, pwhash: str) -> List[int]:
     logger.debug('Looking up %s', pwhash)
     params = [
         ('project', project),
@@ -226,12 +233,12 @@
     patches = rm.get_patches_list(params)
     if not patches:
         logger.debug('No match for hash=%s', pwhash)
-        return None
+        return list()
 
     return [patch['id'] for patch in patches]
 
 
-def get_patchwork_pull_requests_by_project(rm, project, fromstate):
+def get_patchwork_pull_requests_by_project(rm: Restmaker, project: int, fromstate: List[str]) -> Set[Tuple]:
     params = [
         ('project', project),
         ('archived', 'false'),
@@ -256,16 +263,15 @@
                 pull_refname = 'master'
 
             prs.add((pull_host, pull_refname, patch_id))
-
     return prs
 
 
-def project_by_name(pname):
+def project_by_name(pname: str) -> Tuple:
     global _project_cache
     global _server_cache
 
     if not pname:
-        return None
+        raise KeyError('Must specify project name')
 
     if pname not in _project_cache:
         # Find patchwork definition containing this project
@@ -302,27 +308,27 @@
                 break
         if not found:
             logger.info('Could not find project matching %s on server %s', pname, server)
-            return None
+            raise KeyError(f'No match for project {pname} on server {server}')
 
     return _project_cache[pname]
 
 
-def db_save_meta(c):
+def db_save_meta(c: sqlite3.Cursor) -> None:
     c.execute('DELETE FROM meta')
     c.execute('''INSERT INTO meta VALUES(?)''', (DB_VERSION,))
 
 
-def db_save_repo_heads(c, heads):
+def db_save_repo_heads(c: sqlite3.Cursor, heads: list) -> None:
     c.execute('DELETE FROM heads')
     for refname, commit_id in heads:
         c.execute('''INSERT INTO heads VALUES(?,?)''', (refname, commit_id))
 
 
-def db_get_repo_heads(c):
+def db_get_repo_heads(c: sqlite3.Cursor) -> List[Tuple]:
     return c.execute('SELECT refname, commit_id FROM heads').fetchall()
 
 
-def db_init_common_sqlite_db(c):
+def db_init_common_sqlite_db(c: sqlite3.Cursor) -> None:
     c.execute('''
         CREATE TABLE meta (
             version INTEGER
@@ -330,7 +336,7 @@
     db_save_meta(c)
 
 
-def db_init_cache_sqlite_db(c):
+def db_init_cache_sqlite_db(c: sqlite3.Cursor) -> None:
     logger.info('Initializing new sqlite3 db with metadata version %s', DB_VERSION)
     db_init_common_sqlite_db(c)
     c.execute('''
@@ -343,7 +349,7 @@
     c.execute('''CREATE UNIQUE INDEX idx_rev ON revs(rev)''')
 
 
-def db_init_pw_sqlite_db(c):
+def db_init_pw_sqlite_db(c: sqlite3.Cursor) -> None:
     logger.info('Initializing new sqlite3 db with metadata version %s', DB_VERSION)
     db_init_common_sqlite_db(c)
     c.execute('''
@@ -353,7 +359,7 @@
         )''')
 
 
-def git_get_command_lines(gitdir, args):
+def git_get_command_lines(gitdir: str, args: List[str]) -> list:
     out = git_run_command(gitdir, args)
     lines = list()
     if out:
@@ -365,7 +371,7 @@
     return lines
 
 
-def git_run_command(gitdir, args, stdin=None):
+def git_run_command(gitdir: str, args: List[str], stdin: Optional[str] = None) -> str:
     args = ['git', '--no-pager', '--git-dir', gitdir] + args
 
     logger.debug('Running %s' % ' '.join(args))
@@ -383,7 +389,7 @@
     return output
 
 
-def git_get_repo_heads(gitdir, branch, ancestry=None):
+def git_get_repo_heads(gitdir: str, branch: str, ancestry: Optional[str] = None) -> List[Tuple[str, str]]:
     refs = list()
     lines = git_get_command_lines(gitdir, ['show-ref', branch])
     if ancestry is None:
@@ -397,7 +403,8 @@
     return refs
 
 
-def git_get_new_revs(gitdir, db_heads, git_heads, committers, merges=False):
+def git_get_new_revs(gitdir: str, db_heads: List[Tuple[str, str]], git_heads: List[Tuple[str, str]],
+                     committers: List[str], merges: bool = False) -> Dict[str, list]:
     newrevs = dict()
     if committers:
         logger.debug('filtering by committers=%s', committers)
@@ -453,12 +460,12 @@
     return newrevs
 
 
-def git_get_rev_diff(gitdir, rev):
+def git_get_rev_diff(gitdir: str, rev: str) -> str:
     args = ['diff', '%s~..%s' % (rev, rev)]
     return git_run_command(gitdir, args)
 
 
-def git_get_patch_id(diff):
+def git_get_patch_id(diff: str) -> Optional[str]:
     args = ['patch-id', '--stable']
     out = git_run_command('', args, stdin=diff)
     logger.debug('out=%s', out)
@@ -467,7 +474,7 @@
     return out.split()[0]
 
 
-def get_patchwork_hash(diff):
+def get_patchwork_hash(diff: str) -> str:
     """Generate a hash from a diff. Lifted verbatim from patchwork."""
 
     # normalise spaces
@@ -515,13 +522,14 @@
     return hashed.hexdigest()
 
 
-def listify(obj):
+def listify(obj: Union[str, list, None]) -> list:
     if isinstance(obj, list):
         return list(obj)
     return [obj]
 
 
-def send_summary(serieslist, committers, to_state, refname, pname, rs, hs):
+def send_summary(serieslist: List[dict], committers: Dict[int, str], to_state: str, refname: str, pname: str,
+                 rs: Dict[str, str], hs: Dict[str, str]) -> str:
     logger.info('Preparing summary')
     # we send summaries by project, so the project name is going to be all the same
 
@@ -610,7 +618,7 @@
     return str(msg['Message-Id'])
 
 
-def get_tweaks(pconfig, hconfig):
+def get_tweaks(pconfig: Dict[str, str], hconfig: Dict[str, str]) -> Dict[str, str]:
     fields = ['from', 'summaryto', 'onlyto', 'neverto', 'onlyifcc', 'neverifcc',
               'alwayscc', 'alwaysbcc', 'cclist', 'ccall']
     bubbled = dict()
@@ -623,7 +631,9 @@
     return bubbled
 
 
-def notify_submitters(serieslist, committers, refname, revs, pname, rs, hs):
+def notify_submitters(serieslist: List[dict], committers: Dict[int, str], refname: str,
+                      revs: Dict[int, str], pname: str, rs: Dict[str, Union[str, list, dict]],
+                      hs: Dict[str, Union[str, list, dict]]) -> None:
     logger.info('Sending submitter notifications')
     project, rm, pconfig = project_by_name(pname)
 
@@ -634,18 +644,26 @@
         # else the reference is the msgid of the first patch
         patches = sdata.get('patches')
         is_pull_request = False
+        content = headers = reference = None
         if sdata.get('cover_letter'):
             reference = sdata.get('cover_letter').get('msgid')
-            fullcover = rm.get_cover(sdata.get('cover_letter').get('id'))
-            headers = {k.lower(): v for k, v in fullcover.get('headers').items()}
-            content = fullcover.get('content')
-        else:
+            try:
+                fullcover = rm.get_cover(sdata.get('cover_letter').get('id'))
+                headers = {k.lower(): v for k, v in fullcover.get('headers').items()}
+                content = fullcover.get('content')
+            except KeyError:
+                logger.debug('Unable to get cover letter, will try first patch')
+        if not reference:
             reference = patches[0].get('msgid')
-            fullpatch = rm.get_patch(patches[0].get('id'))
-            headers = {k.lower(): v for k, v in fullpatch.get('headers').items()}
-            content = fullpatch.get('content')
-            if fullpatch.get('pull_url'):
-                is_pull_request = True
+            try:
+                fullpatch = rm.get_patch(patches[0].get('id'))
+                headers = {k.lower(): v for k, v in fullpatch.get('headers').items()}
+                content = fullpatch.get('content')
+                if fullpatch.get('pull_url'):
+                    is_pull_request = True
+            except KeyError:
+                logger.debug('Unable to get first patch reference, bailing on %s', sdata.get('id'))
+                continue
 
         submitter = sdata.get('submitter')
         project = sdata.get('project')
@@ -784,7 +802,7 @@
             logger.info('------------------------------')
 
 
-def housekeeping(pname):
+def housekeeping(pname: str) -> None:
     project, rm, pconfig = project_by_name(pname)
     if 'housekeeping' not in pconfig:
         return
@@ -1030,7 +1048,7 @@
             logger.info('------------------------------')
 
 
-def pwrun(repo, rsettings):
+def pwrun(repo: str, rsettings: Dict[str, Union[str, list, dict]]) -> None:
     git_heads = git_get_repo_heads(repo, branch=rsettings.get('branch', '--heads'))
     if not git_heads:
         logger.info('Could not get the latest ref in %s', repo)
@@ -1288,7 +1306,7 @@
         dbconn.commit()
 
 
-def check_repos():
+def check_repos() -> None:
     # Use a global lock to make sure only a single process is running
     try:
         lockfh = open(os.path.join(CACHEDIR, 'patchwork-bot.global.lock'), 'w')
@@ -1310,7 +1328,7 @@
         pwrun(fullpath, settings)
 
 
-def pwhash_differ():
+def pwhash_differ() -> None:
     diff = sys.stdin.read()
     pwhash = get_patchwork_hash(diff)
     print(pwhash)