blob: af927eb978950713333539a6d7aad5b04e4064b5 [file] [log] [blame]
##
# TPM Interposer implementation
# Copyright (C) 2025 by James.Bottomley@HansenPartnership.com
#
# SPDX-License-Identifier: GPL-2.0
##
import socket
import struct
import hmac
import hashlib
import os
import base64
import time
from cryptography.hazmat.primitives.asymmetric import ec
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives.serialization import Encoding
from cryptography.hazmat.primitives.serialization import PublicFormat
from Crypto.Cipher import AES
HOST = "localhost"
PORT = 2325
TPMPORT = 2321
##
# use a single key for substitution
##
privkey = ec.generate_private_key(ec.SECP256R1())
pubkey = privkey.public_key()
pn = pubkey.public_numbers()
#print(f"X = {pn.x.to_bytes(32, 'big').hex()}")
#print(f"Y = {pn.y.to_bytes(32, 'big').hex()}")
#print(f"DER={pubkey.public_bytes(Encoding.DER, PublicFormat.SubjectPublicKeyInfo).hex()}")
new_unique = struct.pack(">H32sH32s", 32, pn.x.to_bytes(32, 'big'),
32, pn.y.to_bytes(32, 'big'))
##
# TPM definitions
##
TPM_ALG_ECC = 0x23
TPM_ALG_SHA256 = 0x0b
TPM_CC_CreatePrimary = 0x131
TPM_CC_Startup = 0x144
TPM_CC_Create = 0x153
TPM_CC_Load = 0x157
TPM_CC_Unseal = 0x15e
TPM_CC_ContextLoad = 0x161
TPM_CC_ContextSave = 0x162
TPM_CC_FlushContext = 0x165
TPM_CC_ReadPublic = 0x173
TPM_CC_StartAuthSession = 0x176
TPM_CC_PolicyPCR = 0x17f
TPM_CC_PolicyGetDigest = 0x189
TPM_SE_HMAC = 0x00
TPM_SE_POLICY = 0x01
TPM_ST_NO_SESSIONS = 0x8001
TPM_ST_SESSIONS = 0x8002
TPM_RH_NULL = 0x40000007
TPM_RS_PW = 0x40000009
TPM_CONTEXT_HANDLE_HMAC = 0x02000000
TPM_CONTEXT_HANDLE_POLICY = 0x03000000
TPMA_SESSION_CONTINUESESSION = 0x01
TPMA_SESSION_DECRYPT = 0x20
TPMA_SESSION_ENCRYPT = 0x40
TPM_RC_REFERENCE_H0 = 0x910
# commands for TPM emulators
TPM_SESSION_END = 20
SESSION_CONTEXT_MAGIC = 0xffff
CONTEXT_MAGIC = b'CONTEXTMAGICSTRING'
primaries = dict()
sessions = dict()
create_session = None
save_policy = None
save_create = None
save_load = 0
save_sensitive = None
save_name = None
##
# Abbreviated KDFe knowing we take in 32 byte points and only
# need a 32 byte key
##
def KDFe(z, str, k1, k2):
digest = hashes.Hash(hashes.SHA256())
# counter set to big endian 1
digest.update(struct.pack(">I", 1))
digest.update(z)
digest.update(str)
digest.update(k1.public_numbers().x.to_bytes(32, 'big'))
digest.update(k2.public_numbers().x.to_bytes(32, 'big'))
return digest.finalize()
def KDFa(key, label, u, v):
h = hmac.new(key, struct.pack(">I", 1), hashlib.sha256)
h.update(label)
h.update(u)
h.update(v)
h.update(struct.pack(">I", 256))
return h.digest()
class session(object):
"""Session"""
@classmethod
def updateone(cls, buf):
(s, len1) = struct.unpack(">IH", buf[:6])
s &= 0x00ffffff
nonce = buf[6:6+len1]
attributes = buf[6+len1:6+len1+1]
(len2,) = struct.unpack(">H", buf[6+len1+1:6+len1+3])
hmac = buf[6+len1+3:6+len1+3+len2]
if s in sessions:
found = sessions[s]
found.update(nonce, attributes, hmac)
else:
found = None
return (buf[6+len1+3+len2:], found)
def update(self, nonce, attributes, hmac):
self.attributes = attributes
self.hmac = hmac
self.nonceCaller = nonce
@classmethod
def updatearea(cls, handles, cmd):
(tag,) = struct.unpack(">H", cmd[:2])
if tag != TPM_ST_SESSIONS:
return None
digest = hashes.Hash(hashes.SHA256())
digest.update(cmd[6:10])
(size,) = struct.unpack(">I", cmd[10+handles*4:14+handles*4])
buf = cmd[14+handles*4:14+size+handles*4]
for x in range(0, handles):
(h,) = struct.unpack(">I", cmd[10+x*4:14+x*4])
if (h & 0xff000000) in (0x0000000, 0x02000000, 0x03000000, 0x40000000):
digest.update(cmd[10+x*4:14+x*4])
elif h in primaries:
digest.update(primaries[h].name())
elif h == save_load:
digest.update(save_name)
else:
print(f"ERROR: missing primary handle {hex(h)}")
digest.update(cmd[14+size+handles*4:])
phash = digest.finalize()
found_list = list()
while True:
(buf, found) = session.updateone(buf)
if found != None:
found_list.append(found)
if len(buf) == 0:
break
if found == None:
return None
if len(found_list) > 1:
found = found_list[0]
h = hmac.new(found.sessionKey, phash, hashlib.sha256)
h.update(found.nonceCaller)
h.update(found.nonceTPM)
if len(found_list) > 1:
h.update(found_list[1].nonceTPM)
h.update(found.attributes)
mac = h.digest()
if mac != found.hmac:
print(f"HMAC MISMATCH {mac.hex()} != {found.hmac.hex()}")
if len(found_list) > 1:
found_list[1].decrypt(handles, cmd)
else:
found.decrypt(handles, cmd)
if len(found_list) > 1:
return found_list
else:
return found
def decrypt(self, handles, cmd):
if self.attributes[0] & TPMA_SESSION_DECRYPT != TPMA_SESSION_DECRYPT:
return
o = 10 + handles*4
(size,) = struct.unpack(">I", cmd[o:o + 4])
o += size + 4
parameter = tpm2b.scan(cmd[o:])
aeskeyiv = KDFa(self.sessionKey, b'CFB\x00', self.nonceCaller, self.nonceTPM)
cipher = AES.new(aeskeyiv[0:16], AES.MODE_CFB, iv = aeskeyiv[16:32], segment_size=128)
l = len(parameter.buf)
o += 2
cmd[o:o+l] = cipher.decrypt(parameter.buf)
def encrypt(self, reply, offset):
if self.attributes[0] & TPMA_SESSION_ENCRYPT != TPMA_SESSION_ENCRYPT:
return
aeskeyiv = KDFa(self.sessionKey, b'CFB\x00', self.nonceTPM, self.nonceCaller)
cipher = AES.new(aeskeyiv[0:16], AES.MODE_CFB, iv = aeskeyiv[16:32], segment_size=128)
param = tpm2b.scan(reply[14 + offset:])
paramlen = len(param.buf)
param = cipher.encrypt(param.buf)
reply[16+offset:16+offset+paramlen] = param
def addresponse(self, ordinal, reply, handles = 0, encryptsession = None):
offset = 4 * handles
(response, ) = struct.unpack('>I', reply[6:10])
# new nonceTPM for reply
self.nonceTPM = os.urandom(self.noncesize)
if response == 0:
(paramsize, ) = struct.unpack('>I', reply[10 + offset:14 + offset])
if encryptsession != None:
encryptsession.nonceTPM = os.urandom(encryptsession.noncesize)
encryptsession.encrypt(reply, offset)
else:
self.encrypt(reply, offset)
responseOffset = 14 + paramsize + offset
else:
paramsize = 0
responseOffset = 10
digest = hashes.Hash(hashes.SHA256())
# rphash: response code, command code, parameters
digest.update(struct.pack(">II", response, ordinal))
if paramsize != 0:
digest.update(reply[14+offset:14+offset+paramsize])
rhash = digest.finalize()
h = hmac.new(self.sessionKey, rhash, hashlib.sha256)
h.update(self.nonceTPM)
h.update(self.nonceCaller)
h.update(self.attributes)
mac = bytes(tpm2b(h.digest()))
reply[responseOffset:] = bytes(tpm2b(self.nonceTPM)) + self.attributes + bytes(mac)
if encryptsession != None:
h = hmac.new(encryptsession.sessionKey, rhash, hashlib.sha256)
h.update(encryptsession.nonceTPM)
h.update(encryptsession.nonceCaller)
h.update(encryptsession.attributes)
mac = bytes(tpm2b(h.digest()))
reply.extend(bytes(tpm2b(encryptsession.nonceTPM)) + encryptsession.attributes + mac)
reply[2:6] = struct.pack('>I', len(reply))
@classmethod
def fromcommand(cls, buf):
s = cls()
nonce = tpm2b.scan(buf[18:])
s.noncesize = len(nonce.buf)
s.nonceCaller = nonce.buf
s.nonceTPM = os.urandom(s.noncesize)
secret = tpm2b.scan(buf[18 + len(nonce):])
setype = buf[18 + len(nonce) + len(secret)]
if setype == TPM_SE_HMAC:
s.session |= TPM_CONTEXT_HANDLE_HMAC
elif setype == TPM_SE_POLICY:
s.session |= TPM_CONTEXT_HANDLE_POLICY
x = int.from_bytes(secret.buf[2:2+32], 'big')
y = int.from_bytes(secret.buf[2+2+32:2+2+32+32], 'big')
k = ec.EllipticCurvePublicNumbers(x, y, ec.SECP256R1()).public_key()
shared_key = privkey.exchange(ec.ECDH(), k)
s.salt = KDFe(shared_key, b'SECRET\x00', k, pubkey)
s.sessionKey = KDFa(s.salt, b'ATH\x00', s.nonceTPM, s.nonceCaller)
return s
def __init__(self):
global sessions
s_num = SESSION_CONTEXT_MAGIC
while s_num in sessions:
s_num -= 1
self.session = s_num
self.saved = False
sessions[s_num] = self
def reply(self):
h = struct.pack('>I', self.session)
reply = h + bytes(tpm2b(self.nonceTPM))
return reply
class tpm2b(object):
"""TPM2B"""
def __init__(self, buf):
self.buf = buf
def fmt(self):
return ">H%us" % len(self.buf)
def __bytes__(self):
return struct.pack(self.fmt(), len(self.buf), self.buf)
def __len__(self):
return 2 + len(self.buf)
@classmethod
def scan(cls, b):
(len,) = struct.unpack(">H", b[:2])
return cls(b[2:2+len])
class tpmt_public(object):
"""TPMT_PUBLIC"""
def __init__(self, Type, nameAlg, objectAttributes,
authPolicy=bytes(), parameters=bytes(),
unique=bytes()):
self.Type = Type
self.nameAlg = nameAlg
self.objectAttributes = objectAttributes
self.authPolicy = authPolicy
self.parameters = parameters
self.unique = unique
def fmt(self):
return '>HHI%us%us%us' % (len(self.authPolicy), len(self.parameters),
len(self.unique))
def __bytes__(self):
return struct.pack(self.fmt(),
self.Type, self.nameAlg, self.objectAttributes,
self.authPolicy, self.parameters, self.unique)
def __len__(self):
return struct.calcsize(self.fmt())
def name(self):
digest = hashes.Hash(hashes.SHA256())
digest.update(bytes(self))
digest = digest.finalize()
return struct.pack(">H%us" % len(digest), self.nameAlg, digest)
@classmethod
def from2b(cls, buf):
(size, Type, nameAlg, objectAttributes) = struct.unpack('>HHHI', buf[:10])
rest = buf[10:size + 2]
# we only process ECC at the moment (parameters size depends on this)
assert(Type == TPM_ALG_ECC)
# and only sha256 name algorithm
assert(nameAlg == TPM_ALG_SHA256)
(size,) = struct.unpack('>H', rest[:2])
authPolicy = rest[:size + 2]
rest = rest[size + 2:]
parameters = rest[:12]
rest = rest[12:]
unique = rest
newpublic = cls(Type, nameAlg, objectAttributes, authPolicy,
parameters, unique)
return newpublic
def preprocess(command):
global primaries, sessions, save_create, save_load
(tag,length,ordinal) = struct.unpack('>HII', command[:10])
if ordinal == TPM_CC_StartAuthSession:
(tpmkey, bind) = struct.unpack('>II', command[10:18])
print(f"Start Auth Session tpmkey={hex(tpmkey)}, bind={hex(bind)}, tag={hex(tag)}")
if tpmkey not in primaries:
print("TPM key is not substituted")
return None
needHmac = session.updatearea(2, command)
if needHmac != None:
(authorizationsize,) = struct.unpack('>I',command[18:22])
del command[18:18+authorizationsize+4]
s = session.fromcommand(command)
print(f"Session salt encryption key is substituted, {hex(s.session)}")
reply = s.reply();
if needHmac != None:
reply = bytearray(reply)
reply[0:0] = struct.pack('>HII', TPM_ST_SESSIONS, 0, 0)
# insert parameterSize after handle
reply[14:14] = struct.pack('>I', len(reply[14:]))
needHmac.addresponse(TPM_CC_StartAuthSession, reply, 1)
header = bytes(reply[:10])
reply = bytes(reply[10:])
return (header, reply)
else:
return reply
elif ordinal == TPM_CC_ContextSave:
(h,) = struct.unpack('>I', command[10:14])
if (h & 0x00ffffff) not in sessions:
if (h & 0x00fffff0) == 0xfff0:
# special synthetic return for the kernel resource
# manager (real TPM would give the wrong error)
return (struct.pack('>HII', TPM_ST_NO_SESSIONS, 10, TPM_RC_REFERENCE_H0), b'')
return None
sessions[h & 0x00ffffff].saved = True
reply = struct.pack('>QII', 12345, h, TPM_RH_NULL)
reply = reply + bytes(tpm2b(CONTEXT_MAGIC))
return reply
elif ordinal == TPM_CC_ContextLoad:
(seq, h, hier) = struct.unpack('>QII', command[10:26])
hb = h & 0x00ffffff
if hb not in sessions:
return None
sessions[hb].saved = False
return struct.pack('>I', h)
elif ordinal == TPM_CC_FlushContext:
(h,) = struct.unpack('>I', command[10:14])
h &= 0x00ffffff
if h in sessions:
if (sessions[h].session & 0xff000000 == 0x03000000) or not sessions[h].saved:
del sessions[h]
return b''
elif ordinal == TPM_CC_Create:
global create_session, save_policy, save_sensitive
(h,) = struct.unpack('>I', command[10:14])
if len(sessions) == 0:
return None
if session.updatearea(1, command) == None:
return None
print(f"Create intercepted, parent {hex(h)}")
(authorizationsize, ah) = struct.unpack('>II',command[14:22])
ah &= 0x00ffffff
parameter = tpm2b.scan(command[18 + authorizationsize:])
auth = tpm2b.scan(parameter.buf)
secret = tpm2b.scan(parameter.buf[len(auth.buf)+2:])
save_sensitive = bytes(secret)
if auth.buf == b'':
print("authorization is empty")
else:
print(f"auth = {auth.buf.hex()}")
#print(f"secret= {base64.standard_b64encode(secret.buf)}")
print(f"secret={base64.b64encode(secret.buf)}")
save_policy = tpm2b.scan(command[18 + authorizationsize + len(parameter) + 10:])
# replace session with TPM_RS_PW
(l,) = struct.unpack('>I', command[2:6])
newauth = struct.pack('>IHcH', TPM_RS_PW, 0, b'\x00', 0)
command[18:18 + authorizationsize] = newauth
command[14:18] = struct.pack('>I', len(newauth))
l = l - authorizationsize + len(newauth)
command[2:6] = struct.pack('>I', l)
create_session = sessions[ah]
elif ordinal == TPM_CC_Startup:
primaries = dict()
elif ordinal == TPM_CC_PolicyPCR:
if len(sessions) == 0:
return None
return b''
elif ordinal == TPM_CC_PolicyGetDigest:
if save_policy == None:
return None
return bytes(save_policy)
elif ordinal == TPM_CC_Unseal:
(h,) = struct.unpack('>I', command[10:14])
if h != save_load:
return None
ss = session.updatearea(1, command)
print(f"Intercepting Unseal to fake response")
if ss == None:
return None
reply = bytearray()
reply.extend(struct.pack('>HII', TPM_ST_SESSIONS, 0, 0))
reply.extend(struct.pack('>I', len(save_sensitive)))
reply.extend(save_sensitive)
ss[0].addresponse(TPM_CC_Unseal, reply, 0, ss[1])
return (reply[:10], reply[10:])
return None
def postprocess(command, reply):
global save_create, save_load, save_name
(tag,length,ordinal) = struct.unpack('>HII', command[:10])
if ordinal == TPM_CC_CreatePrimary:
(h,) = struct.unpack('>I', command[10:14])
print(f"Create Primary {hex(h)}")
elif ordinal == TPM_CC_ReadPublic:
(h,) = struct.unpack('>I', command[10:14])
primary = tpmt_public.from2b(reply[10:]);
#replace the public key
primary.unique = new_unique
reply[12:12+len(primary)] = bytes(primary)
#replace the name
reply[12+len(primary):48+len(primary)] = bytes(tpm2b(primary.name()))
#print(f"Read Public {hex(h)}, parameters={primary.parameters.hex()}, unique={new_unique.hex()}({len(new_unique)})")
primaries[h] = primary
print(f"Read Public {hex(h)} substituted")
elif ordinal == TPM_CC_Load:
(h,authsize) = struct.unpack('>II', command[10:18])
private = tpm2b.scan(command[18+authsize:])
public = tpm2b.scan(command[18+authsize+len(private):])
digest = hashes.Hash(hashes.SHA256())
digest.update(public.buf)
save_name = b'\x00\x0B' + digest.finalize()
if save_create != None and private.buf == save_create.buf:
(save_load,) = struct.unpack('>I', reply[10:14])
print(f"Load with parent {hex(h)} intercepted")
elif ordinal == TPM_CC_Create:
global create_session
if create_session == None:
return
(rcode,) = struct.unpack('>I', reply[6:10])
if rcode != 0:
print(f"intercepted TPM2_Create returned {hex(rcode)}")
create_session = None
return
# save the private area
save_create = tpm2b.scan(reply[14:])
create_session.addresponse(TPM_CC_Create, reply)
create_session = None
return
elif ordinal == TPM_CC_ContextSave:
if len(reply) < 22:
return
(h,) = struct.unpack('>I', command[10:14])
(sh,) = struct.unpack('>I', reply[18:22])
if h == save_load:
save_load = sh
elif ordinal == TPM_CC_ContextSave:
if len(reply) < 14:
return
(sh, ) = struct.unpack('>I',command[18:22])
(h, ) = struct.unpack('>I', reply[10:14])
if (h & 0xff000000) == 0x80000000:
if sh == save_load:
save_load = h
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as r:
r.bind((HOST, PORT))
r.listen()
while True:
conn, addr = r.accept()
with conn:
#command
cmdbuf = conn.recv(9)
if len(cmdbuf) != 9:
continue
cmd,locality,size = struct.unpack('>IcI',cmdbuf)
command = conn.recv(size);
if len(command) != size:
continue
command = bytearray(command)
reply = preprocess(command)
command = bytes(command)
cmdbuf = struct.pack('>IcI', cmd, locality, len(command))
if reply is not None:
if not isinstance(reply, bytes):
header = reply[0]
reply = reply[1]
else:
header = struct.pack('>HII', TPM_ST_NO_SESSIONS, 10 + len(reply), 0)
reply = header + reply
size = struct.pack('>I', len(reply))
conn.sendall(size)
conn.sendall(reply)
ack = struct.pack('>I', 0)
conn.sendall(ack)
end = struct.pack('>I', TPM_SESSION_END)
conn.sendall(end)
continue
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.connect((HOST, TPMPORT))
s.sendall(cmdbuf+command)
#reply
replybuf = s.recv(4)
size = struct.unpack('>I', replybuf)[0]
reply = s.recv(size)
if len(reply) != size:
continue
reply = bytearray(reply)
postprocess(command, reply)
reply=bytes(reply)
replybuf = struct.pack('>I', len(reply))
conn.sendall(replybuf+reply)
command = s.recv(4)
if len(command) != 4:
continue
conn.sendall(command)
cmdbuf = conn.recv(4)
# may send a TPM_SESSION_END or may simply close connection
if len(cmdbuf) != 4:
continue
cmd = struct.unpack('>I', cmdbuf)[0]
s.sendall(cmdbuf)