diff options
Diffstat (limited to 'lib')
-rw-r--r-- | lib/CodeGen/CGOpenMPRuntimeNVPTX.cpp | 26 |
1 files changed, 24 insertions, 2 deletions
diff --git a/lib/CodeGen/CGOpenMPRuntimeNVPTX.cpp b/lib/CodeGen/CGOpenMPRuntimeNVPTX.cpp index 284d14b995..6c908412e0 100644 --- a/lib/CodeGen/CGOpenMPRuntimeNVPTX.cpp +++ b/lib/CodeGen/CGOpenMPRuntimeNVPTX.cpp @@ -107,6 +107,10 @@ enum OpenMPRTLFunctionNVPTX { /// Call to void __kmpc_barrier_simple_spmd(ident_t *loc, kmp_int32 /// global_tid); OMPRTL__kmpc_barrier_simple_spmd, + /// Call to int32_t __kmpc_warp_active_thread_mask(void); + OMPRTL_NVPTX__kmpc_warp_active_thread_mask, + /// Call to void __kmpc_syncwarp(int32_t Mask); + OMPRTL_NVPTX__kmpc_syncwarp, }; /// Pre(post)-action for different OpenMP constructs specialized for NVPTX. @@ -1794,6 +1798,20 @@ CGOpenMPRuntimeNVPTX::createNVPTXRuntimeFunction(unsigned Function) { ->addFnAttr(llvm::Attribute::Convergent); break; } + case OMPRTL_NVPTX__kmpc_warp_active_thread_mask: { + // Build int32_t __kmpc_warp_active_thread_mask(void); + auto *FnTy = + llvm::FunctionType::get(CGM.Int32Ty, llvm::None, /*isVarArg=*/false); + RTLFn = CGM.CreateRuntimeFunction(FnTy, "__kmpc_warp_active_thread_mask"); + break; + } + case OMPRTL_NVPTX__kmpc_syncwarp: { + // Build void __kmpc_syncwarp(kmp_int32 Mask); + auto *FnTy = + llvm::FunctionType::get(CGM.VoidTy, CGM.Int32Ty, /*isVarArg=*/false); + RTLFn = CGM.CreateRuntimeFunction(FnTy, "__kmpc_syncwarp"); + break; + } } return RTLFn; } @@ -2700,6 +2718,9 @@ void CGOpenMPRuntimeNVPTX::emitCriticalRegion( llvm::BasicBlock *BodyBB = CGF.createBasicBlock("omp.critical.body"); llvm::BasicBlock *ExitBB = CGF.createBasicBlock("omp.critical.exit"); + // Get the mask of active threads in the warp. + llvm::Value *Mask = CGF.EmitRuntimeCall( + createNVPTXRuntimeFunction(OMPRTL_NVPTX__kmpc_warp_active_thread_mask)); // Fetch team-local id of the thread. llvm::Value *ThreadID = getNVPTXThreadID(CGF); @@ -2740,8 +2761,9 @@ void CGOpenMPRuntimeNVPTX::emitCriticalRegion( // Block waits for all threads in current team to finish then increments the // counter variable and returns to the loop. CGF.EmitBlock(SyncBB); - emitBarrierCall(CGF, Loc, OMPD_unknown, /*EmitChecks=*/false, - /*ForceSimpleCall=*/true); + // Reconverge active threads in the warp. + (void)CGF.EmitRuntimeCall( + createNVPTXRuntimeFunction(OMPRTL_NVPTX__kmpc_syncwarp), Mask); llvm::Value *IncCounterVal = CGF.Builder.CreateNSWAdd(CounterVal, CGF.Builder.getInt32(1)); |