blob: 425847112314ae50b19cef32c894435c7a35e824 [file] [log] [blame]
from argparse import ArgumentParser
from argparse import FileType
import os
import sys
import tpm2
from tpm2 import ProtocolError
import unittest
import logging
import tss2
pwd1 = "wibble"
pwd2 = "newpassword"
class SessionTest(unittest.TestCase):
def setUp(self):
self.c = tss2.Client()
def tearDown(self):
self.c.close()
# open handles until failure. Return the ones we got
def open_until(self, func):
ha = []
try:
for i in range(0, 10):
h = func()
print "Handle is %08x" % h
ha.append(h)
except tss2.tpm_error, e:
if (e.rc != tpm2.TPM2_RC_SESSION_MEMORY and
e.rc != tpm2.TPM2_RC_OBJECT_MEMORY):
raise e
return ha
def open_handles(self):
def func():
return self.c.start_session(tpm2.TPM2_SE_HMAC, self.c.SRK)
return self.open_until(func)
def open_transients(self):
def func():
k = self.c.create_rsa(self.c.SRK, None)
return self.c.load(self.c.SRK, k.outPrivate, k.outPublic, None)
return self.open_until(func)
def test_handle_clearing(self):
t1 = self.open_transients()
h1 = self.open_handles()
print "Opened {} transients and {} handles".format(len(t1), len(h1))
self.c.close()
self.c = tss2.Client()
h2 = self.open_handles()
t2 = self.open_transients()
print "Opened {} transients and {} handles".format(len(t2), len(h2))
self.assertEqual(len(h1), len(h2))
self.assertEqual(len(t1), len(t2))
def test_transients(self):
k = self.open_transients()
self.c.flush_context(k[0])
self.c.change_auth(self.c.SRK, k[1], None, pwd1)
fail = False
try:
self.c.change_auth(self.c.SRK, k[0], None, pwd1)
except tss2.tpm_error, e:
print "Expected failure {}".format(e)
fail = True
self.assertTrue(fail)
l = self.open_transients()
self.assertEqual(len(l), 1)
def test_handle_flush_on_space_close(self):
i = self.open_handles()
print "Ran out of handles at %d" %len(i)
self.c.close()
self.c = tss2.Client()
# closing and reopening a space session should clear out our handles
j = self.open_handles()
print "Ran out of handles at %d" %len(j)
self.assertNotEqual(len(i), 0)
self.assertEqual(len(i), len(j))
def test_flush(self):
i = self.open_handles()
print "opened %d handles" % len(i)
self.c.flush_context(i[0])
self.c.flush_context(i[1])
i = self.open_handles()
self.assertEqual(len(i), 2);
def test_session_consumption(self):
self.c.read_public(self.c.SRK)
# authorization hmac session
hmac = self.c.start_session(tpm2.TPM2_SE_HMAC)
# parameter encryption session
enc = self.c.start_session(tpm2.TPM2_SE_HMAC, self.c.SRK)
# fill all remaing handles
i = self.open_handles()
# create rsa key continuing both hmac and encryption sessions
self.c.create_rsa(self.c.SRK, pwd1, hmac, 1, enc, 1)
# should be no handles left
i = self.open_handles()
self.assertEqual(len(i),0)
# now create rsa key continuing hmac and consuming encryption
k = self.c.create_rsa(self.c.SRK, pwd1, hmac, 1, enc, 0)
# now should be one handle remaining
i = self.open_handles()
self.assertEqual(len(i),1)
self.c.flush_context(i[0])
# check the hmac continuation actually works
k = self.c.load(self.c.SRK, k.outPrivate, k.outPublic, None)
print "Loaded key at handle %x" %k
# and finally verify with an authenticated encrypted operation
# consuming both handles
enc = self.c.start_session(tpm2.TPM2_SE_HMAC, k)
self.c.change_auth(self.c.SRK, k, pwd1, pwd2, hmac, 0, enc, 0)
i = self.open_handles()
self.assertEqual(len(i), 2)
def test_space_exhaustion(self):
c = []
h = []
# usually 3 max unsaved and 64 max contexts, so 23*3 = 69 should
# mean that the first 4 are evicted
for i in range(0, 23):
self.c = tss2.Client()
c.append(self.c)
h.append(self.open_handles())
#try to use handle by creating an RSA key this should fail
# because the session was evicted
self.c = c[0]
failed = 0
try:
# hmac only
self.c.create_rsa(self.c.SRK, pwd1, h[0][0], 1)
except tss2.tpm_error,e:
print "Expected Session Failure: {}".format(e)
failed = 1
self.assertEqual(failed, 1)
# pick the latest session and handle and it should succeed
self.c = c[22]
self.c.create_rsa(self.c.SRK, pwd1, h[22][0], 1, h[22][1], 1)
def test_disallow_save_context(self):
h = self.open_handles();
s = self.c
failure = 0
try:
o = self.c.context_save(h[0]);
except tss2.tpm_error,e:
failure = 1
self.assertEqual(failure, 1)
def test_gap_error_first(self):
c = tss2.Client()
h = c.start_session(tpm2.TPM2_SE_HMAC)
print "Handle %08x" %h
for i in range(0,256):
j = self.c.start_session(tpm2.TPM2_SE_HMAC)
print "Flush Handle %08x" %j
self.c.flush_context(j)
c.flush_context(h)
def test_gap_error_last(self):
c = tss2.Client()
h = c.start_session(tpm2.TPM2_SE_HMAC)
print "Handle %08x" %h
t = self.open_handles();
self.c.flush_context(t[len(t)-1])
for i in range(0,256):
j = self.c.start_session(tpm2.TPM2_SE_HMAC)
print "Flush Handle %08x" %j
self.c.flush_context(j)
c.flush_context(h)
if __name__ == '__main__':
unittest.main()