Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions tools/fixup-linkage/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,14 @@
# the terms of the Apache License 2.0 which accompanies this distribution. #
# ============================================================================ #

set(LLVM_LINK_COMPONENTS
Core
IRReader
Support
)

add_llvm_executable(fixup-linkage fixup-linkage.cpp)

llvm_update_compile_flags(fixup-linkage)

install(TARGETS fixup-linkage DESTINATION bin)
131 changes: 70 additions & 61 deletions tools/fixup-linkage/fixup-linkage.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,30 @@
* the terms of the Apache License 2.0 which accompanies this distribution. *
******************************************************************************/

/// The fixup-linkage tool is used to rewrite the LLVM IR produced by clang for
/// the classical compute code such that it can be linked correctly with the
/// LLVM IR that is generated for the quantum code. This avoids linker errors
/// such as "duplicate symbol definition".
/// The fixup-linkage tool processes the LLVM IR produced by clang for the
/// classical compute code. For each __qpu__ kernel function, it replaces the
/// function body with a stub containing 'unreachable'. This:
/// 1. Avoids compiling kernel bodies that reference quantum-only types
/// (qvector)
/// 2. Keeps a valid function address for __cudaq_registerLinkableKernel
/// 3. Uses linkonce_odr linkage so the MLIR-generated version overrides the
/// stub The actual kernel implementations are provided by the quantum code
/// path.

#include "llvm/IR/BasicBlock.h"
#include "llvm/IR/Function.h"
#include "llvm/IR/Instructions.h"
#include "llvm/IR/LLVMContext.h"
#include "llvm/IR/Module.h"
#include "llvm/IRReader/IRReader.h"
#include "llvm/Support/FileSystem.h"
#include "llvm/Support/SourceMgr.h"
#include "llvm/Support/raw_ostream.h"

#include <fstream>
#include <iostream>
#include <regex>
#include <set>

int main(int argc, char *argv[]) {
if (argc != 4) {
Expand All @@ -25,7 +41,7 @@ int main(int argc, char *argv[]) {
// mangled_name_map in the quake file. Add these names to `funcs`.
std::ifstream modFile(argv[1]);
std::string line;
std::vector<std::string> funcs;
std::set<std::string> funcs;
{
std::regex mapRegex{"quake\\.mangled_name_map = [{]"};
std::regex stringRegex{"\"(.*?)\""};
Expand All @@ -40,7 +56,7 @@ int main(int argc, char *argv[]) {
std::sregex_iterator(names.begin(), names.end(), stringRegex);
for (std::sregex_iterator i = namesBegin; i != rgxEnd; ++i) {
auto s = i->str();
funcs.push_back(s.substr(1, s.size() - 2));
funcs.insert(s.substr(1, s.size() - 2));
}
}
modFile.close();
Expand All @@ -50,64 +66,57 @@ int main(int argc, char *argv[]) {
}
}

// 2. Scan the LLVM file looking for the mangled kernel names. Where these
// kernels are defined, they have their linkage modified to `linkonce_odr` if
// that is not already the linkage. This change will prevent the duplicate
// symbols defined error from the linker.
std::ifstream llFile(argv[2]);
std::ofstream outFile(argv[3]);
std::regex filterRegex("^define ");
auto rgxEnd = std::sregex_iterator();
auto computeCutPosition =
[&](const std::string &matchStr) -> std::pair<bool, std::size_t> {
std::regex rex("^" + matchStr + " ");
auto iter = std::sregex_iterator(line.begin(), line.end(), rex);
if (iter == rgxEnd)
return {false, 0};
return {true, matchStr.size()};
};
while (std::getline(llFile, line)) {
auto iter = std::sregex_iterator(line.begin(), line.end(), filterRegex);
if (iter == rgxEnd) {
outFile << line << std::endl;
// 2. Parse the LLVM IR file using LLVM APIs.
llvm::LLVMContext context;
llvm::SMDiagnostic err;
std::unique_ptr<llvm::Module> module =
llvm::parseIRFile(argv[2], err, context);

if (!module) {
err.print(argv[0], llvm::errs());
return 1;
}

// 3. For each kernel function, replace its body with a stub containing
// 'unreachable'. This avoids compiling the original body (which may
// reference quantum-only types like qvector) while keeping the function
// as a definition with a valid address. The address is needed because
// __cudaq_registerLinkableKernel takes a pointer to the C++ function.
// The actual implementation is provided by the MLIR/quantum code path.
for (llvm::Function &func : *module) {
if (func.isDeclaration())
continue;
}
if (line.find(" linkonce_odr ") != std::string::npos ||
line.find(" weak dso_local ") != std::string::npos) {
outFile << line << std::endl;

// Check if this function is one of our kernels.
std::string funcName = func.getName().str();
if (!funcs.contains(funcName))
continue;
}
// At this point, `line` starts with define but does not contain
// linkonce_odr. So it is a candidate for being rewritten.
bool replaced = false;
for (auto fn : funcs) {
// Check if this is defining one of our kernels.
auto pos = line.find(fn);
if (pos == std::string::npos)
continue;
auto pair = computeCutPosition("define internal");
if (!pair.first)
pair = computeCutPosition("define dso_local");
// On macOS, clang emits some functions with weak linkage.
if (!pair.first)
pair = computeCutPosition("define weak");
if (!pair.first)
pair = computeCutPosition("define");
if (!pair.first) {
// This is a hard error because the line must have a define.
std::cerr << "internal error: line no longer matches.\n";
return 1;
}
pos = pair.second;
outFile << "define linkonce_odr dso_preemptable" << line.substr(pos)
<< std::endl;
replaced = true;
break;
}
if (!replaced)
outFile << line << std::endl;

// Delete all existing basic blocks.
func.deleteBody();

// Create a new entry block with just 'unreachable'.
// This provides a valid function address while ensuring the classical
// body is never executed (the runtime redirects to the MLIR version).
llvm::BasicBlock *entryBB =
llvm::BasicBlock::Create(context, "entry", &func);
new llvm::UnreachableInst(context, entryBB);

// Change to linkonce_odr with dso_preemptable.
func.setLinkage(llvm::GlobalValue::LinkOnceODRLinkage);
func.setDSOLocal(false);
}
llFile.close();

// 4. Write the modified module to the output file.
std::error_code ec;
llvm::raw_fd_ostream outFile(argv[3], ec, llvm::sys::fs::OF_Text);
if (ec) {
std::cerr << "Error opening output file: " << ec.message() << "\n";
return 1;
}

module->print(outFile, nullptr);
outFile.close();

return 0;
}
Loading