/* Copyright (c) 2005-2021 Intel Corporation Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneapi/tbb/detail/_config.h" #include "oneapi/tbb/detail/_assert.h" #include "../tbb/assert_impl.h" #if !__TBB_WIN8UI_SUPPORT && defined(_WIN32) #ifndef _CRT_SECURE_NO_DEPRECATE #define _CRT_SECURE_NO_DEPRECATE 1 #endif // no standard-conforming implementation of snprintf prior to VS 2015 #if !defined(_MSC_VER) || _MSC_VER>=1900 #define LOG_PRINT(s, n, format, ...) snprintf(s, n, format, __VA_ARGS__) #else #define LOG_PRINT(s, n, format, ...) _snprintf_s(s, n, _TRUNCATE, format, __VA_ARGS__) #endif #include #include #include #include #include "function_replacement.h" // The information about a standard memory allocation function for the replacement log struct FunctionInfo { const char* funcName; const char* dllName; }; // Namespace that processes and manages the output of records to the Log journal // that will be provided to user by TBB_malloc_replacement_log() namespace Log { // Value of RECORDS_COUNT is set due to the fact that we maximally // scan 8 modules, and in every module we can swap 6 opcodes. (rounded to 8) static const unsigned RECORDS_COUNT = 8 * 8; static const unsigned RECORD_LENGTH = MAX_PATH; // Need to add 1 to count of records, because last record must be always NULL static char *records[RECORDS_COUNT + 1]; static bool replacement_status = true; // Internal counter that contains number of next string for record static unsigned record_number = 0; // Function that writes info about (not)found opcodes to the Log journal // functionInfo - information about a standard memory allocation function for the replacement log // opcodeString - string, that contain byte code of this function // status - information about function replacement status static void record(FunctionInfo functionInfo, const char * opcodeString, bool status) { __TBB_ASSERT(functionInfo.dllName, "Empty DLL name value"); __TBB_ASSERT(functionInfo.funcName, "Empty function name value"); __TBB_ASSERT(opcodeString, "Empty opcode"); __TBB_ASSERT(record_number <= RECORDS_COUNT, "Incorrect record number"); //If some replacement failed -> set status to false replacement_status &= status; // If we reach the end of the log, write this message to the last line if (record_number == RECORDS_COUNT) { // %s - workaround to fix empty variable argument parsing behavior in GCC LOG_PRINT(records[RECORDS_COUNT - 1], RECORD_LENGTH, "%s", "Log was truncated."); return; } char* entry = (char*)HeapAlloc(GetProcessHeap(), HEAP_ZERO_MEMORY, RECORD_LENGTH); __TBB_ASSERT(entry, "Invalid memory was returned"); LOG_PRINT(entry, RECORD_LENGTH, "%s: %s (%s), byte pattern: <%s>", status ? "Success" : "Fail", functionInfo.funcName, functionInfo.dllName, opcodeString); records[record_number++] = entry; } }; inline UINT_PTR Ptr2Addrint(LPVOID ptr) { Int2Ptr i2p; i2p.lpv = ptr; return i2p.uip; } inline LPVOID Addrint2Ptr(UINT_PTR ptr) { Int2Ptr i2p; i2p.uip = ptr; return i2p.lpv; } // Is the distance between addr1 and addr2 smaller than dist inline bool IsInDistance(UINT_PTR addr1, UINT_PTR addr2, __int64 dist) { __int64 diff = addr1>addr2 ? addr1-addr2 : addr2-addr1; return diff= m_allocSize) { // Found a free region, try to allocate a page in this region void *newPage = VirtualAlloc(newAddr, m_allocSize, MEM_COMMIT|MEM_RESERVE, PAGE_READWRITE); if (!newPage) break; // Add the new page to the pages database MemoryBuffer *pBuff = new (m_lastBuffer) MemoryBuffer(newPage, m_allocSize); ++m_lastBuffer; return pBuff; } } // Failed to find a buffer in the distance return 0; } public: MemoryProvider() { SYSTEM_INFO sysInfo; GetSystemInfo(&sysInfo); m_allocSize = sysInfo.dwAllocationGranularity; m_lastBuffer = &m_pages[0]; } // We can't free the pages in the destructor because the trampolines // are using these memory locations and a replaced function might be called // after the destructor was called. ~MemoryProvider() { } // Return a memory location in distance less than 2^31 from input address UINT_PTR GetLocation(UINT_PTR addr) { MemoryBuffer *pBuff = m_pages; for (; pBuffm_next, addr, MAX_DISTANCE); ++pBuff) { if (pBuff->m_next < pBuff->m_base + pBuff->m_size) { UINT_PTR loc = pBuff->m_next; pBuff->m_next += MAX_PROBE_SIZE; return loc; } } pBuff = CreateBuffer(addr); if(!pBuff) return 0; UINT_PTR loc = pBuff->m_next; pBuff->m_next += MAX_PROBE_SIZE; return loc; } private: MemoryBuffer m_pages[MAX_NUM_BUFFERS]; MemoryBuffer *m_lastBuffer; DWORD m_allocSize; }; static MemoryProvider memProvider; // Compare opcodes from dictionary (str1) and opcodes from code (str2) // str1 might contain '*' to mask addresses // RETURN: 0 if opcodes did not match, 1 on success size_t compareStrings( const char *str1, const char *str2 ) { for (size_t i=0; str1[i]!=0; i++){ if( str1[i]!='*' && str1[i]!='#' && str1[i]!=str2[i] ) return 0; } return 1; } // Check function prologue with known prologues from the dictionary // opcodes - dictionary // inpAddr - pointer to function prologue // Dictionary contains opcodes for several full asm instructions // + one opcode byte for the next asm instruction for safe address processing // RETURN: 1 + the index of the matched pattern, or 0 if no match found. static UINT CheckOpcodes( const char ** opcodes, void *inpAddr, bool abortOnError, const FunctionInfo* functionInfo = NULL) { static size_t opcodesStringsCount = 0; static size_t maxOpcodesLength = 0; static size_t opcodes_pointer = (size_t)opcodes; char opcodeString[2*MAX_PATTERN_SIZE+1]; size_t i; size_t result = 0; // Get the values for static variables // max length and number of patterns if( !opcodesStringsCount || opcodes_pointer != (size_t)opcodes ){ while( *(opcodes + opcodesStringsCount)!= NULL ){ if( (i=strlen(*(opcodes + opcodesStringsCount))) > maxOpcodesLength ) maxOpcodesLength = i; opcodesStringsCount++; } opcodes_pointer = (size_t)opcodes; __TBB_ASSERT( maxOpcodesLength/2 <= MAX_PATTERN_SIZE, "Pattern exceeded the limit of 28 opcodes/56 symbols" ); } // Translate prologue opcodes to string format to compare for( i=0; i= SIZE_OF_RELJUMP, "Incorrect bytecode pattern?" ); UINT_PTR trampAddr = memProvider.GetLocation(srcAddr); if (!trampAddr) return 0; *storedAddr = Addrint2Ptr(trampAddr); // Set 'executable' flag for original instructions in the new place DWORD pageFlags = PAGE_EXECUTE_READWRITE; if (!VirtualProtect(*storedAddr, MAX_PROBE_SIZE, pageFlags, &pageFlags)) return 0; // Copy original instructions to the new place memcpy(*storedAddr, codePtr, bytesToMove); offset = srcAddr - trampAddr; offset32 = (UINT)(offset & 0xFFFFFFFF); CorrectOffset( trampAddr, pattern, offset32 ); // Set jump to the code after replacement offset32 -= SIZE_OF_RELJUMP; *(UCHAR*)(trampAddr+bytesToMove) = 0xE9; memcpy((UCHAR*)(trampAddr+bytesToMove+1), &offset32, sizeof(offset32)); } // The following will work correctly even if srcAddr>tgtAddr, as long as // address difference is less than 2^31, which is guaranteed by IsInDistance. offset = tgtAddr - srcAddr - SIZE_OF_RELJUMP; offset32 = (UINT)(offset & 0xFFFFFFFF); // Insert the jump to the new code *codePtr = 0xE9; memcpy(codePtr+1, &offset32, sizeof(offset32)); // Fill the rest with NOPs to correctly see disassembler of old code in debugger. for( unsigned i=SIZE_OF_RELJUMP; i= SIZE_OF_INDJUMP, "Incorrect bytecode pattern?" ); UINT_PTR trampAddr = memProvider.GetLocation(srcAddr); if (!trampAddr) return 0; *storedAddr = Addrint2Ptr(trampAddr); // Set 'executable' flag for original instructions in the new place DWORD pageFlags = PAGE_EXECUTE_READWRITE; if (!VirtualProtect(*storedAddr, MAX_PROBE_SIZE, pageFlags, &pageFlags)) return 0; // Copy original instructions to the new place memcpy(*storedAddr, codePtr, bytesToMove); offset = srcAddr - trampAddr; offset32 = (UINT)(offset & 0xFFFFFFFF); CorrectOffset( trampAddr, pattern, offset32 ); // Set jump to the code after replacement. It is within the distance of relative jump! offset32 -= SIZE_OF_RELJUMP; *(UCHAR*)(trampAddr+bytesToMove) = 0xE9; memcpy((UCHAR*)(trampAddr+bytesToMove+1), &offset32, sizeof(offset32)); } // Fill the buffer offset = location - srcAddr - SIZE_OF_INDJUMP; offset32 = (UINT)(offset & 0xFFFFFFFF); *(codePtr) = 0xFF; *(codePtr+1) = 0x25; memcpy(codePtr+2, &offset32, sizeof(offset32)); // Fill the rest with NOPs to correctly see disassembler of old code in debugger. for( unsigned i=SIZE_OF_INDJUMP; i 0, "abortOnError ignored in CheckOpcodes?" ); pattern = opcodes[opcodeIdx-1]; // -1 compensates for +1 in CheckOpcodes } } probeSize = InsertTrampoline32(inpAddr, targetAddr, pattern, origFunc); if (!probeSize) probeSize = InsertTrampoline64(inpAddr, targetAddr, pattern, origFunc); // Restore original protection VirtualProtect(inpAddr, MAX_PROBE_SIZE, origProt, &origProt); if (!probeSize) return FALSE; FlushInstructionCache(GetCurrentProcess(), inpAddr, probeSize); FlushInstructionCache(GetCurrentProcess(), origFunc, probeSize); return TRUE; } // Routine to replace the functions // TODO: replace opcodesNumber with opcodes and opcodes number to check if we replace right code. FRR_TYPE ReplaceFunctionA(const char *dllName, const char *funcName, FUNCPTR newFunc, const char ** opcodes, FUNCPTR* origFunc) { // Cache the results of the last search for the module // Assume that there was no DLL unload between static char cachedName[MAX_PATH+1]; static HMODULE cachedHM = 0; if (!dllName || !*dllName) return FRR_NODLL; if (!cachedHM || strncmp(dllName, cachedName, MAX_PATH) != 0) { // Find the module handle for the input dll HMODULE hModule = GetModuleHandleA(dllName); if (hModule == 0) { // Couldn't find the module with the input name cachedHM = 0; return FRR_NODLL; } cachedHM = hModule; strncpy(cachedName, dllName, MAX_PATH); } FARPROC inpFunc = GetProcAddress(cachedHM, funcName); if (inpFunc == 0) { // Function was not found return FRR_NOFUNC; } if (!InsertTrampoline((void*)inpFunc, (void*)newFunc, opcodes, (void**)origFunc)){ // Failed to insert the trampoline to the target address return FRR_FAILED; } return FRR_OK; } FRR_TYPE ReplaceFunctionW(const wchar_t *dllName, const char *funcName, FUNCPTR newFunc, const char ** opcodes, FUNCPTR* origFunc) { // Cache the results of the last search for the module // Assume that there was no DLL unload between static wchar_t cachedName[MAX_PATH+1]; static HMODULE cachedHM = 0; if (!dllName || !*dllName) return FRR_NODLL; if (!cachedHM || wcsncmp(dllName, cachedName, MAX_PATH) != 0) { // Find the module handle for the input dll HMODULE hModule = GetModuleHandleW(dllName); if (hModule == 0) { // Couldn't find the module with the input name cachedHM = 0; return FRR_NODLL; } cachedHM = hModule; wcsncpy(cachedName, dllName, MAX_PATH); } FARPROC inpFunc = GetProcAddress(cachedHM, funcName); if (inpFunc == 0) { // Function was not found return FRR_NOFUNC; } if (!InsertTrampoline((void*)inpFunc, (void*)newFunc, opcodes, (void**)origFunc)){ // Failed to insert the trampoline to the target address return FRR_FAILED; } return FRR_OK; } bool IsPrologueKnown(const char* dllName, const char *funcName, const char **opcodes, HMODULE module) { FARPROC inpFunc = GetProcAddress(module, funcName); FunctionInfo functionInfo = { funcName, dllName }; if (!inpFunc) { Log::record(functionInfo, "unknown", /*status*/ false); return false; } return CheckOpcodes( opcodes, (void*)inpFunc, /*abortOnError=*/false, &functionInfo) != 0; } // Public Windows API extern "C" __declspec(dllexport) int TBB_malloc_replacement_log(char *** function_replacement_log_ptr) { if (function_replacement_log_ptr != NULL) { *function_replacement_log_ptr = Log::records; } // If we have no logs -> return false status return Log::replacement_status && Log::records[0] != NULL ? 0 : -1; } #endif /* !__TBB_WIN8UI_SUPPORT && defined(_WIN32) */