EternalWindows
LSP / LSPサンプル(複数エントリ)

2つ以上のエントリの上にLSPをインストールしている場合、 LSPの実装は前節のコードよりも複雑になります。 たとえば、TCPとUDPという2つのベースプロトコル上にLSPをインストールした場合、 LSPはTCPとUDPのどちらの要求も検出することになりますが、 これにより実際にどちらの要求が送られたのかを判断する処理が必要となります。 問題を分かりやすくするため、次のようなコードを実行するWinsockアプリケーションを例に挙げます。

WSAStartup(MAKEWORD(2, 2), &wsaData);

soc = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP); // WSPStartupとWSPSocketが呼ばれる

soc2 = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP); // WSPSocketが呼ばれる

Winsockアプリケーションは、最初にWSAStartupを呼び出すことになります。 この関数の呼び出しによって、LSPの何らかの関数が呼ばれることはありません。 次に、Winsockアプリケーションは、socketを呼び出しますが、 これによりLSPが呼び出し側プロセスにロードされます。 具体的には、システムに存在するエントリのWSAPROTOCOL_INFOWのメンバがsocketの引数と一致した場合、 そのエントリのDLLがロードされます。 そして、DLLがエクスポートしているWSPStartupが呼ばれ、 ここで0を返した場合にWSPSocketが呼ばれ、 WSPSocketが制御を返した場合にsocketも制御を返すことになります。 ただし、2回目のsocketの呼び出しでは、WSPStartupが呼ばれるとは限りません。 上記コードのsocketは、1回目の呼び出しと引数が同一であり、 この場合は指定したプロトコルの初期化は既に完了しているものと解釈されるため、 WSPSocketのみが呼ばれることになります。

それでは、2回目のsocket呼び出しを1回目と異なる引数で実行した場合は、どうなるのでしょうか。 答えは、WSPStartupとWSPSocketが呼ばれることになります。

WSAStartup(MAKEWORD(2, 2), &wsaData);

soc = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP); // WSPStartupとWSPSocketが呼ばれる

soc2 = socket(AF_INET, SOCK_DGRAM, IPPROTO_UDP); // WSPStartupとWSPSocketが呼ばれる

このように、プロトコルが異なるのにも関わらず同じ関数が呼ばれる場合、 LSP内では現在呼び出されている関数がどのプロトコルによるものかを判断する仕組みが必要になります。 理由は非常に単純で、あるプロトコルの処理を行う場合に、もう片方のプロトコルのデータを使用するわけにはいかないからです。 LSP内で扱われるデータは、それがプロトコル固有のものなのか、あるいは共通して使用できるものかを適切に見極め、 プロトコル固有のデータは構造体などで管理しておくのがよいと思われます。

今回のLSPは、複数のエントリの上にインストールされている場合でも正常に動作するように実装されています。 defファイルについては、前節と同様です。

#include <ws2spi.h>
#include <shlobj.h>

#pragma comment (lib, "ws2_32.lib")

GUID g_guidProvider = {0x024801fd, 0x3797, 0x4dd0, {0x92, 0x38, 0x7c, 0xbe, 0x1e, 0xa7, 0xe9, 0x3d}};

struct SOCKETCONTEXT {
	SOCKET        soc;
	SOCKETCONTEXT *lpNext;
};
typedef struct SOCKETCONTEXT SOCKETCONTEXT;
typedef struct SOCKETCONTEXT *LPSOCKETCONTEXT;

struct LSPCONTEXT {
	WSAPROTOCOL_INFOW entryInfo;
	WSAPROTOCOL_INFOW nextEntryInfo;
	WSPDATA           wspData;
	WSPPROC_TABLE     procTable;
	HMODULE           hmod;
	LPSOCKETCONTEXT   lpSockHeader;
};
typedef struct LSPCONTEXT LSPCONTEXT;
typedef struct LSPCONTEXT *LPLSPCONTEXT;

