summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--include/clang/AST/OpenMPClause.h24
-rw-r--r--include/clang/AST/RecursiveASTVisitor.h1
-rw-r--r--lib/AST/OpenMPClause.cpp9
-rw-r--r--lib/AST/StmtProfile.cpp1
-rw-r--r--lib/Sema/SemaOpenMP.cpp105
-rw-r--r--lib/Serialization/ASTReader.cpp1
-rw-r--r--lib/Serialization/ASTWriter.cpp1
-rw-r--r--test/OpenMP/parallel_master_taskloop_codegen.cpp9
8 files changed, 129 insertions, 22 deletions
diff --git a/include/clang/AST/OpenMPClause.h b/include/clang/AST/OpenMPClause.h
index 6c504c7701..db780f7ed3 100644
--- a/include/clang/AST/OpenMPClause.h
+++ b/include/clang/AST/OpenMPClause.h
@@ -5268,7 +5268,7 @@ public:
/// \endcode
/// In this example directive '#pragma omp taskloop' has clause 'grainsize'
/// with single expression '4'.
-class OMPGrainsizeClause : public OMPClause {
+class OMPGrainsizeClause : public OMPClause, public OMPClauseWithPreInit {
friend class OMPClauseReader;
/// Location of '('.
@@ -5284,16 +5284,23 @@ public:
/// Build 'grainsize' clause.
///
/// \param Size Expression associated with this clause.
+ /// \param HelperSize Helper grainsize for the construct.
+ /// \param CaptureRegion Innermost OpenMP region where expressions in this
+ /// clause must be captured.
/// \param StartLoc Starting location of the clause.
/// \param EndLoc Ending location of the clause.
- OMPGrainsizeClause(Expr *Size, SourceLocation StartLoc,
+ OMPGrainsizeClause(Expr *Size, Stmt *HelperSize,
+ OpenMPDirectiveKind CaptureRegion, SourceLocation StartLoc,
SourceLocation LParenLoc, SourceLocation EndLoc)
- : OMPClause(OMPC_grainsize, StartLoc, EndLoc), LParenLoc(LParenLoc),
- Grainsize(Size) {}
+ : OMPClause(OMPC_grainsize, StartLoc, EndLoc), OMPClauseWithPreInit(this),
+ LParenLoc(LParenLoc), Grainsize(Size) {
+ setPreInitStmt(HelperSize, CaptureRegion);
+ }
/// Build an empty clause.
explicit OMPGrainsizeClause()
- : OMPClause(OMPC_grainsize, SourceLocation(), SourceLocation()) {}
+ : OMPClause(OMPC_grainsize, SourceLocation(), SourceLocation()),
+ OMPClauseWithPreInit(this) {}
/// Sets the location of '('.
void setLParenLoc(SourceLocation Loc) { LParenLoc = Loc; }
@@ -5310,11 +5317,10 @@ public:
return const_child_range(&Grainsize, &Grainsize + 1);
}
- child_range used_children() {
- return child_range(child_iterator(), child_iterator());
- }
+ child_range used_children();
const_child_range used_children() const {
- return const_child_range(const_child_iterator(), const_child_iterator());
+ auto Children = const_cast<OMPGrainsizeClause *>(this)->used_children();
+ return const_child_range(Children.begin(), Children.end());
}
static bool classof(const OMPClause *T) {
diff --git a/include/clang/AST/RecursiveASTVisitor.h b/include/clang/AST/RecursiveASTVisitor.h
index cfeaec46c7..ba5232a2f4 100644
--- a/include/clang/AST/RecursiveASTVisitor.h
+++ b/include/clang/AST/RecursiveASTVisitor.h
@@ -3275,6 +3275,7 @@ bool RecursiveASTVisitor<Derived>::VisitOMPPriorityClause(
template <typename Derived>
bool RecursiveASTVisitor<Derived>::VisitOMPGrainsizeClause(
OMPGrainsizeClause *C) {
+ TRY_TO(VisitOMPClauseWithPreInit(C));
TRY_TO(TraverseStmt(C->getGrainsize()));
return true;
}
diff --git a/lib/AST/OpenMPClause.cpp b/lib/AST/OpenMPClause.cpp
index bfe272b1d9..b97607f8c6 100644
--- a/lib/AST/OpenMPClause.cpp
+++ b/lib/AST/OpenMPClause.cpp
@@ -84,6 +84,8 @@ const OMPClauseWithPreInit *OMPClauseWithPreInit::get(const OMPClause *C) {
return static_cast<const OMPThreadLimitClause *>(C);
case OMPC_device:
return static_cast<const OMPDeviceClause *>(C);
+ case OMPC_grainsize:
+ return static_cast<const OMPGrainsizeClause *>(C);
case OMPC_default:
case OMPC_proc_bind:
case OMPC_final:
@@ -113,7 +115,6 @@ const OMPClauseWithPreInit *OMPClauseWithPreInit::get(const OMPClause *C) {
case OMPC_simd:
case OMPC_map:
case OMPC_priority:
- case OMPC_grainsize:
case OMPC_nogroup:
case OMPC_num_tasks:
case OMPC_hint:
@@ -234,6 +235,12 @@ OMPClause::child_range OMPIfClause::used_children() {
return child_range(&Condition, &Condition + 1);
}
+OMPClause::child_range OMPGrainsizeClause::used_children() {
+ if (Stmt **C = getAddrOfExprAsWritten(getPreInitStmt()))
+ return child_range(C, C + 1);
+ return child_range(&Grainsize, &Grainsize + 1);
+}
+
OMPOrderedClause *OMPOrderedClause::Create(const ASTContext &C, Expr *Num,
unsigned NumLoops,
SourceLocation StartLoc,
diff --git a/lib/AST/StmtProfile.cpp b/lib/AST/StmtProfile.cpp
index efc64af22a..6c65f8a1d0 100644
--- a/lib/AST/StmtProfile.cpp
+++ b/lib/AST/StmtProfile.cpp
@@ -740,6 +740,7 @@ void OMPClauseProfiler::VisitOMPPriorityClause(const OMPPriorityClause *C) {
Profiler->VisitStmt(C->getPriority());
}
void OMPClauseProfiler::VisitOMPGrainsizeClause(const OMPGrainsizeClause *C) {
+ VistOMPClauseWithPreInit(C);
if (C->getGrainsize())
Profiler->VisitStmt(C->getGrainsize());
}
diff --git a/lib/Sema/SemaOpenMP.cpp b/lib/Sema/SemaOpenMP.cpp
index f717dc7f3d..94cd80a4b0 100644
--- a/lib/Sema/SemaOpenMP.cpp
+++ b/lib/Sema/SemaOpenMP.cpp
@@ -4590,12 +4590,16 @@ StmtResult Sema::ActOnOpenMPExecutableDirective(
continue;
case OMPC_schedule:
break;
+ case OMPC_grainsize:
+ // Do not analyze if no parent parallel directive.
+ if (isOpenMPParallelDirective(DSAStack->getCurrentDirective()))
+ break;
+ continue;
case OMPC_ordered:
case OMPC_device:
case OMPC_num_teams:
case OMPC_thread_limit:
case OMPC_priority:
- case OMPC_grainsize:
case OMPC_num_tasks:
case OMPC_hint:
case OMPC_collapse:
@@ -10773,6 +10777,74 @@ static OpenMPDirectiveKind getOpenMPCaptureRegionForClause(
llvm_unreachable("Unknown OpenMP directive");
}
break;
+ case OMPC_grainsize:
+ switch (DKind) {
+ case OMPD_task:
+ case OMPD_taskloop:
+ case OMPD_taskloop_simd:
+ case OMPD_master_taskloop:
+ break;
+ case OMPD_parallel_master_taskloop:
+ CaptureRegion = OMPD_parallel;
+ break;
+ case OMPD_target_update:
+ case OMPD_target_enter_data:
+ case OMPD_target_exit_data:
+ case OMPD_target:
+ case OMPD_target_simd:
+ case OMPD_target_teams:
+ case OMPD_target_parallel:
+ case OMPD_target_teams_distribute:
+ case OMPD_target_teams_distribute_simd:
+ case OMPD_target_parallel_for:
+ case OMPD_target_parallel_for_simd:
+ case OMPD_target_teams_distribute_parallel_for:
+ case OMPD_target_teams_distribute_parallel_for_simd:
+ case OMPD_target_data:
+ case OMPD_teams_distribute_parallel_for:
+ case OMPD_teams_distribute_parallel_for_simd:
+ case OMPD_teams:
+ case OMPD_teams_distribute:
+ case OMPD_teams_distribute_simd:
+ case OMPD_distribute_parallel_for:
+ case OMPD_distribute_parallel_for_simd:
+ case OMPD_cancel:
+ case OMPD_parallel:
+ case OMPD_parallel_sections:
+ case OMPD_parallel_for:
+ case OMPD_parallel_for_simd:
+ case OMPD_threadprivate:
+ case OMPD_allocate:
+ case OMPD_taskyield:
+ case OMPD_barrier:
+ case OMPD_taskwait:
+ case OMPD_cancellation_point:
+ case OMPD_flush:
+ case OMPD_declare_reduction:
+ case OMPD_declare_mapper:
+ case OMPD_declare_simd:
+ case OMPD_declare_variant:
+ case OMPD_declare_target:
+ case OMPD_end_declare_target:
+ case OMPD_simd:
+ case OMPD_for:
+ case OMPD_for_simd:
+ case OMPD_sections:
+ case OMPD_section:
+ case OMPD_single:
+ case OMPD_master:
+ case OMPD_critical:
+ case OMPD_taskgroup:
+ case OMPD_distribute:
+ case OMPD_ordered:
+ case OMPD_atomic:
+ case OMPD_distribute_simd:
+ case OMPD_requires:
+ llvm_unreachable("Unexpected OpenMP directive with grainsize-clause");
+ case OMPD_unknown:
+ llvm_unreachable("Unknown OpenMP directive");
+ }
+ break;
case OMPC_firstprivate:
case OMPC_lastprivate:
case OMPC_reduction:
@@ -10808,7 +10880,6 @@ static OpenMPDirectiveKind getOpenMPCaptureRegionForClause(
case OMPC_simd:
case OMPC_map:
case OMPC_priority:
- case OMPC_grainsize:
case OMPC_nogroup:
case OMPC_num_tasks:
case OMPC_hint:
@@ -10926,9 +10997,12 @@ ExprResult Sema::PerformOpenMPImplicitIntegerConversion(SourceLocation Loc,
return PerformContextualImplicitConversion(Loc, Op, ConvertDiagnoser);
}
-static bool isNonNegativeIntegerValue(Expr *&ValExpr, Sema &SemaRef,
- OpenMPClauseKind CKind,
- bool StrictlyPositive) {
+static bool
+isNonNegativeIntegerValue(Expr *&ValExpr, Sema &SemaRef, OpenMPClauseKind CKind,
+ bool StrictlyPositive, bool BuildCapture = false,
+ OpenMPDirectiveKind DKind = OMPD_unknown,
+ OpenMPDirectiveKind *CaptureRegion = nullptr,
+ Stmt **HelperValStmt = nullptr) {
if (!ValExpr->isTypeDependent() && !ValExpr->isValueDependent() &&
!ValExpr->isInstantiationDependent()) {
SourceLocation Loc = ValExpr->getExprLoc();
@@ -10949,6 +11023,16 @@ static bool isNonNegativeIntegerValue(Expr *&ValExpr, Sema &SemaRef,
<< ValExpr->getSourceRange();
return false;
}
+ if (!BuildCapture)
+ return true;
+ *CaptureRegion = getOpenMPCaptureRegionForClause(DKind, CKind);
+ if (*CaptureRegion != OMPD_unknown &&
+ !SemaRef.CurContext->isDependentContext()) {
+ ValExpr = SemaRef.MakeFullExpr(ValExpr).get();
+ llvm::MapVector<const Expr *, DeclRefExpr *> Captures;
+ ValExpr = tryBuildCapture(SemaRef, ValExpr, Captures).get();
+ *HelperValStmt = buildPreInits(SemaRef.Context, Captures);
+ }
}
return true;
}
@@ -15847,15 +15931,20 @@ OMPClause *Sema::ActOnOpenMPGrainsizeClause(Expr *Grainsize,
SourceLocation LParenLoc,
SourceLocation EndLoc) {
Expr *ValExpr = Grainsize;
+ Stmt *HelperValStmt = nullptr;
+ OpenMPDirectiveKind CaptureRegion = OMPD_unknown;
// OpenMP [2.9.2, taskloop Constrcut]
// The parameter of the grainsize clause must be a positive integer
// expression.
- if (!isNonNegativeIntegerValue(ValExpr, *this, OMPC_grainsize,
- /*StrictlyPositive=*/true))
+ if (!isNonNegativeIntegerValue(
+ ValExpr, *this, OMPC_grainsize,
+ /*StrictlyPositive=*/true, /*BuildCapture=*/true,
+ DSAStack->getCurrentDirective(), &CaptureRegion, &HelperValStmt))
return nullptr;
- return new (Context) OMPGrainsizeClause(ValExpr, StartLoc, LParenLoc, EndLoc);
+ return new (Context) OMPGrainsizeClause(ValExpr, HelperValStmt, CaptureRegion,
+ StartLoc, LParenLoc, EndLoc);
}
OMPClause *Sema::ActOnOpenMPNumTasksClause(Expr *NumTasks,
diff --git a/lib/Serialization/ASTReader.cpp b/lib/Serialization/ASTReader.cpp
index 55f2be3e10..0a7958f2c2 100644
--- a/lib/Serialization/ASTReader.cpp
+++ b/lib/Serialization/ASTReader.cpp
@@ -12934,6 +12934,7 @@ void OMPClauseReader::VisitOMPPriorityClause(OMPPriorityClause *C) {
}
void OMPClauseReader::VisitOMPGrainsizeClause(OMPGrainsizeClause *C) {
+ VisitOMPClauseWithPreInit(C);
C->setGrainsize(Record.readSubExpr());
C->setLParenLoc(Record.readSourceLocation());
}
diff --git a/lib/Serialization/ASTWriter.cpp b/lib/Serialization/ASTWriter.cpp
index df89e44680..57c9242504 100644
--- a/lib/Serialization/ASTWriter.cpp
+++ b/lib/Serialization/ASTWriter.cpp
@@ -6938,6 +6938,7 @@ void OMPClauseWriter::VisitOMPPriorityClause(OMPPriorityClause *C) {
}
void OMPClauseWriter::VisitOMPGrainsizeClause(OMPGrainsizeClause *C) {
+ VisitOMPClauseWithPreInit(C);
Record.AddStmt(C->getGrainsize());
Record.AddSourceLocation(C->getLParenLoc());
}
diff --git a/test/OpenMP/parallel_master_taskloop_codegen.cpp b/test/OpenMP/parallel_master_taskloop_codegen.cpp
index ab15c4884b..2a2f4eb598 100644
--- a/test/OpenMP/parallel_master_taskloop_codegen.cpp
+++ b/test/OpenMP/parallel_master_taskloop_codegen.cpp
@@ -14,7 +14,7 @@
int main(int argc, char **argv) {
// CHECK: [[GTID:%.+]] = call i32 @__kmpc_global_thread_num(%struct.ident_t* [[DEFLOC:@.+]])
// CHECK: call void (%struct.ident_t*, i32, void (i32*, i32*, ...)*, ...) @__kmpc_fork_call(%struct.ident_t* [[DEFLOC]], i32 0, void (i32*, i32*, ...)* bitcast (void (i32*, i32*)* [[OMP_OUTLINED1:@.+]] to void (i32*, i32*, ...)*))
-// CHECK: call void (%struct.ident_t*, i32, void (i32*, i32*, ...)*, ...) @__kmpc_fork_call(%struct.ident_t* [[DEFLOC]], i32 0, void (i32*, i32*, ...)* bitcast (void (i32*, i32*)* [[OMP_OUTLINED2:@.+]] to void (i32*, i32*, ...)*))
+// CHECK: call void (%struct.ident_t*, i32, void (i32*, i32*, ...)*, ...) @__kmpc_fork_call(%struct.ident_t* [[DEFLOC]], i32 1, void (i32*, i32*, ...)* bitcast (void (i32*, i32*, i64)* [[OMP_OUTLINED2:@.+]] to void (i32*, i32*, ...)*), i64 [[GRAINSIZE:%.+]])
// CHECK: call void (%struct.ident_t*, i32, void (i32*, i32*, ...)*, ...) @__kmpc_fork_call(%struct.ident_t* [[DEFLOC]], i32 3, void (i32*, i32*, ...)* bitcast (void (i32*, i32*, i32*, i8***, i64)* [[OMP_OUTLINED3:@.+]] to void (i32*, i32*, ...)*), i32* [[ARGC:%.+]], i8*** [[ARGV:%.+]], i64 [[COND:%.+]])
// CHECK: call void @__kmpc_serialized_parallel(%struct.ident_t* [[DEFLOC]], i32 [[GTID]])
// CHECK: call void [[OMP_OUTLINED3]](i32* %{{.+}}, i32* %{{.+}}, i32* [[ARGC]], i8*** [[ARGV]], i64 [[COND]])
@@ -77,7 +77,7 @@ int main(int argc, char **argv) {
#pragma omp parallel master taskloop priority(4)
for (int i = 0; i < 10; ++i)
;
-// CHECK: define internal void [[OMP_OUTLINED2]](i32* noalias %{{.+}}, i32* noalias %{{.+}})
+// CHECK: define internal void [[OMP_OUTLINED2]](i32* noalias %{{.+}}, i32* noalias %{{.+}}, i64 %{{.+}})
// CHECK: [[RES:%.+]] = call {{.*}}i32 @__kmpc_master(%struct.ident_t* [[DEFLOC]], i32 [[GTID:%.+]])
// CHECK-NEXT: [[IS_MASTER:%.+]] = icmp ne i32 [[RES]], 0
// CHECK-NEXT: br i1 [[IS_MASTER]], label {{%?}}[[THEN:.+]], label {{%?}}[[EXIT:.+]]
@@ -92,7 +92,8 @@ int main(int argc, char **argv) {
// CHECK: [[ST:%.+]] = getelementptr inbounds [[TD_TY]], [[TD_TY]]* [[TASK_DATA]], i32 0, i32 7
// CHECK: store i64 1, i64* [[ST]],
// CHECK: [[ST_VAL:%.+]] = load i64, i64* [[ST]],
-// CHECK: call void @__kmpc_taskloop(%struct.ident_t* [[DEFLOC]], i32 [[GTID]], i8* [[TASKV]], i32 1, i64* [[DOWN]], i64* [[UP]], i64 [[ST_VAL]], i32 1, i32 1, i64 4, i8* null)
+// CHECK: [[GRAINSIZE:%.+]] = zext i32 %{{.+}} to i64
+// CHECK: call void @__kmpc_taskloop(%struct.ident_t* [[DEFLOC]], i32 [[GTID]], i8* [[TASKV]], i32 1, i64* [[DOWN]], i64* [[UP]], i64 [[ST_VAL]], i32 1, i32 1, i64 [[GRAINSIZE]], i8* null)
// CHECK-NEXT: call {{.*}}void @__kmpc_end_master(%struct.ident_t* [[DEFLOC]], i32 [[GTID]])
// CHECK-NEXT: br label {{%?}}[[EXIT]]
// CHECK: [[EXIT]]
@@ -128,7 +129,7 @@ int main(int argc, char **argv) {
// CHECK: br label %
// CHECK: ret i32 0
-#pragma omp parallel master taskloop nogroup grainsize(4)
+#pragma omp parallel master taskloop nogroup grainsize(argc)
for (int i = 0; i < 10; ++i)
;
// CHECK: define internal void [[OMP_OUTLINED3]](i32* noalias %{{.+}}, i32* noalias %{{.+}}, i32* dereferenceable(4) %{{.+}}, i8*** dereferenceable(8) %{{.+}}, i64 %{{.+}})