diff --git a/src/dsql/StmtNodes.cpp b/src/dsql/StmtNodes.cpp index 5b99cb1560..4eab1c3675 100644 --- a/src/dsql/StmtNodes.cpp +++ b/src/dsql/StmtNodes.cpp @@ -5329,7 +5329,8 @@ StmtNode* MergeNode::dsqlPass(DsqlCompilerScratch* dsqlScratch) store->dsqlFields = notMatched->fields; store->dsqlValues = notMatched->values; - thisIf->trueAction = store = store->internalDsqlPass(dsqlScratch, false)->as(); + bool needSavePoint; // unused + thisIf->trueAction = store = store->internalDsqlPass(dsqlScratch, false, needSavePoint)->as(); fb_assert(store); if (notMatched->condition) @@ -6436,7 +6437,7 @@ DmlNode* StoreNode::parse(thread_db* tdbb, MemoryPool& pool, CompilerScratch* cs return node; } -StmtNode* StoreNode::internalDsqlPass(DsqlCompilerScratch* dsqlScratch, bool updateOrInsert) +StmtNode* StoreNode::internalDsqlPass(DsqlCompilerScratch* dsqlScratch, bool updateOrInsert, bool& needSavePoint) { thread_db* tdbb = JRD_get_thread_data(); // necessary? DsqlContextStack::AutoRestore autoContext(*dsqlScratch->context); @@ -6460,9 +6461,13 @@ StmtNode* StoreNode::internalDsqlPass(DsqlCompilerScratch* dsqlScratch, bool upd RseNode* rse = PASS1_rse(dsqlScratch, selExpr, false); node->dsqlRse = rse; values = rse->dsqlSelectList; + needSavePoint = false; } else + { values = doDsqlPass(dsqlScratch, dsqlValues, false); + needSavePoint = SubSelectFinder::find(values); + } // Process relation @@ -6598,7 +6603,14 @@ StmtNode* StoreNode::internalDsqlPass(DsqlCompilerScratch* dsqlScratch, bool upd StmtNode* StoreNode::dsqlPass(DsqlCompilerScratch* dsqlScratch) { - return SavepointEncloseNode::make(getPool(), dsqlScratch, internalDsqlPass(dsqlScratch, false)); + bool needSavePoint; + StmtNode* node = SavepointEncloseNode::make(getPool(), dsqlScratch, + internalDsqlPass(dsqlScratch, false, needSavePoint)); + + if (!needSavePoint || node->is()) + return node; + + return FB_NEW SavepointEncloseNode(getPool(), node); } string StoreNode::internalPrint(NodePrinter& printer) const @@ -7979,13 +7991,15 @@ StmtNode* UpdateOrInsertNode::dsqlPass(DsqlCompilerScratch* dsqlScratch) const MetaName& relation_name = relation->as()->dsqlName; MetaName base_name = relation_name; + bool needSavePoint; + // Build the INSERT node. StoreNode* insert = FB_NEW_POOL(pool) StoreNode(pool); insert->dsqlRelation = relation; insert->dsqlFields = fields; insert->dsqlValues = values; insert->dsqlReturning = returning; - insert = insert->internalDsqlPass(dsqlScratch, true)->as(); + insert = insert->internalDsqlPass(dsqlScratch, true, needSavePoint)->as(); fb_assert(insert); dsql_ctx* context = insert->dsqlRelation->dsqlContext; @@ -8165,7 +8179,11 @@ StmtNode* UpdateOrInsertNode::dsqlPass(DsqlCompilerScratch* dsqlScratch) if (!returning) dsqlScratch->getStatement()->setType(DsqlCompiledStatement::TYPE_INSERT); - return SavepointEncloseNode::make(getPool(), dsqlScratch, list); + StmtNode* ret = SavepointEncloseNode::make(getPool(), dsqlScratch, list); + if (!needSavePoint || ret->is()) + return ret; + + return FB_NEW SavepointEncloseNode(getPool(), ret); } string UpdateOrInsertNode::internalPrint(NodePrinter& printer) const diff --git a/src/dsql/StmtNodes.h b/src/dsql/StmtNodes.h index a0e608d5e3..4d5da2bf4a 100644 --- a/src/dsql/StmtNodes.h +++ b/src/dsql/StmtNodes.h @@ -1244,7 +1244,7 @@ public: static DmlNode* parse(thread_db* tdbb, MemoryPool& pool, CompilerScratch* csb, const UCHAR blrOp); virtual Firebird::string internalPrint(NodePrinter& printer) const; - StmtNode* internalDsqlPass(DsqlCompilerScratch* dsqlScratch, bool updateOrInsert); + StmtNode* internalDsqlPass(DsqlCompilerScratch* dsqlScratch, bool updateOrInsert, bool& needSavePoint); virtual StmtNode* dsqlPass(DsqlCompilerScratch* dsqlScratch); virtual void genBlr(DsqlCompilerScratch* dsqlScratch); virtual StoreNode* pass1(thread_db* tdbb, CompilerScratch* csb);