WSPUPCALLTABLE g_upcallTable = {0};
LPLSPCONTEXT   g_lpLspContext = NULL;
HANDLE         g_hheap = NULL;
int            g_nContextCount = 0;
int            g_nStartupCount = 0;

SOCKET WSPAPI WSPSocket(int af, int type, int protocol, LPWSAPROTOCOL_INFOW lpProtocolInfo, GROUP g, DWORD dwFlags, LPINT lpErrno);
int WSPAPI WSPConnect(SOCKET s, const struct sockaddr *name, int namelen, LPWSABUF lpCallerData, LPWSABUF lpCalleeData, LPQOS lpSQOS, LPQOS lpGQOS, LPINT lpErrno);
SOCKET WSPAPI WSPAccept(SOCKET s, struct sockaddr *addr, LPINT addrlen, LPCONDITIONPROC lpfnCondition, DWORD dwCallbackData, LPINT lpErrno);
int WSPAPI WSPSend(SOCKET s, LPWSABUF lpBuffers, DWORD dwBufferCount, LPDWORD lpNumberOfBytesSent, DWORD dwFlags, LPWSAOVERLAPPED lpOverlapped, LPWSAOVERLAPPED_COMPLETION_ROUTINE lpCompletionRoutine, LPWSATHREADID lpThreadId, LPINT lpErrno);
int WSPAPI WSPRecv(SOCKET s, LPWSABUF lpBuffers, DWORD dwBufferCount, LPDWORD lpNumberOfBytesRecvd, LPDWORD lpFlags, LPWSAOVERLAPPED lpOverlapped, LPWSAOVERLAPPED_COMPLETION_ROUTINE lpCompletionRoutine, LPWSATHREADID lpThreadId, LPINT lpErrno);
int WSPAPI WSPCleanup(LPINT lpErrno);

BOOL CreateLspContext(GUID guidProvider);
LPLSPCONTEXT GetLspContextFromSocket(SOCKET soc);
LPLSPCONTEXT GetLspContextFromId(DWORD dwEntryId);
void WriteLogFile(LPTSTR lpszData);

int WSPAPI WSPStartup(WORD wVersionRequested, LPWSPDATA lpWSPData, LPWSAPROTOCOL_INFOW lpProtocolInfo, WSPUPCALLTABLE UpcallTable, LPWSPPROC_TABLE lpProcTable)
{
	int          nError;
	WCHAR        szDllPath[256];
	WCHAR        szDllPathEnv[256];
	DWORD        dwSize;
	WSPDATA      wspData;
	LPWSPSTARTUP lpfnWSPStartup;
	LPLSPCONTEXT lpLspContext;

	WriteLogFile(TEXT("WSAStartup"));

	if (g_nStartupCount == 0) {
		g_hheap = HeapCreate(0, 4096, 0);
		if (g_hheap == NULL)
			return WSAEPROVIDERFAILEDINIT;
		
		if (!CreateLspContext(g_guidProvider))
			return WSAEPROVIDERFAILEDINIT;
		
		CopyMemory(&g_upcallTable, &UpcallTable, sizeof(WSPUPCALLTABLE));
	}
	
	g_nStartupCount++;
	
	lpLspContext = GetLspContextFromId(lpProtocolInfo->dwCatalogEntryId);
	if (lpLspContext == NULL)
		return WSAEPROVIDERFAILEDINIT;

	dwSize = sizeof(szDllPathEnv) / sizeof(WCHAR);
	WSCGetProviderPath(&lpLspContext->nextEntryInfo.ProviderId, szDllPathEnv, (LPINT)&dwSize, &nError);

	dwSize = sizeof(szDllPath) / sizeof(WCHAR);
	ExpandEnvironmentStringsW(szDllPathEnv, szDllPath, dwSize);

	lpLspContext->hmod = LoadLibraryW(szDllPath);
	if (lpLspContext->hmod == NULL)
		return WSAEPROVIDERFAILEDINIT;

	lpfnWSPStartup = (LPWSPSTARTUP)GetProcAddress(lpLspContext->hmod, "WSPStartup");
	if (lpfnWSPStartup == NULL) {
		FreeLibrary(lpLspContext->hmod);
		return WSAEPROVIDERFAILEDINIT;
	}

	if (lpfnWSPStartup(wVersionRequested, &wspData, &lpLspContext->nextEntryInfo, UpcallTable, &lpLspContext->procTable) != 0) {
		FreeLibrary(lpLspContext->hmod);
		return WSAEPROVIDERFAILEDINIT;
	}

	CopyMemory(lpProcTable, &lpLspContext->procTable, sizeof(WSPPROC_TABLE));
	lpProcTable->lpWSPAccept  = WSPAccept;
	lpProcTable->lpWSPCleanup = WSPCleanup;
	lpProcTable->lpWSPConnect = WSPConnect;
	lpProcTable->lpWSPRecv    = WSPRecv;
	lpProcTable->lpWSPSend    = WSPSend;
	lpProcTable->lpWSPSocket  = WSPSocket;
	
	CopyMemory(lpWSPData, &wspData, sizeof(WSPDATA));

	return NO_ERROR;
}

