Edit file File name : test_win32_com_foreign_func.py Content :import ctypes import gc import sys import unittest from ctypes import POINTER, byref, c_void_p from ctypes.wintypes import BYTE, DWORD, WORD if sys.platform != "win32": raise unittest.SkipTest("Windows-specific test") from _ctypes import COMError, CopyComPointer from ctypes import HRESULT COINIT_APARTMENTTHREADED = 0x2 CLSCTX_SERVER = 5 S_OK = 0 OUT = 2 TRUE = 1 E_NOINTERFACE = -2147467262 class GUID(ctypes.Structure): # https://learn.microsoft.com/en-us/windows/win32/api/guiddef/ns-guiddef-guid _fields_ = [ ("Data1", DWORD), ("Data2", WORD), ("Data3", WORD), ("Data4", BYTE * 8), ] def create_proto_com_method(name, index, restype, *argtypes): proto = ctypes.WINFUNCTYPE(restype, *argtypes) def make_method(*args): foreign_func = proto(index, name, *args) def call(self, *args, **kwargs): return foreign_func(self, *args, **kwargs) return call return make_method def create_guid(name): guid = GUID() # https://learn.microsoft.com/en-us/windows/win32/api/combaseapi/nf-combaseapi-clsidfromstring ole32.CLSIDFromString(name, byref(guid)) return guid def is_equal_guid(guid1, guid2): # https://learn.microsoft.com/en-us/windows/win32/api/objbase/nf-objbase-isequalguid return ole32.IsEqualGUID(byref(guid1), byref(guid2)) ole32 = ctypes.oledll.ole32 IID_IUnknown = create_guid("{00000000-0000-0000-C000-000000000046}") IID_IStream = create_guid("{0000000C-0000-0000-C000-000000000046}") IID_IPersist = create_guid("{0000010C-0000-0000-C000-000000000046}") CLSID_ShellLink = create_guid("{00021401-0000-0000-C000-000000000046}") # https://learn.microsoft.com/en-us/windows/win32/api/unknwn/nf-unknwn-iunknown-queryinterface(refiid_void) proto_query_interface = create_proto_com_method( "QueryInterface", 0, HRESULT, POINTER(GUID), POINTER(c_void_p) ) # https://learn.microsoft.com/en-us/windows/win32/api/unknwn/nf-unknwn-iunknown-addref proto_add_ref = create_proto_com_method("AddRef", 1, ctypes.c_long) # https://learn.microsoft.com/en-us/windows/win32/api/unknwn/nf-unknwn-iunknown-release proto_release = create_proto_com_method("Release", 2, ctypes.c_long) # https://learn.microsoft.com/en-us/windows/win32/api/objidl/nf-objidl-ipersist-getclassid proto_get_class_id = create_proto_com_method( "GetClassID", 3, HRESULT, POINTER(GUID) ) def create_shelllink_persist(typ): ppst = typ() # https://learn.microsoft.com/en-us/windows/win32/api/combaseapi/nf-combaseapi-cocreateinstance ole32.CoCreateInstance( byref(CLSID_ShellLink), None, CLSCTX_SERVER, byref(IID_IPersist), byref(ppst), ) return ppst class ForeignFunctionsThatWillCallComMethodsTests(unittest.TestCase): def setUp(self): # https://learn.microsoft.com/en-us/windows/win32/api/combaseapi/nf-combaseapi-coinitializeex ole32.CoInitializeEx(None, COINIT_APARTMENTTHREADED) def tearDown(self): # https://learn.microsoft.com/en-us/windows/win32/api/combaseapi/nf-combaseapi-couninitialize ole32.CoUninitialize() gc.collect() def test_without_paramflags_and_iid(self): class IUnknown(c_void_p): QueryInterface = proto_query_interface() AddRef = proto_add_ref() Release = proto_release() class IPersist(IUnknown): GetClassID = proto_get_class_id() ppst = create_shelllink_persist(IPersist) clsid = GUID() hr_getclsid = ppst.GetClassID(byref(clsid)) self.assertEqual(S_OK, hr_getclsid) self.assertEqual(TRUE, is_equal_guid(CLSID_ShellLink, clsid)) self.assertEqual(2, ppst.AddRef()) self.assertEqual(3, ppst.AddRef()) punk = IUnknown() hr_qi = ppst.QueryInterface(IID_IUnknown, punk) self.assertEqual(S_OK, hr_qi) self.assertEqual(3, punk.Release()) with self.assertRaises(OSError) as e: punk.QueryInterface(IID_IStream, IUnknown()) self.assertEqual(E_NOINTERFACE, e.exception.winerror) self.assertEqual(2, ppst.Release()) self.assertEqual(1, ppst.Release()) self.assertEqual(0, ppst.Release()) def test_with_paramflags_and_without_iid(self): class IUnknown(c_void_p): QueryInterface = proto_query_interface(None) AddRef = proto_add_ref() Release = proto_release() class IPersist(IUnknown): GetClassID = proto_get_class_id(((OUT, "pClassID"),)) ppst = create_shelllink_persist(IPersist) clsid = ppst.GetClassID() self.assertEqual(TRUE, is_equal_guid(CLSID_ShellLink, clsid)) punk = IUnknown() hr_qi = ppst.QueryInterface(IID_IUnknown, punk) self.assertEqual(S_OK, hr_qi) self.assertEqual(1, punk.Release()) with self.assertRaises(OSError) as e: ppst.QueryInterface(IID_IStream, IUnknown()) self.assertEqual(E_NOINTERFACE, e.exception.winerror) self.assertEqual(0, ppst.Release()) def test_with_paramflags_and_iid(self): class IUnknown(c_void_p): QueryInterface = proto_query_interface(None, IID_IUnknown) AddRef = proto_add_ref() Release = proto_release() class IPersist(IUnknown): GetClassID = proto_get_class_id(((OUT, "pClassID"),), IID_IPersist) ppst = create_shelllink_persist(IPersist) clsid = ppst.GetClassID() self.assertEqual(TRUE, is_equal_guid(CLSID_ShellLink, clsid)) punk = IUnknown() hr_qi = ppst.QueryInterface(IID_IUnknown, punk) self.assertEqual(S_OK, hr_qi) self.assertEqual(1, punk.Release()) with self.assertRaises(COMError) as e: ppst.QueryInterface(IID_IStream, IUnknown()) self.assertEqual(E_NOINTERFACE, e.exception.hresult) self.assertEqual(0, ppst.Release()) class CopyComPointerTests(unittest.TestCase): def setUp(self): ole32.CoInitializeEx(None, COINIT_APARTMENTTHREADED) class IUnknown(c_void_p): QueryInterface = proto_query_interface(None, IID_IUnknown) AddRef = proto_add_ref() Release = proto_release() class IPersist(IUnknown): GetClassID = proto_get_class_id(((OUT, "pClassID"),), IID_IPersist) self.IUnknown = IUnknown self.IPersist = IPersist def tearDown(self): ole32.CoUninitialize() gc.collect() def test_both_are_null(self): src = self.IPersist() dst = self.IPersist() hr = CopyComPointer(src, byref(dst)) self.assertEqual(S_OK, hr) self.assertIsNone(src.value) self.assertIsNone(dst.value) def test_src_is_nonnull_and_dest_is_null(self): # The reference count of the COM pointer created by `CoCreateInstance` # is initially 1. src = create_shelllink_persist(self.IPersist) dst = self.IPersist() # `CopyComPointer` calls `AddRef` explicitly in the C implementation. # The refcount of `src` is incremented from 1 to 2 here. hr = CopyComPointer(src, byref(dst)) self.assertEqual(S_OK, hr) self.assertEqual(src.value, dst.value) # This indicates that the refcount was 2 before the `Release` call. self.assertEqual(1, src.Release()) clsid = dst.GetClassID() self.assertEqual(TRUE, is_equal_guid(CLSID_ShellLink, clsid)) self.assertEqual(0, dst.Release()) def test_src_is_null_and_dest_is_nonnull(self): src = self.IPersist() dst_orig = create_shelllink_persist(self.IPersist) dst = self.IPersist() CopyComPointer(dst_orig, byref(dst)) self.assertEqual(1, dst_orig.Release()) clsid = dst.GetClassID() self.assertEqual(TRUE, is_equal_guid(CLSID_ShellLink, clsid)) # This does NOT affects the refcount of `dst_orig`. hr = CopyComPointer(src, byref(dst)) self.assertEqual(S_OK, hr) self.assertIsNone(dst.value) with self.assertRaises(ValueError): dst.GetClassID() # NULL COM pointer access # This indicates that the refcount was 1 before the `Release` call. self.assertEqual(0, dst_orig.Release()) def test_both_are_nonnull(self): src = create_shelllink_persist(self.IPersist) dst_orig = create_shelllink_persist(self.IPersist) dst = self.IPersist() CopyComPointer(dst_orig, byref(dst)) self.assertEqual(1, dst_orig.Release()) self.assertEqual(dst.value, dst_orig.value) self.assertNotEqual(src.value, dst.value) hr = CopyComPointer(src, byref(dst)) self.assertEqual(S_OK, hr) self.assertEqual(src.value, dst.value) self.assertNotEqual(dst.value, dst_orig.value) self.assertEqual(1, src.Release()) clsid = dst.GetClassID() self.assertEqual(TRUE, is_equal_guid(CLSID_ShellLink, clsid)) self.assertEqual(0, dst.Release()) self.assertEqual(0, dst_orig.Release()) if __name__ == '__main__': unittest.main() Save