Add HMAC to start auth session

Signed-off-by: James Bottomley <James.Bottomley@HansenPartnership.com>
diff --git a/tpm2-interposer.py b/tpm2-interposer.py
index d8b6c37..adee3dc 100644
--- a/tpm2-interposer.py
+++ b/tpm2-interposer.py
@@ -129,7 +129,7 @@
     def updatearea(cls, handles, cmd):
         (tag,) = struct.unpack(">H", cmd[:2])
         if tag != TPM_ST_SESSIONS:
-            return
+            return None
 
         digest = hashes.Hash(hashes.SHA256())
         digest.update(cmd[6:10])
@@ -153,7 +153,7 @@
                 break
 
         if found == None:
-            return False
+            return None
 
         h = hmac.new(found.sessionKey, phash, hashlib.sha256)
         h.update(found.nonceCaller)
@@ -166,7 +166,7 @@
 
         found.decrypt(handles, cmd)
 
-        return True
+        return found
 
     def decrypt(self, handles, cmd):
         if self.attributes[0] & TPMA_SESSION_DECRYPT != TPMA_SESSION_DECRYPT:
@@ -181,28 +181,29 @@
         o += 2
         cmd[o:o+l] = cipher.decrypt(parameter.buf)
 
-    def encrypt(self, reply):
+    def encrypt(self, reply, offset):
         if self.attributes[0] & TPMA_SESSION_ENCRYPT != TPMA_SESSION_ENCRYPT:
             return
 
         aeskeyiv = KDFa(self.sessionKey, b'CFB\x00', self.nonceTPM, self.nonceCaller)
         cipher = AES.new(aeskeyiv[0:16], AES.MODE_CFB, iv = aeskeyiv[16:32], segment_size=128)
-        param = tpm2b.scan(reply[14:])
+        param = tpm2b.scan(reply[14 + offset:])
         paramlen = len(param.buf)
         print(f"encrypting parameter of size {paramlen}")
         param = cipher.encrypt(param.buf)
         print(f"encrypted size {len(param)}")
-        reply[16:16+paramlen] = param
+        reply[16+offset:16+offset+paramlen] = param
 
-    def addresponse(self, ordinal, reply):
+    def addresponse(self, ordinal, reply, handles = 0):
+        offset = 4 * handles
         (response, ) = struct.unpack('>I', reply[6:10])
         # new nonceTPM for reply
         self.nonceTPM = os.urandom(self.noncesize)
 
         if response == 0:
-            (paramsize, ) = struct.unpack('>I', reply[10:14])
-            self.encrypt(reply)
-            responseOffset = 14 + paramsize
+            (paramsize, ) = struct.unpack('>I', reply[10 + offset:14 + offset])
+            self.encrypt(reply, offset)
+            responseOffset = 14 + paramsize + offset
         else:
             paramsize = 0
             responseOffset = 10
@@ -211,7 +212,7 @@
         # rphash: response code, command code, parameters
         digest.update(struct.pack(">II", response, ordinal))
         if paramsize != 0:
-            digest.update(reply[14:14+paramsize])
+            digest.update(reply[14+offset:14+offset+paramsize])
         rhash = digest.finalize()
         h = hmac.new(self.sessionKey, rhash, hashlib.sha256)
         h.update(self.nonceTPM)
@@ -346,13 +347,26 @@
             print("TPM key is not substituted")
             return None
         needHmac = session.updatearea(2, command)
-        if needHmac:
+        if needHmac != None:
             (authorizationsize,) = struct.unpack('>I',command[18:22])
             print(f"authorizationSize={hex(authorizationsize)}")
             del command[18:18+authorizationsize+4]
         s = session.fromcommand(command)
         print(f"Session salt encryption key is substituted, {hex(s.session)}")
-        return s.reply()
+
+        reply = s.reply();
+        if needHmac != None:
+            reply = bytearray(reply)
+            reply[0:0] = struct.pack('>HII', TPM_ST_SESSIONS, 0, 0)
+            # insert parameterSize after handle
+            reply[14:14] = struct.pack('>I', len(reply[14:]))
+            print(f"Adding HMAC for session {hex(needHmac.session)}")
+            needHmac.addresponse(TPM_CC_StartAuthSession, reply, 1)
+            header = bytes(reply[:10])
+            reply = bytes(reply[10:])
+            return (header, reply)
+        else:
+            return reply
     elif ordinal == TPM_CC_ContextSave:
         (h,) = struct.unpack('>I', command[10:14])
         h &= 0x00ffffff
@@ -380,7 +394,7 @@
         (h,) = struct.unpack('>I', command[10:14])
         if len(sessions) == 0:
             return None
-        if not session.updatearea(1, command):
+        if session.updatearea(1, command) == None:
             return None
         print(f"Create intercepted")
         (authorizationsize, ah) = struct.unpack('>II',command[14:22])