SOCKET WSPAPI WSPSocket(int af, int type, int protocol, LPWSAPROTOCOL_INFOW lpProtocolInfo, GROUP g, DWORD dwFlags, LPINT lpErrno)
{
	SOCKET          soc;
	SOCKET          socModify;
	LPLSPCONTEXT    lpLspContext;
	LPSOCKETCONTEXT lp;

	WriteLogFile(TEXT("socket"));
		
	lpLspContext = GetLspContextFromId(lpProtocolInfo->dwCatalogEntryId);
	if (lpLspContext == NULL)
		return INVALID_SOCKET;

	soc = lpLspContext->procTable.lpWSPSocket(af, type, protocol, lpProtocolInfo, g, dwFlags, lpErrno);
	if (soc != INVALID_SOCKET) {
		socModify = g_upcallTable.lpWPUModifyIFSHandle(lpProtocolInfo->dwCatalogEntryId, soc, lpErrno);
		if (soc != socModify)
			soc = INVALID_SOCKET;
	}

	if (soc != INVALID_SOCKET) {
		if (lpLspContext->lpSockHeader == NULL) {
			lpLspContext->lpSockHeader = (LPSOCKETCONTEXT)HeapAlloc(g_hheap, HEAP_ZERO_MEMORY, sizeof(SOCKETCONTEXT));
			lpLspContext->lpSockHeader->soc = soc;
		}
		else {
			lp = lpLspContext->lpSockHeader;
			for (; lp->lpNext != NULL; )
				lp = lp->lpNext;
			lp->lpNext = (LPSOCKETCONTEXT)HeapAlloc(g_hheap, HEAP_ZERO_MEMORY, sizeof(SOCKETCONTEXT));
			lp->lpNext->soc = soc;
		}
	}

	return soc;
}

int WSPAPI WSPConnect(SOCKET s, const struct sockaddr *name, int namelen, LPWSABUF lpCallerData, LPWSABUF lpCalleeData, LPQOS lpSQOS, LPQOS lpGQOS, LPINT lpErrno)
{
	LPLSPCONTEXT lpLspContext;

	WriteLogFile(TEXT("connect"));

	lpLspContext = GetLspContextFromSocket(s);
	if (lpLspContext == NULL) {
		*lpErrno = WSAENOTSOCK;
		return SOCKET_ERROR;
	}

	return lpLspContext->procTable.lpWSPConnect(s, name, namelen, lpCallerData, lpCalleeData, lpSQOS, lpGQOS, lpErrno);
}

