diff --git a/tools/fixup-linkage/CMakeLists.txt b/tools/fixup-linkage/CMakeLists.txt index 1213bdec451..3d21fc12d14 100644 --- a/tools/fixup-linkage/CMakeLists.txt +++ b/tools/fixup-linkage/CMakeLists.txt @@ -6,6 +6,14 @@ # the terms of the Apache License 2.0 which accompanies this distribution. # # ============================================================================ # +set(LLVM_LINK_COMPONENTS + Core + IRReader + Support +) + add_llvm_executable(fixup-linkage fixup-linkage.cpp) +llvm_update_compile_flags(fixup-linkage) + install(TARGETS fixup-linkage DESTINATION bin) diff --git a/tools/fixup-linkage/fixup-linkage.cpp b/tools/fixup-linkage/fixup-linkage.cpp index 44a5069e9ea..a7ff07f9912 100644 --- a/tools/fixup-linkage/fixup-linkage.cpp +++ b/tools/fixup-linkage/fixup-linkage.cpp @@ -6,14 +6,30 @@ * the terms of the Apache License 2.0 which accompanies this distribution. * ******************************************************************************/ -/// The fixup-linkage tool is used to rewrite the LLVM IR produced by clang for -/// the classical compute code such that it can be linked correctly with the -/// LLVM IR that is generated for the quantum code. This avoids linker errors -/// such as "duplicate symbol definition". +/// The fixup-linkage tool processes the LLVM IR produced by clang for the +/// classical compute code. For each __qpu__ kernel function, it replaces the +/// function body with a stub containing 'unreachable'. This: +/// 1. Avoids compiling kernel bodies that reference quantum-only types +/// (qvector) +/// 2. Keeps a valid function address for __cudaq_registerLinkableKernel +/// 3. Uses linkonce_odr linkage so the MLIR-generated version overrides the +/// stub The actual kernel implementations are provided by the quantum code +/// path. + +#include "llvm/IR/BasicBlock.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/LLVMContext.h" +#include "llvm/IR/Module.h" +#include "llvm/IRReader/IRReader.h" +#include "llvm/Support/FileSystem.h" +#include "llvm/Support/SourceMgr.h" +#include "llvm/Support/raw_ostream.h" #include #include #include +#include int main(int argc, char *argv[]) { if (argc != 4) { @@ -25,7 +41,7 @@ int main(int argc, char *argv[]) { // mangled_name_map in the quake file. Add these names to `funcs`. std::ifstream modFile(argv[1]); std::string line; - std::vector funcs; + std::set funcs; { std::regex mapRegex{"quake\\.mangled_name_map = [{]"}; std::regex stringRegex{"\"(.*?)\""}; @@ -40,7 +56,7 @@ int main(int argc, char *argv[]) { std::sregex_iterator(names.begin(), names.end(), stringRegex); for (std::sregex_iterator i = namesBegin; i != rgxEnd; ++i) { auto s = i->str(); - funcs.push_back(s.substr(1, s.size() - 2)); + funcs.insert(s.substr(1, s.size() - 2)); } } modFile.close(); @@ -50,64 +66,57 @@ int main(int argc, char *argv[]) { } } - // 2. Scan the LLVM file looking for the mangled kernel names. Where these - // kernels are defined, they have their linkage modified to `linkonce_odr` if - // that is not already the linkage. This change will prevent the duplicate - // symbols defined error from the linker. - std::ifstream llFile(argv[2]); - std::ofstream outFile(argv[3]); - std::regex filterRegex("^define "); - auto rgxEnd = std::sregex_iterator(); - auto computeCutPosition = - [&](const std::string &matchStr) -> std::pair { - std::regex rex("^" + matchStr + " "); - auto iter = std::sregex_iterator(line.begin(), line.end(), rex); - if (iter == rgxEnd) - return {false, 0}; - return {true, matchStr.size()}; - }; - while (std::getline(llFile, line)) { - auto iter = std::sregex_iterator(line.begin(), line.end(), filterRegex); - if (iter == rgxEnd) { - outFile << line << std::endl; + // 2. Parse the LLVM IR file using LLVM APIs. + llvm::LLVMContext context; + llvm::SMDiagnostic err; + std::unique_ptr module = + llvm::parseIRFile(argv[2], err, context); + + if (!module) { + err.print(argv[0], llvm::errs()); + return 1; + } + + // 3. For each kernel function, replace its body with a stub containing + // 'unreachable'. This avoids compiling the original body (which may + // reference quantum-only types like qvector) while keeping the function + // as a definition with a valid address. The address is needed because + // __cudaq_registerLinkableKernel takes a pointer to the C++ function. + // The actual implementation is provided by the MLIR/quantum code path. + for (llvm::Function &func : *module) { + if (func.isDeclaration()) continue; - } - if (line.find(" linkonce_odr ") != std::string::npos || - line.find(" weak dso_local ") != std::string::npos) { - outFile << line << std::endl; + + // Check if this function is one of our kernels. + std::string funcName = func.getName().str(); + if (!funcs.contains(funcName)) continue; - } - // At this point, `line` starts with define but does not contain - // linkonce_odr. So it is a candidate for being rewritten. - bool replaced = false; - for (auto fn : funcs) { - // Check if this is defining one of our kernels. - auto pos = line.find(fn); - if (pos == std::string::npos) - continue; - auto pair = computeCutPosition("define internal"); - if (!pair.first) - pair = computeCutPosition("define dso_local"); - // On macOS, clang emits some functions with weak linkage. - if (!pair.first) - pair = computeCutPosition("define weak"); - if (!pair.first) - pair = computeCutPosition("define"); - if (!pair.first) { - // This is a hard error because the line must have a define. - std::cerr << "internal error: line no longer matches.\n"; - return 1; - } - pos = pair.second; - outFile << "define linkonce_odr dso_preemptable" << line.substr(pos) - << std::endl; - replaced = true; - break; - } - if (!replaced) - outFile << line << std::endl; + + // Delete all existing basic blocks. + func.deleteBody(); + + // Create a new entry block with just 'unreachable'. + // This provides a valid function address while ensuring the classical + // body is never executed (the runtime redirects to the MLIR version). + llvm::BasicBlock *entryBB = + llvm::BasicBlock::Create(context, "entry", &func); + new llvm::UnreachableInst(context, entryBB); + + // Change to linkonce_odr with dso_preemptable. + func.setLinkage(llvm::GlobalValue::LinkOnceODRLinkage); + func.setDSOLocal(false); } - llFile.close(); + + // 4. Write the modified module to the output file. + std::error_code ec; + llvm::raw_fd_ostream outFile(argv[3], ec, llvm::sys::fs::OF_Text); + if (ec) { + std::cerr << "Error opening output file: " << ec.message() << "\n"; + return 1; + } + + module->print(outFile, nullptr); outFile.close(); + return 0; }