Add session gapping tests
diff --git a/tpm2.py b/tpm2.py
index ac420e2..3a2d8fa 100644
--- a/tpm2.py
+++ b/tpm2.py
@@ -16,6 +16,8 @@
TPM2_CC_CREATE = 0x0153
TPM2_CC_LOAD = 0x0157
TPM2_CC_UNSEAL = 0x015E
+TPM2_CC_CONTEXT_LOAD = 0x0161
+TPM2_CC_CONTEXT_SAVE = 0x0162
TPM2_CC_FLUSH_CONTEXT = 0x0165
TPM2_CC_READ_PUBLIC = 0x0173
TPM2_CC_START_AUTH_SESSION = 0x0176
@@ -57,8 +59,11 @@
TPM2_RH_LOCKOUT = 0x4000000A
TPM2_RS_PW = 0x40000009
+TPM2_RC_HANDLE = 0x000008B
TPM2_RC_SIZE = 0x000001D5
+TPM2_RC_OBJECT_MEMORY = 0x00000902
TPM2_RC_SESSION_MEMORY = 0x00000903
+TPM2_RC_SESSION_HANDLES = 0x00000905
TPM2_RC_AUTH_FAIL = 0x0000098E
TPM2_RC_POLICY_FAIL = 0x0000099D
diff --git a/tpm2_sessions_smoke.py b/tpm2_sessions_smoke.py
index d233d08..4258471 100755
--- a/tpm2_sessions_smoke.py
+++ b/tpm2_sessions_smoke.py
@@ -19,18 +19,58 @@
self.c.close()
# open handles until failure. Return the ones we got
- def open_handles(self):
+ def open_until(self, func):
ha = []
try:
for i in range(0, 10):
- h = self.c.start_session(tpm2.TPM2_SE_HMAC)
+ h = func()
print "Handle is %08x" % h
ha.append(h)
except tss2.tpm_error, e:
- if (e.rc != tpm2.TPM2_RC_SESSION_MEMORY):
+ if (e.rc != tpm2.TPM2_RC_SESSION_MEMORY and
+ e.rc != tpm2.TPM2_RC_OBJECT_MEMORY):
raise e
return ha
+ def open_handles(self):
+ def func():
+ return self.c.start_session(tpm2.TPM2_SE_HMAC, self.c.SRK)
+
+ return self.open_until(func)
+
+ def open_transients(self):
+ def func():
+ k = self.c.create_rsa(self.c.SRK, None)
+ return self.c.load(self.c.SRK, k.outPrivate, k.outPublic, None)
+
+ return self.open_until(func)
+
+ def test_handle_clearing(self):
+ t1 = self.open_transients()
+ h1 = self.open_handles()
+ print "Opened {} transients and {} handles".format(len(t1), len(h1))
+ self.c.close()
+ self.c = tss2.Client()
+ h2 = self.open_handles()
+ t2 = self.open_transients()
+ print "Opened {} transients and {} handles".format(len(t2), len(h2))
+ self.assertEqual(len(h1), len(h2))
+ self.assertEqual(len(t1), len(t2))
+
+ def test_transients(self):
+ k = self.open_transients()
+ self.c.flush_context(k[0])
+ self.c.change_auth(self.c.SRK, k[1], None, pwd1)
+ fail = False
+ try:
+ self.c.change_auth(self.c.SRK, k[0], None, pwd1)
+ except tss2.tpm_error, e:
+ print "Expected failure {}".format(e)
+ fail = True
+ self.assertTrue(fail)
+ l = self.open_transients()
+ self.assertEqual(len(l), 1)
+
def test_handle_flush_on_space_close(self):
i = self.open_handles()
print "Ran out of handles at %d" %len(i)
@@ -78,9 +118,70 @@
self.c.change_auth(self.c.SRK, k, pwd1, pwd2, hmac, 0, enc, 0)
i = self.open_handles()
self.assertEqual(len(i), 2)
-
-
+ def test_space_exhaustion(self):
+ c = []
+ h = []
+ # usually 3 max unsaved and 64 max contexts, so 23*3 = 69 should
+ # mean that the first 4 are evicted
+ for i in range(0, 23):
+ self.c = tss2.Client()
+ c.append(self.c)
+ h.append(self.open_handles())
+
+ #try to use handle by creating an RSA key this should fail
+ # because the session was evicted
+ self.c = c[0]
+ failed = 0
+ try:
+ # hmac only
+ self.c.create_rsa(self.c.SRK, pwd1, h[0][0], 1)
+ except tss2.tpm_error,e:
+ print "Expected Session Failure: {}".format(e)
+ failed = 1
+ self.assertEqual(failed, 1)
+ # pick the latest session and handle and it should succeed
+ self.c = c[22]
+ self.c.create_rsa(self.c.SRK, pwd1, h[22][0], 1, h[22][1], 1)
+
+ def test_disallow_save_context(self):
+ h = self.open_handles();
+ s = self.c
+ failure = 0
+ try:
+ o = self.c.context_save(h[0]);
+ except tss2.tpm_error,e:
+ failure = 1
+ self.assertEqual(failure, 1)
+
+ def test_gap_error_first(self):
+ c = tss2.Client()
+ h = c.start_session(tpm2.TPM2_SE_HMAC)
+ print "Handle %08x" %h
+
+ for i in range(0,256):
+ j = self.c.start_session(tpm2.TPM2_SE_HMAC)
+ print "Flush Handle %08x" %j
+ self.c.flush_context(j)
+
+ c.flush_context(h)
+
+ def test_gap_error_last(self):
+ c = tss2.Client()
+ h = c.start_session(tpm2.TPM2_SE_HMAC)
+ print "Handle %08x" %h
+ t = self.open_handles();
+ self.c.flush_context(t[len(t)-1])
+
+ for i in range(0,256):
+ j = self.c.start_session(tpm2.TPM2_SE_HMAC)
+ print "Flush Handle %08x" %j
+ self.c.flush_context(j)
+
+ c.flush_context(h)
+
+
+
if __name__ == '__main__':
unittest.main()
diff --git a/tss2.py b/tss2.py
index 67f5a4f..10a3289 100644
--- a/tss2.py
+++ b/tss2.py
@@ -64,7 +64,16 @@
class TPM2B_PRIVATE(ctypes.Structure):
_fields_ = [("size", ctypes.c_uint16),
("buffer", ctypes.c_uint8 * 1166)]
-
+
+class TPM2B_CONTEXT_DATA(ctypes.Structure):
+ _fields_ = [("size", ctypes.c_uint16),
+ ("buffer", ctypes.c_uint8 * 2048)]
+
+class TPMS_CONTEXT(ctypes.Structure):
+ _fields_ = [("sequence", ctypes.c_uint64),
+ ("savedHandle", ctypes.c_uint32),
+ ("hierarchy", ctypes.c_uint32),
+ ("contextBlob", TPM2B_CONTEXT_DATA)]
class StartAuthSession_In(ctypes.Structure):
_fields_ = [("tpmKey", ctypes.c_uint32),
@@ -125,9 +134,15 @@
class ObjectChangeAuth_Out(ctypes.Structure):
_fields_ = [("outPrivate", TPM2B_PRIVATE)]
+class ContextSave_In(ctypes.Structure):
+ _fields_ = [("saveHandle", ctypes.c_uint32)]
+
class tpm_error(Exception):
def __init__(self, rc):
+ self.raw_rc = rc
+ if ((rc & 0x80) == 0x80): # RC_FMT1
+ rc = rc & 0xbf
self.rc = rc
def __str__(self):
@@ -136,7 +151,7 @@
num = ctypes.c_char_p()
lib.TSS_ResponseCode_toString(ctypes.byref(msg),
ctypes.byref(submsg),
- ctypes.byref(num), self.rc);
+ ctypes.byref(num), self.raw_rc);
return "%s%s%s" % (msg.value, submsg.value, num.value)
class Client:
@@ -150,12 +165,8 @@
# we need the public area in the context for salted parameter encryption
self.read_public(self.SRK)
- print self.ctx
-
def close(self):
lib.TSS_Delete(self.ctx)
- print "closed"
- print self.ctx
def TSS_Execute(self, out, inp, extra, ordinal, *sessions):
rc = lib.TSS_Execute(self.ctx, out, inp,
@@ -206,7 +217,6 @@
inp.parentHandle = parent
if (auth != None):
lenauth = len(auth)
- print "AUTh len is %d" %lenauth
inp.inSensitive.sensitive.userAuth.b[0:lenauth] = bytearray(auth)
inp.inSensitive.sensitive.userAuth.s = lenauth
inp.inPublic.publicArea.Type = tpm2.TPM2_ALG_RSA
@@ -258,3 +268,24 @@
tpm2.TPM2_RH_NULL, None, 0)
return out.outPrivate
+
+ def context_save(self, handle):
+ inp = ContextSave_In()
+ out = TPMS_CONTEXT()
+
+ inp.saveHandle = handle
+
+ self.TSS_Execute(ctypes.byref(out), ctypes.byref(inp), None,
+ tpm2.TPM2_CC_CONTEXT_SAVE,
+ tpm2.TPM2_RH_NULL, None, 0)
+
+ return out
+
+ def context_load(self, inp):
+ out = ContextSave_In()
+
+ self.TSS_Execute(ctypes.byref(out), ctypes.byref(inp), None,
+ tpm2.TPM2_CC_CONTEXT_LOAD,
+ tpm2.TPM2_RH_NULL, None, 0)
+
+ return out.saveHandle