SOCKET WSPAPI WSPAccept(SOCKET s, struct sockaddr *addr, LPINT addrlen, LPCONDITIONPROC lpfnCondition, DWORD dwCallbackData, LPINT lpErrno)
{
	SOCKET          socServer;
	SOCKET          socModify;
	LPLSPCONTEXT    lpLspContext;
	LPSOCKETCONTEXT lp;

	WriteLogFile(TEXT("accept"));

	lpLspContext = GetLspContextFromSocket(s);
	if (lpLspContext == NULL) {
		*lpErrno = WSAENOTSOCK;
		return INVALID_SOCKET;
	}

	socServer = lpLspContext->procTable.lpWSPAccept(s, addr, addrlen, lpfnCondition, dwCallbackData, lpErrno);
	if (socServer != INVALID_SOCKET) {
		socModify = g_upcallTable.lpWPUModifyIFSHandle(lpLspContext->entryInfo.dwCatalogEntryId, socServer, lpErrno);
		if (socServer != socModify)
			socServer = INVALID_SOCKET;
	}

	if (socServer != INVALID_SOCKET) {
		lp = lpLspContext->lpSockHeader;
		for (; lp->lpNext != NULL;)
			lp = lp->lpNext;
		lp->lpNext = (LPSOCKETCONTEXT)HeapAlloc(g_hheap, HEAP_ZERO_MEMORY, sizeof(SOCKETCONTEXT));
		lp->lpNext->soc = socServer;
	}

	return socServer;
}

int WSPAPI WSPSend(SOCKET s, LPWSABUF lpBuffers, DWORD dwBufferCount, LPDWORD lpNumberOfBytesSent, DWORD dwFlags, LPWSAOVERLAPPED lpOverlapped, LPWSAOVERLAPPED_COMPLETION_ROUTINE lpCompletionRoutine, LPWSATHREADID lpThreadId, LPINT lpErrno)
{
	LPLSPCONTEXT lpLspContext;

	WriteLogFile(TEXT("send"));

	lpLspContext = GetLspContextFromSocket(s);
	if (lpLspContext == NULL) {
		*lpErrno = WSAENOTSOCK;
		return SOCKET_ERROR;
	}

	return lpLspContext->procTable.lpWSPSend(s, lpBuffers, dwBufferCount, lpNumberOfBytesSent, dwFlags, lpOverlapped, lpCompletionRoutine, lpThreadId, lpErrno);
}

int WSPAPI WSPRecv(SOCKET s, LPWSABUF lpBuffers, DWORD dwBufferCount, LPDWORD lpNumberOfBytesRecvd, LPDWORD lpFlags, LPWSAOVERLAPPED lpOverlapped, LPWSAOVERLAPPED_COMPLETION_ROUTINE lpCompletionRoutine, LPWSATHREADID lpThreadId, LPINT lpErrno)
{
	LPLSPCONTEXT lpLspContext;

	WriteLogFile(TEXT("recv"));

	lpLspContext = GetLspContextFromSocket(s);
	if (lpLspContext == NULL) {
		*lpErrno = WSAENOTSOCK;
		return SOCKET_ERROR;
	}

	return lpLspContext->procTable.lpWSPRecv(s, lpBuffers, dwBufferCount, lpNumberOfBytesRecvd, lpFlags, lpOverlapped, lpCompletionRoutine, lpThreadId, lpErrno);
}

int WSPAPI WSPCleanup(LPINT lpErrno)
{
	int i;

	if (--g_nStartupCount == 0) {
		for (i = 0; i < g_nContextCount; i++) {
			if (g_lpLspContext[i].hmod != NULL) {
				g_lpLspContext[i].procTable.lpWSPCleanup(lpErrno);
				FreeLibrary(g_lpLspContext[i].hmod);
				WriteLogFile(TEXT("WSACleanup"));
			}
		}
		HeapDestroy(g_hheap);
	}
	
	return 0;
}

void WSPAPI GetLspGuid(LPGUID lpGuid)
{
	CopyMemory(lpGuid, &g_guidProvider, sizeof(GUID));
}

BOOL CreateLspContext(GUID guidProvider)
{
	int                 i, j;
	int                 nError;
	int                 nTotalEntryCount;
	DWORD               dwDummyEntryId = 0;
	DWORD               dwSize;
	LPWSAPROTOCOL_INFOW lpEntryList;

	WSCEnumProtocols(NULL, NULL, &dwSize, &nError);
	lpEntryList = (LPWSAPROTOCOL_INFOW)HeapAlloc(GetProcessHeap(), 0, dwSize);
	nTotalEntryCount = WSCEnumProtocols(NULL, lpEntryList, &dwSize, &nError);

	for (i = 0; i < nTotalEntryCount; i++) {
		if (IsEqualGUID(lpEntryList[i].ProviderId, guidProvider)) {
			dwDummyEntryId = lpEntryList[i].dwCatalogEntryId;
			break;
		}
	}

	if (dwDummyEntryId == 0) {
		HeapFree(GetProcessHeap(), 0, lpEntryList);
		return FALSE;
	}
	
	for (i = 0; i < nTotalEntryCount; i++) {
		if (lpEntryList[i].ProtocolChain.ChainLen > 1 && lpEntryList[i].ProtocolChain.ChainEntries[0] == dwDummyEntryId)
			g_nContextCount++;
	}
	
	if (g_nContextCount == 0) {
		HeapFree(GetProcessHeap(), 0, lpEntryList);
		return FALSE;
	}
	
	g_lpLspContext = (LPLSPCONTEXT)HeapAlloc(g_hheap, HEAP_ZERO_MEMORY, sizeof(LSPCONTEXT) * g_nContextCount);

	for (i = 0, j = 0; i < nTotalEntryCount; i++) {
		if (lpEntryList[i].ProtocolChain.ChainLen > 1 && lpEntryList[i].ProtocolChain.ChainEntries[0] == dwDummyEntryId) {
			CopyMemory(&g_lpLspContext[j].entryInfo, &lpEntryList[i], sizeof(WSAPROTOCOL_INFOW));
			j++;
		}
	}
	
	for (i = 0; i < g_nContextCount; i++) {
		for (j = 0; j < nTotalEntryCount; j++) {
			if (lpEntryList[j].dwCatalogEntryId == g_lpLspContext[i].entryInfo.ProtocolChain.ChainEntries[1])
				CopyMemory(&g_lpLspContext[i].nextEntryInfo, &lpEntryList[j], sizeof(WSAPROTOCOL_INFOW));
		}
	}

	HeapFree(GetProcessHeap(), 0, lpEntryList);

	return TRUE;
}

LPLSPCONTEXT GetLspContextFromId(DWORD dwEntryId)
{
	int i;

	for (i = 0; i < g_nContextCount; i++) {
		if (g_lpLspContext[i].entryInfo.dwCatalogEntryId == dwEntryId)
			return &g_lpLspContext[i];
	}

	return NULL;
}

LPLSPCONTEXT GetLspContextFromSocket(SOCKET soc)
{
	int             i;
	LPSOCKETCONTEXT lp;

	for (i = 0; i < g_nContextCount; i++) {
		lp = g_lpLspContext[i].lpSockHeader;
		for (; lp != NULL; ) {
			if (lp->soc == soc)
				return &g_lpLspContext[i];
			lp = lp->lpNext;
		}
	}
	
	return NULL;
}

void WriteLogFile(LPTSTR lpszData)
{
	TCHAR  szFileName[] = TEXT("\\lsplog.txt");
	TCHAR  szModulePath[MAX_PATH];
	TCHAR  szDesktopPath[MAX_PATH];
	TCHAR  szBuf[1024];
	HANDLE hFile;
	DWORD  dwResult;	
	
	SHGetSpecialFolderPath(NULL, szDesktopPath, CSIDL_DESKTOPDIRECTORY, FALSE);
	lstrcat(szDesktopPath, szFileName);

	hFile = CreateFile(szDesktopPath, GENERIC_READ | GENERIC_WRITE, 0, 0, OPEN_ALWAYS, FILE_ATTRIBUTE_NORMAL, NULL);
	if (hFile == INVALID_HANDLE_VALUE)
		return;
	
	SetFilePointer(hFile, 0, NULL, FILE_END);
	
	GetModuleFileName(GetModuleHandle(NULL), szModulePath, sizeof(szModulePath) / sizeof(TCHAR));
	wsprintf(szBuf, TEXT("%s : %s\r\n"), szModulePath, lpszData);
	WriteFile(hFile, szBuf, lstrlen(szBuf) * sizeof(TCHAR), &dwResult, NULL);

	CloseHandle(hFile);
}

LSPCONTEXT構造体は、1つのプロトコルを処理する場合に必要となるメンバが格納されています。 言い換えれば、このメンバに格納される値は各プロトコルによって異なります。 entryInfoは、この構造体がどのプロトコルを表すかを格納し、 nextEntryInfoは次のエントリのプロトコル情報を格納します。 lpSockHeaderは、このプロトコルが使用する一連のソケットの先頭を指すポインタであり、 lpNextを通じて全てのソケットにアクセスできるようになります。 この結果、ソケットからLSPCONTEXT構造体を取得することが可能になります。 WSPStartupでは最初に次の処理が行われています。

if (g_nStartupCount == 0) {
	g_hheap = HeapCreate(0, 4096, 0);
	if (g_hheap == NULL)
		return WSAEPROVIDERFAILEDINIT;
	
	if (!CreateLspContext(g_guidProvider))
		return WSAEPROVIDERFAILEDINIT;
	
	CopyMemory(&g_upcallTable, &UpcallTable, sizeof(WSPUPCALLTABLE));
}

g_nStartupCountは、WSPStartupが呼ばれた回数を表しており、 これが0である場合はまだWSPStartupが呼ばれていないことを意味します。 この場合、各プロトコル共通で使用するデータを初期化することになります。 まず、独自のヒープを作成するためにHeapCreateを呼び出します。 これは、HeapDestroyによって、これまで確保したメモリをまとめて開放するようにしたいからです。 CreateLspContextは、グローバルに定義されたg_lpLspContextとg_lpLspContextを初期化します。

BOOL CreateLspContext(GUID guidProvider)
{
	int                 i, j;
	int                 nError;
	int                 nTotalEntryCount;
	DWORD               dwDummyEntryId = 0;
	DWORD               dwSize;
	LPWSAPROTOCOL_INFOW lpEntryList;

	WSCEnumProtocols(NULL, NULL, &dwSize, &nError);
	lpEntryList = (LPWSAPROTOCOL_INFOW)HeapAlloc(GetProcessHeap(), 0, dwSize);
	nTotalEntryCount = WSCEnumProtocols(NULL, lpEntryList, &dwSize, &nError);

	for (i = 0; i < nTotalEntryCount; i++) {
		if (IsEqualGUID(lpEntryList[i].ProviderId, guidProvider)) {
			dwDummyEntryId = lpEntryList[i].dwCatalogEntryId;
			break;
		}
	}

	if (dwDummyEntryId == 0) {
		HeapFree(GetProcessHeap(), 0, lpEntryList);
		return FALSE;
	}
	
	for (i = 0; i < nTotalEntryCount; i++) {
		if (lpEntryList[i].ProtocolChain.ChainLen > 1 && lpEntryList[i].ProtocolChain.ChainEntries[0] == dwDummyEntryId)
			g_nContextCount++;
	}
	
	if (g_nContextCount == 0) {
		HeapFree(GetProcessHeap(), 0, lpEntryList);
		return FALSE;
	}
	
	g_lpLspContext = (LPLSPCONTEXT)HeapAlloc(g_hheap, HEAP_ZERO_MEMORY, sizeof(LSPCONTEXT) * g_nContextCount);

	for (i = 0, j = 0; i < nTotalEntryCount; i++) {
		if (lpEntryList[i].ProtocolChain.ChainLen > 1 && lpEntryList[i].ProtocolChain.ChainEntries[0] == dwDummyEntryId) {
			CopyMemory(&g_lpLspContext[j].entryInfo, &lpEntryList[i], sizeof(WSAPROTOCOL_INFOW));
			j++;
		}
	}
	
	for (i = 0; i < g_nContextCount; i++) {
		for (j = 0; j < nTotalEntryCount; j++) {
			if (lpEntryList[j].dwCatalogEntryId == g_lpLspContext[i].entryInfo.ProtocolChain.ChainEntries[1])
				CopyMemory(&g_lpLspContext[i].nextEntryInfo, &lpEntryList[j], sizeof(WSAPROTOCOL_INFOW));
		}
	}

	HeapFree(GetProcessHeap(), 0, lpEntryList);

	return TRUE;
}

CreateLspContextの目的は、このLSPに関連する全てのエントリを取得し、 それを基にLSPCONTEXT構造体を初期化することです。 たとえば、LSPがTCPとUDPエントリの上にインストールされているならば、 2つのエントリが存在しているはずですからそれらを取得します。 これらのエントリには、ChainEntries[0]にダミーエントリのIDが格納されているという共通点があるため、 まず行うべきことはダミーエントリのIDを取得することです。 WSCEnumProtocolsで取得したエントリの中で、LSPのGUIDと一致するエントリはダミーエントリであるため、 このエントリのIDを保存します。 次に、ダミーエントリのIDと一致するエントリの数をカウントし、この数だけLSPCONTEXT構造体を確保します。 そして、先と同じループを実行し、確保したLSPCONTEXT構造体のentryInfoにエントリの内容をコピーします。 最後のループは、entryInfoの下に存在するエントリを発見し、それをnextEntryInfoにコピーする処理です。 ChainEntries[1]には下に存在するエントリのIDが格納されているため、 これと一致するエントリを探すようにしています。

各関数が適切なLSPCONTEXT構造体を取得できるように、 GetLspContextFromIdとGetLspContextFromSocketという関数が用意されています。 前者の関数はエントリIDから関連するLSPCONTEXT構造体を取得し、 後者の関数はソケットから関連するLSPCONTEXT構造体を取得します。 どちらの関数を呼び出すかは、呼び出し側の関数の情報量によるでしょう。 たとえば、WSPSocketにはWSAPROTOCOL_INFOW構造体が渡されるため、 この構造体のdwCatalogEntryIdからGetLspContextFromIdを呼び出すことができます。

SOCKET WSPAPI WSPSocket(int af, int type, int protocol, LPWSAPROTOCOL_INFOW lpProtocolInfo, GROUP g, DWORD dwFlags, LPINT lpErrno)
{
	SOCKET          soc;
	SOCKET          socModify;
	LPLSPCONTEXT    lpLspContext;
	LPSOCKETCONTEXT lp;

	WriteLogFile(TEXT("socket"));
		
	lpLspContext = GetLspContextFromId(lpProtocolInfo->dwCatalogEntryId);
	if (lpLspContext == NULL)
		return INVALID_SOCKET;

	soc = lpLspContext->procTable.lpWSPSocket(af, type, protocol, lpProtocolInfo, g, dwFlags, lpErrno);
	if (soc != INVALID_SOCKET) {
		socModify = g_upcallTable.lpWPUModifyIFSHandle(lpProtocolInfo->dwCatalogEntryId, soc, lpErrno);
		if (soc != socModify)
			soc = INVALID_SOCKET;
	}

	if (soc != INVALID_SOCKET) {
		if (lpLspContext->lpSockHeader == NULL) {
			lpLspContext->lpSockHeader = (LPSOCKETCONTEXT)HeapAlloc(g_hheap, HEAP_ZERO_MEMORY, sizeof(SOCKETCONTEXT));
			lpLspContext->lpSockHeader->soc = soc;
		}
		else {
			lp = lpLspContext->lpSockHeader;
			for (; lp->lpNext != NULL; )
				lp = lp->lpNext;
			lp->lpNext = (LPSOCKETCONTEXT)HeapAlloc(g_hheap, HEAP_ZERO_MEMORY, sizeof(SOCKETCONTEXT));
			lp->lpNext->soc = soc;
		}
	}

	return soc;
}

WSPSocketは、どのようなプロトコルを使用している場合でも呼び出される関数であり、 現在要求されているプロトコルを特定するのは非常に重要といえます。 GetLspContextFromIdを呼び出せば、適切なプロトコルの情報を格納したLSPCONTEXT構造体が返るため、 これを使用すれば間違って他のプロトコルのデータを書き換えることもなくなります。 ソケットを正常に取得できた場合は、LSPCONTEXT構造体のlpSockHeaderのためにメモリを確保し、 そこに取得したソケットを格納するようにします。 2回目以降の場合は、lpNextが新しいメモリを指すようにし、 lpNextを通じて全てのソケットにアクセスできるような仕組みを作っています。

WSPConnectやWSPAcceptではソケットが引数として渡されるため、 GetLspContextFromSocketを呼び出すことになります。

SOCKET WSPAPI WSPAccept(SOCKET s, struct sockaddr *addr, LPINT addrlen, LPCONDITIONPROC lpfnCondition, DWORD dwCallbackData, LPINT lpErrno)
{
	SOCKET          socServer;
	SOCKET          socModify;
	LPLSPCONTEXT    lpLspContext;
	LPSOCKETCONTEXT lp;

	WriteLogFile(TEXT("accept"));

	lpLspContext = GetLspContextFromSocket(s);
	if (lpLspContext == NULL) {
		*lpErrno = WSAENOTSOCK;
		return INVALID_SOCKET;
	}

	socServer = lpLspContext->procTable.lpWSPAccept(s, addr, addrlen, lpfnCondition, dwCallbackData, lpErrno);
	if (socServer != INVALID_SOCKET) {
		socModify = g_upcallTable.lpWPUModifyIFSHandle(lpLspContext->entryInfo.dwCatalogEntryId, socServer, lpErrno);
		if (socServer != socModify)
			socServer = INVALID_SOCKET;
	}

	if (socServer != INVALID_SOCKET) {
		lp = lpLspContext->lpSockHeader;
		for (; lp->lpNext != NULL;)
			lp = lp->lpNext;
		lp->lpNext = (LPSOCKETCONTEXT)HeapAlloc(g_hheap, HEAP_ZERO_MEMORY, sizeof(SOCKETCONTEXT));
		lp->lpNext->soc = socServer;
	}

	return socServer;
}

lpWSPAcceptで取得したソケットもSOCKETCONTEXT構造体のリストに追加しておくことになります。 そうすることで、このソケットが何らかの関数に渡された場合に、 GetLspContextFromSocketで関連するLSPCONTEXT構造体を取得できるようになります。 WSPAcceptやWSPConnectのようなsockaddr構造体を取得できる関数では、 この構造体をSOCKETCONTEXT構造体のメンバとして保存しておくのも面白いでしょう。

WSPCleanupの実装は、次のようになっています。

int WSPAPI WSPCleanup(LPINT lpErrno)
{
	int i;

	if (--g_nStartupCount == 0) {
		for (i = 0; i < g_nContextCount; i++) {
			if (g_lpLspContext[i].hmod != NULL) {
				g_lpLspContext[i].procTable.lpWSPCleanup(lpErrno);
				FreeLibrary(g_lpLspContext[i].hmod);
				WriteLogFile(TEXT("WSACleanup"));
			}
		}
		HeapDestroy(g_hheap);
	}
	
	return 0;
}

この関数では、下のエントリのWSPCleanupを呼び出し、FreeLibraryで下のエントリをアンロードする必要がありますが、 ここで少し問題が発生します。 それは、この関数にプロトコルを特定するための引数が存在しないという点です。 たとえば、TCPとUDPの上にLSPをインストールしている場合、 WinsockアプリケーションのWSACleanup呼び出しによって、 TCP用のWSPCleanupとUDP用のWSPCleanupが呼ばれることがありますが、 関数内ではどのLSPCONTEXT構造体を参照すればよいかが分かりません。 WSPCleanupが呼ばれる回数は、g_nStartupCountに格納されているため、 これを下げ続けて0になった場合は、今回のWSPCleanupが最後の呼び出しということになります。 つまり、全ての開放処理を行ってもよい段階であるため、 g_lpLspContextの各要素にアクセスしています。


戻る