From 52a4c39f41e6f35c5660d196edea6960862971c0 Mon Sep 17 00:00:00 2001 From: asfernandes Date: Fri, 27 Aug 2010 02:18:00 +0000 Subject: [PATCH] 1) Separate DsqlCompilerScratch in its own files. 2) Move BlockNode functionality to it. 3) Move some related CTE functions to it. --- builds/posix/make.shared.variables | 3 +- src/dsql/DSqlDataTypeUtil.cpp | 2 +- src/dsql/DdlNodes.epp | 67 +-- src/dsql/DdlNodes.h | 9 +- src/dsql/DsqlCompilerScratch.cpp | 761 +++++++++++++++++++++++++++++ src/dsql/DsqlCompilerScratch.h | 281 +++++++++++ src/dsql/Nodes.h | 35 +- src/dsql/StmtNodes.cpp | 319 +----------- src/dsql/StmtNodes.h | 6 +- src/dsql/dsql.h | 244 +-------- src/dsql/gen.cpp | 2 +- src/dsql/misc_func.cpp | 2 +- src/dsql/pass1.cpp | 526 +------------------- src/dsql/pass1_proto.h | 1 + 14 files changed, 1109 insertions(+), 1149 deletions(-) create mode 100644 src/dsql/DsqlCompilerScratch.cpp create mode 100644 src/dsql/DsqlCompilerScratch.h diff --git a/builds/posix/make.shared.variables b/builds/posix/make.shared.variables index 4a5ad3260f..17b44c3649 100644 --- a/builds/posix/make.shared.variables +++ b/builds/posix/make.shared.variables @@ -81,7 +81,8 @@ DSQL_ClientFiles = array.epp blob.epp \ DSQL_ServerFiles= metd.epp DSqlDataTypeUtil.cpp \ ddl.cpp dsql.cpp errd.cpp gen.cpp hsh.cpp make.cpp \ movd.cpp parse.cpp Parser.cpp pass1.cpp misc_func.cpp \ - DdlNodes.epp PackageNodes.epp AggNodes.cpp BlrWriter.cpp ExprNodes.cpp StmtNodes.cpp WinNodes.cpp + DdlNodes.epp PackageNodes.epp AggNodes.cpp BlrWriter.cpp DsqlCompilerScratch.cpp \ + ExprNodes.cpp StmtNodes.cpp WinNodes.cpp DSQL_Files = $(DSQL_ClientFiles) $(DSQL_ServerFiles) diff --git a/src/dsql/DSqlDataTypeUtil.cpp b/src/dsql/DSqlDataTypeUtil.cpp index 6be0591f6e..53c3311216 100644 --- a/src/dsql/DSqlDataTypeUtil.cpp +++ b/src/dsql/DSqlDataTypeUtil.cpp @@ -24,7 +24,7 @@ #include "firebird.h" #include "../dsql/DSqlDataTypeUtil.h" -#include "../dsql/dsql.h" +#include "../dsql/DsqlCompilerScratch.h" #include "../dsql/metd_proto.h" UCHAR Jrd::DSqlDataTypeUtil::maxBytesPerChar(UCHAR charSet) diff --git a/src/dsql/DdlNodes.epp b/src/dsql/DdlNodes.epp index 44cc0ae15c..4d2ea415fa 100644 --- a/src/dsql/DdlNodes.epp +++ b/src/dsql/DdlNodes.epp @@ -1157,8 +1157,6 @@ void CreateAlterFunctionNode::print(string& text, Array& /*nodes*/) c DdlNode* CreateAlterFunctionNode::internalDsqlPass() { - DsqlCompiledStatement* const statement = dsqlScratch->getStatement(); - statement->setBlockNode(this); dsqlScratch->flags |= (DsqlCompilerScratch::FLAG_BLOCK | DsqlCompilerScratch::FLAG_FUNCTION); const dsql_nod* variables = localDeclList; @@ -1488,7 +1486,6 @@ void CreateAlterFunctionNode::storeArgument(thread_db* tdbb, jrd_tra* transactio unsigned pos, const ParameterClause& parameter) { Attachment* attachment = transaction->getAttachment(); - DsqlCompiledStatement* statement = dsqlScratch->getStatement(); AutoCacheRequest requestHandle(tdbb, drq_s_func_args2, DYN_REQUESTS); @@ -1591,7 +1588,7 @@ void CreateAlterFunctionNode::storeArgument(thread_db* tdbb, jrd_tra* transactio dsqlScratch->getBlrData().clear(); - if (statement->getFlags() & DsqlCompiledStatement::FLAG_BLR_VERSION4) + if (dsqlScratch->isVersion4()) dsqlScratch->appendUChar(blr_version4); else dsqlScratch->appendUChar(blr_version5); @@ -1618,14 +1615,12 @@ void CreateAlterFunctionNode::compile(thread_db* tdbb, jrd_tra* /*transaction*/) compiled = true; invalid = true; - DsqlCompiledStatement* statement = dsqlScratch->getStatement(); - if (body) { dsqlScratch->beginDebug(); dsqlScratch->getBlrData().clear(); - if (statement->getFlags() & DsqlCompiledStatement::FLAG_BLR_VERSION4) + if (dsqlScratch->isVersion4()) dsqlScratch->appendUChar(blr_version4); else dsqlScratch->appendUChar(blr_version5); @@ -1650,7 +1645,7 @@ void CreateAlterFunctionNode::compile(thread_db* tdbb, jrd_tra* /*transaction*/) dsqlScratch->appendUChar(blr_short); dsqlScratch->appendUChar(0); - variables.add(MAKE_variable(parameter.legacyField, + dsqlScratch->variables.add(MAKE_variable(parameter.legacyField, parameter.name.c_str(), VAR_input, 0, (USHORT) (2 * i), 0)); } } @@ -1667,8 +1662,8 @@ void CreateAlterFunctionNode::compile(thread_db* tdbb, jrd_tra* /*transaction*/) dsqlScratch->appendUChar(0); dsql_nod* const var = MAKE_variable(returnType.legacyField, "", VAR_output, 1, 0, 0); - variables.add(var); - outputVariables.add(var); + dsqlScratch->variables.add(var); + dsqlScratch->outputVariables.add(var); if (parameters.getCount() != 0) { @@ -1695,14 +1690,15 @@ void CreateAlterFunctionNode::compile(thread_db* tdbb, jrd_tra* /*transaction*/) } } - dsql_var* const variable = (dsql_var*) outputVariables[0]->nod_arg[Dsql::e_var_variable]; - putLocalVariable(dsqlScratch, variable, 0, NULL); + dsql_var* const variable = + (dsql_var*) dsqlScratch->outputVariables[0]->nod_arg[Dsql::e_var_variable]; + dsqlScratch->putLocalVariable(variable, 0, NULL); // ASF: This is here to not change the old logic (proc_flag) // of previous calls to PASS1_node and PASS1_statement. dsqlScratch->setPsql(true); - putLocalVariables(dsqlScratch, localDeclList, 1); + dsqlScratch->putLocalVariables(localDeclList, 1); dsqlScratch->appendUChar(blr_stall); // put a label before body of procedure, @@ -1714,10 +1710,9 @@ void CreateAlterFunctionNode::compile(thread_db* tdbb, jrd_tra* /*transaction*/) GEN_statement(dsqlScratch, PASS1_statement(dsqlScratch, body)); - statement->setType(DsqlCompiledStatement::TYPE_DDL); + dsqlScratch->getStatement()->setType(DsqlCompiledStatement::TYPE_DDL); dsqlScratch->appendUChar(blr_end); - genReturn(dsqlScratch, false); - + dsqlScratch->genReturn(false); dsqlScratch->appendUChar(blr_end); dsqlScratch->appendUChar(blr_eoc); @@ -1902,8 +1897,6 @@ void CreateAlterProcedureNode::print(string& text, Array& /*nodes*/) DdlNode* CreateAlterProcedureNode::internalDsqlPass() { - DsqlCompiledStatement* statement = dsqlScratch->getStatement(); - statement->setBlockNode(this); dsqlScratch->flags |= (DsqlCompilerScratch::FLAG_BLOCK | DsqlCompilerScratch::FLAG_PROCEDURE); const dsql_nod* variables = localDeclList; @@ -2340,11 +2333,9 @@ void CreateAlterProcedureNode::storeParameter(thread_db* tdbb, jrd_tra* transact string defaultSource = string(defaultString->str_data, defaultString->str_length); attachment->storeMetaDataBlob(tdbb, transaction, &PRM.RDB$DEFAULT_SOURCE, defaultSource); - DsqlCompiledStatement* statement = dsqlScratch->getStatement(); - dsqlScratch->getBlrData().clear(); - if (statement->getFlags() & DsqlCompiledStatement::FLAG_BLR_VERSION4) + if (dsqlScratch->isVersion4()) dsqlScratch->appendUChar(blr_version4); else dsqlScratch->appendUChar(blr_version5); @@ -2375,12 +2366,10 @@ void CreateAlterProcedureNode::compile(thread_db* tdbb, jrd_tra* /*transaction*/ invalid = true; - DsqlCompiledStatement* statement = dsqlScratch->getStatement(); - dsqlScratch->beginDebug(); dsqlScratch->getBlrData().clear(); - if (statement->getFlags() & DsqlCompiledStatement::FLAG_BLR_VERSION4) + if (dsqlScratch->isVersion4()) dsqlScratch->appendUChar(blr_version4); else dsqlScratch->appendUChar(blr_version5); @@ -2404,7 +2393,7 @@ void CreateAlterProcedureNode::compile(thread_db* tdbb, jrd_tra* /*transaction*/ dsqlScratch->appendUChar(blr_short); dsqlScratch->appendUChar(0); - variables.add(MAKE_variable(parameter.legacyField, + dsqlScratch->variables.add(MAKE_variable(parameter.legacyField, parameter.name.c_str(), VAR_input, 0, (USHORT) (2 * i), 0)); } } @@ -2429,8 +2418,8 @@ void CreateAlterProcedureNode::compile(thread_db* tdbb, jrd_tra* /*transaction*/ dsql_nod* const var = MAKE_variable(parameter.legacyField, parameter.name.c_str(), VAR_output, 1, (USHORT) (2 * i), i); - variables.add(var); - outputVariables.add(var); + dsqlScratch->variables.add(var); + dsqlScratch->outputVariables.add(var); } } @@ -2463,18 +2452,20 @@ void CreateAlterProcedureNode::compile(thread_db* tdbb, jrd_tra* /*transaction*/ } } - for (Array::const_iterator i = outputVariables.begin(); i != outputVariables.end(); ++i) + for (Array::const_iterator i = dsqlScratch->outputVariables.begin(); + i != dsqlScratch->outputVariables.end(); + ++i) { dsql_nod* parameter = *i; dsql_var* const variable = (dsql_var*) parameter->nod_arg[Dsql::e_var_variable]; - putLocalVariable(dsqlScratch, variable, 0, NULL); + dsqlScratch->putLocalVariable(variable, 0, NULL); } // ASF: This is here to not change the old logic (proc_flag) // of previous calls to PASS1_node and PASS1_statement. dsqlScratch->setPsql(true); - putLocalVariables(dsqlScratch, localDeclList, returns.getCount()); + dsqlScratch->putLocalVariables(localDeclList, returns.getCount()); dsqlScratch->appendUChar(blr_stall); // put a label before body of procedure, @@ -2486,10 +2477,9 @@ void CreateAlterProcedureNode::compile(thread_db* tdbb, jrd_tra* /*transaction*/ GEN_statement(dsqlScratch, PASS1_statement(dsqlScratch, body)); - statement->setType(DsqlCompiledStatement::TYPE_DDL); + dsqlScratch->getStatement()->setType(DsqlCompiledStatement::TYPE_DDL); dsqlScratch->appendUChar(blr_end); - genReturn(dsqlScratch, true); - + dsqlScratch->genReturn(true); dsqlScratch->appendUChar(blr_end); dsqlScratch->appendUChar(blr_eoc); @@ -2779,8 +2769,6 @@ void CreateAlterTriggerNode::print(string& text, Array& /*nodes*/) co DdlNode* CreateAlterTriggerNode::internalDsqlPass() { - DsqlCompiledStatement* statement = dsqlScratch->getStatement(); - statement->setBlockNode(this); dsqlScratch->flags |= (DsqlCompilerScratch::FLAG_BLOCK | DsqlCompilerScratch::FLAG_TRIGGER); if (type.specified) @@ -2876,8 +2864,6 @@ void CreateAlterTriggerNode::compile(thread_db* tdbb, jrd_tra* /*transaction*/) if (body) { - DsqlCompiledStatement* statement = dsqlScratch->getStatement(); - dsqlScratch->beginDebug(); dsqlScratch->getBlrData().clear(); @@ -2921,7 +2907,7 @@ void CreateAlterTriggerNode::compile(thread_db* tdbb, jrd_tra* /*transaction*/) // generate the trigger blr - if (statement->getFlags() & DsqlCompiledStatement::FLAG_BLR_VERSION4) + if (dsqlScratch->isVersion4()) dsqlScratch->appendUChar(blr_version4); else dsqlScratch->appendUChar(blr_version5); @@ -2929,8 +2915,7 @@ void CreateAlterTriggerNode::compile(thread_db* tdbb, jrd_tra* /*transaction*/) dsqlScratch->appendUChar(blr_begin); dsqlScratch->setPsql(true); - - putLocalVariables(dsqlScratch, localDeclList, 0); + dsqlScratch->putLocalVariables(localDeclList, 0); dsqlScratch->scopeLevel++; // dimitr: I see no reason to deny EXIT command in triggers, @@ -2954,7 +2939,7 @@ void CreateAlterTriggerNode::compile(thread_db* tdbb, jrd_tra* /*transaction*/) // The statement type may have been set incorrectly when parsing // the trigger actions, so reset it to reflect the fact that this // is a data definition statement; also reset the ddl node. - statement->setType(DsqlCompiledStatement::TYPE_DDL); + dsqlScratch->getStatement()->setType(DsqlCompiledStatement::TYPE_DDL); } invalid = false; diff --git a/src/dsql/DdlNodes.h b/src/dsql/DdlNodes.h index a15fab4e84..1153db0733 100644 --- a/src/dsql/DdlNodes.h +++ b/src/dsql/DdlNodes.h @@ -229,12 +229,11 @@ private: }; -class CreateAlterFunctionNode : public DdlNode, public BlockNode +class CreateAlterFunctionNode : public DdlNode { public: CreateAlterFunctionNode(MemoryPool& pool, const Firebird::MetaName& aName) : DdlNode(pool), - BlockNode(pool, false), name(pool, aName), create(true), alter(false), @@ -334,12 +333,11 @@ typedef RecreateNodefld_dtype > FB_NELEM(blr_dtypes) || !blr_dtypes[field->fld_dtype]) + { + SCHAR buffer[100]; + + sprintf(buffer, "Invalid dtype %d in BlockNode::putDtype", field->fld_dtype); + ERRD_bugcheck(buffer); + } +#endif + + if (field->fld_not_nullable) + appendUChar(blr_not_nullable); + + if (field->fld_type_of_name.hasData()) + { + if (field->fld_type_of_table) + { + if (field->fld_explicit_collation) + { + appendUChar(blr_column_name2); + appendUChar(field->fld_full_domain ? blr_domain_full : blr_domain_type_of); + appendMetaString(field->fld_type_of_table->str_data); + appendMetaString(field->fld_type_of_name.c_str()); + appendUShort(field->fld_ttype); + } + else + { + appendUChar(blr_column_name); + appendUChar(field->fld_full_domain ? blr_domain_full : blr_domain_type_of); + appendMetaString(field->fld_type_of_table->str_data); + appendMetaString(field->fld_type_of_name.c_str()); + } + } + else + { + if (field->fld_explicit_collation) + { + appendUChar(blr_domain_name2); + appendUChar(field->fld_full_domain ? blr_domain_full : blr_domain_type_of); + appendMetaString(field->fld_type_of_name.c_str()); + appendUShort(field->fld_ttype); + } + else + { + appendUChar(blr_domain_name); + appendUChar(field->fld_full_domain ? blr_domain_full : blr_domain_type_of); + appendMetaString(field->fld_type_of_name.c_str()); + } + } + + return; + } + + switch (field->fld_dtype) + { + case dtype_cstring: + case dtype_text: + case dtype_varying: + case dtype_blob: + if (!useSubType) + appendUChar(blr_dtypes[field->fld_dtype]); + else if (field->fld_dtype == dtype_varying) + { + appendUChar(blr_varying2); + appendUShort(field->fld_ttype); + } + else if (field->fld_dtype == dtype_cstring) + { + appendUChar(blr_cstring2); + appendUShort(field->fld_ttype); + } + else if (field->fld_dtype == dtype_blob) + { + appendUChar(blr_blob2); + appendUShort(field->fld_sub_type); + appendUShort(field->fld_ttype); + } + else + { + appendUChar(blr_text2); + appendUShort(field->fld_ttype); + } + + if (field->fld_dtype == dtype_varying) + appendUShort(field->fld_length - sizeof(USHORT)); + else if (field->fld_dtype != dtype_blob) + appendUShort(field->fld_length); + break; + + default: + appendUChar(blr_dtypes[field->fld_dtype]); + if (DTYPE_IS_EXACT(field->fld_dtype) || (dtype_quad == field->fld_dtype)) + appendUChar(field->fld_scale); + break; + } +} + +// Emit dyn for the local variables declared in a procedure or trigger. +void DsqlCompilerScratch::putLocalVariables(const dsql_nod* parameters, SSHORT locals) +{ + if (!parameters) + return; + + dsql_nod* const* ptr = parameters->nod_arg; + + for (const dsql_nod* const* const end = ptr + parameters->nod_count; ptr < end; ptr++) + { + dsql_nod* parameter = *ptr; + + putDebugSrcInfo(parameter->nod_line, parameter->nod_column); + + if (parameter->nod_type == Dsql::nod_def_field) + { + dsql_fld* field = (dsql_fld*) parameter->nod_arg[Dsql::e_dfl_field]; + const dsql_nod* const* rest = ptr; + + while (++rest != end) + { + if ((*rest)->nod_type == Dsql::nod_def_field) + { + const dsql_fld* rest_field = (dsql_fld*) (*rest)->nod_arg[Dsql::e_dfl_field]; + if (field->fld_name == rest_field->fld_name) + { + ERRD_post(Arg::Gds(isc_sqlerr) << Arg::Num(-637) << + Arg::Gds(isc_dsql_duplicate_spec) << Arg::Str(field->fld_name)); + } + } + } + + dsql_nod* varNode = MAKE_variable(field, field->fld_name.c_str(), VAR_local, 0, 0, locals); + variables.add(varNode); + + dsql_var* variable = (dsql_var*) varNode->nod_arg[Dsql::e_var_variable]; + putLocalVariable(variable, parameter, + reinterpret_cast(parameter->nod_arg[Dsql::e_dfl_collate])); + + // Some field attributes are calculated inside + // putLocalVariable(), so we reinitialize the + // descriptor + MAKE_desc_from_field(&varNode->nod_desc, field); + + ++locals; + } + else if (parameter->nod_type == Dsql::nod_cursor) + { + PASS1_statement(this, parameter); + GEN_statement(this, parameter); + } + } +} + +// Write out local variable field data type. +void DsqlCompilerScratch::putLocalVariable(dsql_var* variable, dsql_nod* hostParam, + const dsql_str* collationName) +{ + dsql_fld* field = variable->var_field; + + appendUChar(blr_dcl_variable); + appendUShort(variable->var_variable_number); + DDL_resolve_intl_type(this, field, collationName); + + //const USHORT dtype = field->fld_dtype; + + putDtype(field, true); + //field->fld_dtype = dtype; + + // Check for a default value, borrowed from define_domain + dsql_nod* node = hostParam ? hostParam->nod_arg[Dsql::e_dfl_default] : NULL; + + if (node || (!field->fld_full_domain && !field->fld_not_nullable)) + { + appendUChar(blr_assignment); + + if (node) + { + fb_assert(node->nod_type == Dsql::nod_def_default); + PsqlChanger psqlChanger(this, false); + node = PASS1_node(this, node->nod_arg[Dsql::e_dft_default]); + GEN_expr(this, node); + } + else + appendUChar(blr_null); // Initialize variable to NULL + + appendUChar(blr_variable); + appendUShort(variable->var_variable_number); + } + else + { + appendUChar(blr_init_variable); + appendUShort(variable->var_variable_number); + } + + if (variable->var_name[0]) // Not a function return value + putDebugVariable(variable->var_variable_number, variable->var_name); + + ++hiddenVarsNumber; +} + +// Try to resolve variable name against parameters and local variables. +dsql_nod* DsqlCompilerScratch::resolveVariable(const dsql_str* varName) +{ + for (dsql_nod* const* i = variables.begin(); i != variables.end(); ++i) + { + dsql_nod* varNode = *i; + fb_assert(varNode->nod_type == Dsql::nod_variable); + + if (varNode->nod_type == Dsql::nod_variable) + { + const dsql_var* variable = (dsql_var*) varNode->nod_arg[Dsql::e_var_variable]; + DEV_BLKCHK(variable, dsql_type_var); + + if (!strcmp(varName->str_data, variable->var_name)) + return varNode; + } + } + + return NULL; +} + +// Generate BLR for a return. +void DsqlCompilerScratch::genReturn(bool eosFlag) +{ + const bool hasEos = !(flags & (FLAG_TRIGGER | FLAG_FUNCTION)); + + if (hasEos && !eosFlag) + appendUChar(blr_begin); + + appendUChar(blr_send); + appendUChar(1); + appendUChar(blr_begin); + + for (Array::const_iterator i = outputVariables.begin(); i != outputVariables.end(); ++i) + { + const dsql_nod* parameter = *i; + const dsql_var* variable = (dsql_var*) parameter->nod_arg[Dsql::e_var_variable]; + appendUChar(blr_assignment); + appendUChar(blr_variable); + appendUShort(variable->var_variable_number); + appendUChar(blr_parameter2); + appendUChar(variable->var_msg_number); + appendUShort(variable->var_msg_item); + appendUShort(variable->var_msg_item + 1); + } + + if (hasEos) + { + appendUChar(blr_assignment); + appendUChar(blr_literal); + appendUChar(blr_short); + appendUChar(0); + appendUShort((eosFlag ? 0 : 1)); + appendUChar(blr_parameter); + appendUChar(1); + appendUShort(USHORT(2 * outputVariables.getCount())); + } + + appendUChar(blr_end); + + if (hasEos && !eosFlag) + { + appendUChar(blr_stall); + appendUChar(blr_end); + } +} + +void DsqlCompilerScratch::addCTEs(dsql_nod* with) +{ + DEV_BLKCHK(with, dsql_type_nod); + fb_assert(with->nod_type == Dsql::nod_with); + + if (ctes.getCount()) + { + ERRD_post(Arg::Gds(isc_sqlerr) << Arg::Num(-104) << + // WITH clause can't be nested + Arg::Gds(isc_dsql_cte_nested_with)); + } + + if (with->nod_flags & NOD_UNION_RECURSIVE) + flags |= DsqlCompilerScratch::FLAG_RECURSIVE_CTE; + + const dsql_nod* list = with->nod_arg[0]; + const dsql_nod* const* end = list->nod_arg + list->nod_count; + + for (dsql_nod* const* cte = list->nod_arg; cte < end; cte++) + { + fb_assert((*cte)->nod_type == Dsql::nod_derived_table); + + if (with->nod_flags & NOD_UNION_RECURSIVE) + { + currCtes.push(*cte); + PsqlChanger changer(this, false); + ctes.add(pass1RecursiveCte(*cte)); + currCtes.pop(); + + // Add CTE name into CTE aliases stack. It allows later to search for + // aliases of given CTE. + const dsql_str* cteName = (dsql_str*) (*cte)->nod_arg[Dsql::e_derived_table_alias]; + addCTEAlias(cteName); + } + else + ctes.add(*cte); + } +} + +dsql_nod* DsqlCompilerScratch::findCTE(const dsql_str* name) +{ + for (size_t i = 0; i < ctes.getCount(); ++i) + { + dsql_nod* cte = ctes[i]; + const dsql_str* cteName = (dsql_str*) cte->nod_arg[Dsql::e_derived_table_alias]; + + if (name->str_length == cteName->str_length && + strncmp(name->str_data, cteName->str_data, cteName->str_length) == 0) + { + return cte; + } + } + + return NULL; +} + +void DsqlCompilerScratch::clearCTEs() +{ + flags &= ~DsqlCompilerScratch::FLAG_RECURSIVE_CTE; + ctes.clear(); + cteAliases.clear(); +} + +void DsqlCompilerScratch::checkUnusedCTEs() const +{ + for (size_t i = 0; i < ctes.getCount(); ++i) + { + const dsql_nod* cte = ctes[i]; + + if (!(cte->nod_flags & NOD_DT_CTE_USED)) + { + const dsql_str* cteName = (dsql_str*) cte->nod_arg[Dsql::e_derived_table_alias]; + + ERRD_post(Arg::Gds(isc_sqlerr) << Arg::Num(-104) << + Arg::Gds(isc_dsql_cte_not_used) << Arg::Str(cteName->str_data)); + } + } +} + +// Process derived table which can be recursive CTE. +// If it is non-recursive return input node unchanged. +// If it is recursive return new derived table which is an union of union of anchor (non-recursive) +// queries and union of recursive queries. Check recursive queries to satisfy various criterias. +// Note that our parser is right-to-left therefore nested list linked as first node in parent list +// and second node is always query spec. +// For example, if we have 4 CTE's where first two is non-recursive and last two is recursive: +// +// list union +// [0] [1] [0] [1] +// list cte3 ===> anchor recursive +// [0] [1] [0] [1] [0] [1] +// list cte3 cte1 cte2 cte3 cte4 +// [0] [1] +// cte1 cte2 +// +// Also, we should not change layout of original parse tree to allow it to be parsed again if +// needed. Therefore recursive part is built using newly allocated list nodes. +dsql_nod* DsqlCompilerScratch::pass1RecursiveCte(dsql_nod* input) +{ + dsql_str* const cte_alias = (dsql_str*) input->nod_arg[Dsql::e_derived_table_alias]; + dsql_nod* const select_expr = input->nod_arg[Dsql::e_derived_table_rse]; + dsql_nod* query = select_expr->nod_arg[Dsql::e_sel_query_spec]; + + if (query->nod_type != Dsql::nod_list && pass1RseIsRecursive(query)) + { + ERRD_post(Arg::Gds(isc_sqlerr) << Arg::Num(-104) << + // Recursive CTE (%s) must be an UNION + Arg::Gds(isc_dsql_cte_not_a_union) << Arg::Str(cte_alias->str_data)); + } + + // split queries list on two parts: anchor and recursive + dsql_nod* anchorRse = NULL, *recursiveRse = NULL; + dsql_nod* qry = query; + + dsql_nod* newQry = MAKE_node(Dsql::nod_list, 2); + newQry->nod_flags = query->nod_flags; + + while (true) + { + dsql_nod* rse = NULL; + + if (qry->nod_type == Dsql::nod_list) + rse = qry->nod_arg[1]; + else + rse = qry; + + dsql_nod* newRse = pass1RseIsRecursive(rse); + + if (newRse) // rse is recursive + { + if (anchorRse) + { + ERRD_post(Arg::Gds(isc_sqlerr) << Arg::Num(-104) << + // CTE '%s' defined non-recursive member after recursive + Arg::Gds(isc_dsql_cte_nonrecurs_after_recurs) << Arg::Str(cte_alias->str_data)); + } + + if (newRse->nod_arg[Dsql::e_qry_distinct]) + { + ERRD_post(Arg::Gds(isc_sqlerr) << Arg::Num(-104) << + // Recursive member of CTE '%s' has %s clause + Arg::Gds(isc_dsql_cte_wrong_clause) << Arg::Str(cte_alias->str_data) << + Arg::Str("DISTINCT")); + } + + if (newRse->nod_arg[Dsql::e_qry_group]) + { + ERRD_post(Arg::Gds(isc_sqlerr) << Arg::Num(-104) << + // Recursive member of CTE '%s' has %s clause + Arg::Gds(isc_dsql_cte_wrong_clause) << Arg::Str(cte_alias->str_data) << + Arg::Str("GROUP BY")); + } + + if (newRse->nod_arg[Dsql::e_qry_having]) + { + ERRD_post(Arg::Gds(isc_sqlerr) << Arg::Num(-104) << + // Recursive member of CTE '%s' has %s clause + Arg::Gds(isc_dsql_cte_wrong_clause) << Arg::Str(cte_alias->str_data) << + Arg::Str("HAVING")); + } + // hvlad: we need also forbid any aggregate function here + // but for now i have no idea how to do it simple + + if ((newQry->nod_type == Dsql::nod_list) && !(newQry->nod_flags & NOD_UNION_ALL)) + { + ERRD_post(Arg::Gds(isc_sqlerr) << Arg::Num(-104) << + // Recursive members of CTE (%s) must be linked with another members via UNION ALL + Arg::Gds(isc_dsql_cte_union_all) << Arg::Str(cte_alias->str_data)); + } + + if (!recursiveRse) + recursiveRse = newQry; + + newRse->nod_flags |= NOD_SELECT_EXPR_RECURSIVE; + + if (qry->nod_type == Dsql::nod_list) + newQry->nod_arg[1] = newRse; + else + newQry->nod_arg[0] = newRse; + } + else + { + if (qry->nod_type == Dsql::nod_list) + newQry->nod_arg[1] = rse; + else + newQry->nod_arg[0] = rse; + + if (!anchorRse) + { + if (qry->nod_type == Dsql::nod_list) + anchorRse = newQry; + else + anchorRse = rse; + } + } + + if (qry->nod_type != Dsql::nod_list) + break; + + qry = qry->nod_arg[0]; + + if (qry->nod_type == Dsql::nod_list) + { + newQry->nod_arg[0] = MAKE_node(Dsql::nod_list, 2); + newQry = newQry->nod_arg[0]; + newQry->nod_flags = qry->nod_flags; + } + } + + if (!recursiveRse) + return input; + + if (!anchorRse) + { + ERRD_post(Arg::Gds(isc_sqlerr) << Arg::Num(-104) << + // Non-recursive member is missing in CTE '%s' + Arg::Gds(isc_dsql_cte_miss_nonrecursive) << Arg::Str(cte_alias->str_data)); + } + + qry = recursiveRse; + dsql_nod* list = NULL; + + while (qry->nod_arg[0] != anchorRse) + { + list = qry; + qry = qry->nod_arg[0]; + } + + qry->nod_arg[0] = 0; + + if (list) + list->nod_arg[0] = qry->nod_arg[1]; + else + recursiveRse = qry->nod_arg[1]; + + dsql_nod* unionNode = MAKE_node(Dsql::nod_list, 2); + unionNode->nod_flags = NOD_UNION_ALL | NOD_UNION_RECURSIVE; + unionNode->nod_arg[0] = anchorRse; + unionNode->nod_arg[1] = recursiveRse; + + dsql_nod* select = MAKE_node(Dsql::nod_select_expr, Dsql::e_sel_count); + select->nod_arg[Dsql::e_sel_query_spec] = unionNode; + select->nod_arg[Dsql::e_sel_order] = select->nod_arg[Dsql::e_sel_rows] = + select->nod_arg[Dsql::e_sel_with_list] = NULL; + + dsql_nod* node = MAKE_node(Dsql::nod_derived_table, Dsql::e_derived_table_count); + dsql_str* alias = (dsql_str*) input->nod_arg[Dsql::e_derived_table_alias]; + node->nod_arg[Dsql::e_derived_table_alias] = (dsql_nod*) alias; + node->nod_arg[Dsql::e_derived_table_column_alias] = + input->nod_arg[Dsql::e_derived_table_column_alias]; + node->nod_arg[Dsql::e_derived_table_rse] = select; + node->nod_arg[Dsql::e_derived_table_context] = input->nod_arg[Dsql::e_derived_table_context]; + + return node; +} + +// Check if rse is recursive. If recursive reference is a table in the FROM list remove it. +// If recursive reference is a part of join add join boolean (returned by pass1JoinIsRecursive) +// to the WHERE clause. Punt if more than one recursive reference is found. +dsql_nod* DsqlCompilerScratch::pass1RseIsRecursive(dsql_nod* input) +{ + fb_assert(input->nod_type == Dsql::nod_query_spec); + + dsql_nod* result = MAKE_node(Dsql::nod_query_spec, Dsql::e_qry_count); + memcpy(result->nod_arg, input->nod_arg, Dsql::e_qry_count * sizeof(dsql_nod*)); + + dsql_nod* srcTables = input->nod_arg[Dsql::e_qry_from]; + dsql_nod* dstTables = MAKE_node(Dsql::nod_list, srcTables->nod_count); + result->nod_arg[Dsql::e_qry_from] = dstTables; + + dsql_nod** pDstTable = dstTables->nod_arg; + dsql_nod** pSrcTable = srcTables->nod_arg; + dsql_nod** end = srcTables->nod_arg + srcTables->nod_count; + bool found = false; + + for (dsql_nod** prev = pDstTable; pSrcTable < end; ++pSrcTable, ++pDstTable) + { + *prev++ = *pDstTable = *pSrcTable; + + switch ((*pDstTable)->nod_type) + { + case Dsql::nod_rel_proc_name: + case Dsql::nod_relation_name: + if (pass1RelProcIsRecursive(*pDstTable)) + { + if (found) + { + ERRD_post(Arg::Gds(isc_sqlerr) << Arg::Num(-104) << + // Recursive member of CTE can't reference itself more than once + Arg::Gds(isc_dsql_cte_mult_references)); + } + found = true; + + prev--; + dstTables->nod_count--; + } + break; + + case Dsql::nod_join: + { + *pDstTable = MAKE_node(Dsql::nod_join, Dsql::e_join_count); + memcpy((*pDstTable)->nod_arg, (*pSrcTable)->nod_arg, + Dsql::e_join_count * sizeof(dsql_nod*)); + + dsql_nod* joinBool = pass1JoinIsRecursive(*pDstTable); + if (joinBool) + { + if (found) + { + ERRD_post(Arg::Gds(isc_sqlerr) << Arg::Num(-104) << + // Recursive member of CTE can't reference itself more than once + Arg::Gds(isc_dsql_cte_mult_references)); + } + found = true; + + result->nod_arg[Dsql::e_qry_where] = + PASS1_compose(result->nod_arg[Dsql::e_qry_where], joinBool, Dsql::nod_and); + } + + break; + } + + case Dsql::nod_derived_table: + break; + + default: + fb_assert(false); + } + } + + return found ? result : NULL; +} + +// Check if table reference is recursive i.e. its name is equal to the name of current processing CTE. +bool DsqlCompilerScratch::pass1RelProcIsRecursive(dsql_nod* input) +{ + const dsql_str* relName = NULL; + const dsql_str* relAlias = NULL; + + switch (input->nod_type) + { + case Dsql::nod_rel_proc_name: + relName = (dsql_str*) input->nod_arg[Dsql::e_rpn_name]; + relAlias = (dsql_str*) input->nod_arg[Dsql::e_rpn_alias]; + break; + + case Dsql::nod_relation_name: + relName = (dsql_str*) input->nod_arg[Dsql::e_rln_name]; + relAlias = (dsql_str*) input->nod_arg[Dsql::e_rln_alias]; + break; + + default: + return false; + } + + fb_assert(currCtes.hasData()); + const dsql_nod* curr_cte = currCtes.object(); + const dsql_str* cte_name = (dsql_str*) curr_cte->nod_arg[Dsql::e_derived_table_alias]; + + const bool recursive = (cte_name->str_length == relName->str_length) && + (strncmp(relName->str_data, cte_name->str_data, cte_name->str_length) == 0); + + if (recursive) + addCTEAlias(relAlias ? relAlias : relName); + + return recursive; +} + +// Check if join have recursive members. If found remove this member from join and return its +// boolean (to be added into WHERE clause). +// We must remove member only if it is a table reference. Punt if recursive reference is found in +// outer join or more than one recursive reference is found +dsql_nod* DsqlCompilerScratch::pass1JoinIsRecursive(dsql_nod*& input) +{ + const NOD_TYPE join_type = input->nod_arg[Dsql::e_join_type]->nod_type; + bool remove = false; + + bool leftRecursive = false; + dsql_nod* leftBool = NULL; + dsql_nod** join_table = &input->nod_arg[Dsql::e_join_left_rel]; + + if ((*join_table)->nod_type == Dsql::nod_join) + { + leftBool = pass1JoinIsRecursive(*join_table); + leftRecursive = (leftBool != NULL); + } + else + { + leftBool = input->nod_arg[Dsql::e_join_boolean]; + leftRecursive = pass1RelProcIsRecursive(*join_table); + + if (leftRecursive) + remove = true; + } + + if (leftRecursive && join_type != Dsql::nod_join_inner) + { + ERRD_post(Arg::Gds(isc_sqlerr) << Arg::Num(-104) << + // Recursive member of CTE can't be member of an outer join + Arg::Gds(isc_dsql_cte_outer_join)); + } + + bool rightRecursive = false; + dsql_nod* rightBool = NULL; + + join_table = &input->nod_arg[Dsql::e_join_rght_rel]; + + if ((*join_table)->nod_type == Dsql::nod_join) + { + rightBool = pass1JoinIsRecursive(*join_table); + rightRecursive = (rightBool != NULL); + } + else + { + rightBool = input->nod_arg[Dsql::e_join_boolean]; + rightRecursive = pass1RelProcIsRecursive(*join_table); + + if (rightRecursive) + remove = true; + } + + if (rightRecursive && join_type != Dsql::nod_join_inner) + { + ERRD_post(Arg::Gds(isc_sqlerr) << Arg::Num(-104) << + // Recursive member of CTE can't be member of an outer join + Arg::Gds(isc_dsql_cte_outer_join)); + } + + if (leftRecursive && rightRecursive) + { + ERRD_post(Arg::Gds(isc_sqlerr) << Arg::Num(-104) << + // Recursive member of CTE can't reference itself more than once + Arg::Gds(isc_dsql_cte_mult_references)); + } + + if (leftRecursive) + { + if (remove) + input = input->nod_arg[Dsql::e_join_rght_rel]; + + return leftBool; + } + + if (rightRecursive) + { + if (remove) + input = input->nod_arg[Dsql::e_join_left_rel]; + + return rightBool; + } + + return NULL; +} diff --git a/src/dsql/DsqlCompilerScratch.h b/src/dsql/DsqlCompilerScratch.h new file mode 100644 index 0000000000..13540ccd09 --- /dev/null +++ b/src/dsql/DsqlCompilerScratch.h @@ -0,0 +1,281 @@ +/* + * + * The contents of this file are subject to the Interbase Public + * License Version 1.0 (the "License"); you may not use this file + * except in compliance with the License. You may obtain a copy + * of the License at http://www.Inprise.com/IPL.html + * + * Software distributed under the License is distributed on an + * "AS IS" basis, WITHOUT WARRANTY OF ANY KIND, either express + * or implied. See the License for the specific language governing + * rights and limitations under the License. + * + * The Original Code was created by Inprise Corporation + * and its predecessors. Portions created by Inprise Corporation are + * Copyright (C) Inprise Corporation. + * + * All Rights Reserved. + * Contributor(s): ______________________________________. + * Adriano dos Santos Fernandes + */ + +#ifndef DSQL_COMPILER_SCRATCH_H +#define DSQL_COMPILER_SCRATCH_H + +#include "../jrd/common.h" +#include "../dsql/dsql.h" +#include "../dsql/BlrWriter.h" +#include "../common/classes/array.h" +#include "../common/classes/MetaName.h" +#include "../common/classes/stack.h" +#include "../common/classes/alloc.h" + +namespace Jrd +{ + +// DSQL Compiler scratch block - may be discarded after compilation in the future. +class DsqlCompilerScratch : public BlrWriter +{ +public: + static const unsigned FLAG_IN_AUTO_TRANS_BLOCK = 0x001; + static const unsigned FLAG_RETURNING_INTO = 0x002; + static const unsigned FLAG_METADATA_SAVED = 0x004; + static const unsigned FLAG_PROCEDURE = 0x008; + static const unsigned FLAG_TRIGGER = 0x010; + static const unsigned FLAG_BLOCK = 0x020; + static const unsigned FLAG_RECURSIVE_CTE = 0x040; + static const unsigned FLAG_UPDATE_OR_INSERT = 0x080; + static const unsigned FLAG_MERGE = 0x100; + static const unsigned FLAG_FUNCTION = 0x200; + +public: + DsqlCompilerScratch(MemoryPool& p, dsql_dbb* aDbb, jrd_tra* aTransaction, + DsqlCompiledStatement* aStatement) + : BlrWriter(p), + dbb(aDbb), + transaction(aTransaction), + statement(aStatement), + flags(0), + ports(p), + relation(NULL), + procedure(NULL), + mainContext(p), + context(&mainContext), + unionContext(p), + derivedContext(p), + outerAggContext(NULL), + contextNumber(0), + derivedContextNumber(0), + scopeLevel(0), + loopLevel(0), + labels(p), + cursorNumber(0), + cursors(p), + inSelectList(0), + inWhereClause(0), + inGroupByClause(0), + inHavingClause(0), + inOrderByClause(0), + errorHandlers(0), + clientDialect(0), + inOuterJoin(0), + aliasRelationPrefix(NULL), + hiddenVars(p), + hiddenVarsNumber(0), + package(p), + currCtes(p), + recursiveCtx(0), + recursiveCtxId(0), + processingWindow(false), + checkConstraintTrigger(false), + variables(p), + outputVariables(p), + ctes(p), + cteAliases(p), + currCteAlias(NULL), + psql(false) + { + domainValue.clear(); + } + +protected: + // DsqlCompilerScratch should never be destroyed using delete. + // It dies together with it's pool in release_request(). + ~DsqlCompilerScratch() + { + } + + virtual bool isDdlDyn() + { + return (statement->getType() == DsqlCompiledStatement::TYPE_DDL || statement->getDdlNode()) && + !(flags & FLAG_BLOCK); + } + +public: + virtual bool isVersion4() + { + return statement->getFlags() & DsqlCompiledStatement::FLAG_BLR_VERSION4; + } + + MemoryPool& getPool() + { + return PermanentStorage::getPool(); + } + + dsql_dbb* getAttachment() + { + return dbb; + } + + jrd_tra* getTransaction() + { + return transaction; + } + + void setTransaction(jrd_tra* value) + { + transaction = value; + } + + DsqlCompiledStatement* getStatement() + { + return statement; + } + + DsqlCompiledStatement* getStatement() const + { + return statement; + } + + void putDtype(const dsql_fld* field, bool useSubType); + void putLocalVariables(const dsql_nod* parameters, SSHORT locals); + void putLocalVariable(dsql_var* variable, dsql_nod* hostParam, const dsql_str* collationName); + dsql_nod* resolveVariable(const dsql_str* varName); + void genReturn(bool eosFlag = false); + + void addCTEs(dsql_nod* list); + dsql_nod* findCTE(const dsql_str* name); + void clearCTEs(); + void checkUnusedCTEs() const; + + // hvlad: each member of recursive CTE can refer to CTE itself (only once) via + // CTE name or via alias. We need to substitute this aliases when processing CTE + // member to resolve field names. Therefore we store all aliases in order of + // occurrence and later use it in backward order (since our parser is right-to-left). + // Also we put CTE name after all such aliases to distinguish aliases for + // different CTE's. + // We also need to repeat this process if main select expression contains union with + // recursive CTE + void addCTEAlias(const dsql_str* alias) + { + cteAliases.add(alias); + } + + const dsql_str* getNextCTEAlias() + { + return *(--currCteAlias); + } + + void resetCTEAlias(const dsql_str* alias) + { + const dsql_str* const* begin = cteAliases.begin(); + + currCteAlias = cteAliases.end() - 1; + fb_assert(currCteAlias >= begin); + + const dsql_str* curr = *(currCteAlias); + while (strcmp(curr->str_data, alias->str_data)) + { + currCteAlias--; + fb_assert(currCteAlias >= begin); + + curr = *(currCteAlias); + } + } + + bool isPsql() const { return psql; } + void setPsql(bool value) { psql = value; } + +private: + dsql_nod* pass1RecursiveCte(dsql_nod* input); + dsql_nod* pass1RseIsRecursive(dsql_nod* input); + bool pass1RelProcIsRecursive(dsql_nod* input); + dsql_nod* pass1JoinIsRecursive(dsql_nod*& input); + +private: + dsql_dbb* dbb; // DSQL attachment + jrd_tra* transaction; // Transaction + DsqlCompiledStatement* statement; // Compiled statement + +public: + unsigned flags; // flags + Firebird::Array ports; // Port messages + dsql_rel* relation; // relation created by this request (for DDL) + dsql_prc* procedure; // procedure created by this request (for DDL) + DsqlContextStack mainContext; + DsqlContextStack* context; + DsqlContextStack unionContext; // Save contexts for views of unions + DsqlContextStack derivedContext; // Save contexts for views of derived tables + dsql_ctx* outerAggContext; // agg context for outer ref + USHORT contextNumber; // Next available context number + USHORT derivedContextNumber; // Next available context number for derived tables + USHORT scopeLevel; // Scope level for parsing aliases in subqueries + USHORT loopLevel; // Loop level + DsqlStrStack labels; // Loop labels + USHORT cursorNumber; // Cursor number + DsqlNodStack cursors; // Cursors + USHORT inSelectList; // now processing "select list" + USHORT inWhereClause; // processing "where clause" + USHORT inGroupByClause; // processing "group by clause" + USHORT inHavingClause; // processing "having clause" + USHORT inOrderByClause; // processing "order by clause" + USHORT errorHandlers; // count of active error handlers + USHORT clientDialect; // dialect passed into the API call + USHORT inOuterJoin; // processing inside outer-join part + dsql_str* aliasRelationPrefix; // prefix for every relation-alias. + DsqlNodStack hiddenVars; // hidden variables + USHORT hiddenVarsNumber; // next hidden variable number + Firebird::MetaName package; // package being defined + DsqlNodStack currCtes; // current processing CTE's + class dsql_ctx* recursiveCtx; // context of recursive CTE + USHORT recursiveCtxId; // id of recursive union stream context + bool processingWindow; // processing window functions + bool checkConstraintTrigger; // compiling a check constraint trigger + dsc domainValue; // VALUE in the context of domain's check constraint + Firebird::Array variables; + Firebird::Array outputVariables; + +private: + Firebird::HalfStaticArray ctes; // common table expressions + Firebird::HalfStaticArray cteAliases; // CTE aliases in recursive members + const dsql_str* const* currCteAlias; + bool psql; +}; + +class PsqlChanger +{ +public: + PsqlChanger(DsqlCompilerScratch* aDsqlScratch, bool value) + : dsqlScratch(aDsqlScratch), + oldValue(dsqlScratch->isPsql()) + { + dsqlScratch->setPsql(value); + } + + ~PsqlChanger() + { + dsqlScratch->setPsql(oldValue); + } + +private: + // copying is prohibited + PsqlChanger(const PsqlChanger&); + PsqlChanger& operator =(const PsqlChanger&); + + DsqlCompilerScratch* dsqlScratch; + const bool oldValue; +}; + +} // namespace Jrd + +#endif // DSQL_COMPILER_SCRATCH_H diff --git a/src/dsql/Nodes.h b/src/dsql/Nodes.h index 59230740eb..8cebfaea68 100644 --- a/src/dsql/Nodes.h +++ b/src/dsql/Nodes.h @@ -24,7 +24,7 @@ #define DSQL_NODES_H #include "../jrd/common.h" -#include "../dsql/dsql.h" +#include "../dsql/DsqlCompilerScratch.h" #include "../dsql/node.h" #include "../dsql/Visitors.h" #include "../common/classes/array.h" @@ -611,39 +611,6 @@ private: }; -// Common node for all "code blocks" (i.e.: procedures, triggers and execute block) -class BlockNode -{ -public: - explicit BlockNode(MemoryPool& pool, bool aHasEos) - : hasEos(aHasEos), - variables(pool), - outputVariables(pool) - { - } - - virtual ~BlockNode() - { - } - - static void putDtype(DsqlCompilerScratch* dsqlScratch, const dsql_fld* field, bool useSubType); - - void putLocalVariables(DsqlCompilerScratch* dsqlScratch, const dsql_nod* parameters, - SSHORT locals); - void putLocalVariable(DsqlCompilerScratch* dsqlScratch, dsql_var* variable, - dsql_nod* hostParam, const dsql_str* collationName); - dsql_nod* resolveVariable(const dsql_str* varName); - void genReturn(DsqlCompilerScratch* dsqlScratch, bool eosFlag = false); - -private: - bool hasEos; - -protected: - Firebird::Array variables; - Firebird::Array outputVariables; -}; - - } // namespace #endif // DSQL_NODES_H diff --git a/src/dsql/StmtNodes.cpp b/src/dsql/StmtNodes.cpp index f420f96ab0..5f8e00a889 100644 --- a/src/dsql/StmtNodes.cpp +++ b/src/dsql/StmtNodes.cpp @@ -52,278 +52,6 @@ using namespace Jrd; namespace Jrd { -// Write out field data type. -// Taking special care to declare international text. -void BlockNode::putDtype(DsqlCompilerScratch* dsqlScratch, const dsql_fld* field, bool useSubType) -{ -#ifdef DEV_BUILD - // Check if the field describes a known datatype - - if (field->fld_dtype > FB_NELEM(blr_dtypes) || !blr_dtypes[field->fld_dtype]) - { - SCHAR buffer[100]; - - sprintf(buffer, "Invalid dtype %d in BlockNode::putDtype", field->fld_dtype); - ERRD_bugcheck(buffer); - } -#endif - - if (field->fld_not_nullable) - dsqlScratch->appendUChar(blr_not_nullable); - - if (field->fld_type_of_name.hasData()) - { - if (field->fld_type_of_table) - { - if (field->fld_explicit_collation) - { - dsqlScratch->appendUChar(blr_column_name2); - dsqlScratch->appendUChar(field->fld_full_domain ? blr_domain_full : blr_domain_type_of); - dsqlScratch->appendMetaString(field->fld_type_of_table->str_data); - dsqlScratch->appendMetaString(field->fld_type_of_name.c_str()); - dsqlScratch->appendUShort(field->fld_ttype); - } - else - { - dsqlScratch->appendUChar(blr_column_name); - dsqlScratch->appendUChar(field->fld_full_domain ? blr_domain_full : blr_domain_type_of); - dsqlScratch->appendMetaString(field->fld_type_of_table->str_data); - dsqlScratch->appendMetaString(field->fld_type_of_name.c_str()); - } - } - else - { - if (field->fld_explicit_collation) - { - dsqlScratch->appendUChar(blr_domain_name2); - dsqlScratch->appendUChar(field->fld_full_domain ? blr_domain_full : blr_domain_type_of); - dsqlScratch->appendMetaString(field->fld_type_of_name.c_str()); - dsqlScratch->appendUShort(field->fld_ttype); - } - else - { - dsqlScratch->appendUChar(blr_domain_name); - dsqlScratch->appendUChar(field->fld_full_domain ? blr_domain_full : blr_domain_type_of); - dsqlScratch->appendMetaString(field->fld_type_of_name.c_str()); - } - } - - return; - } - - switch (field->fld_dtype) - { - case dtype_cstring: - case dtype_text: - case dtype_varying: - case dtype_blob: - if (!useSubType) - dsqlScratch->appendUChar(blr_dtypes[field->fld_dtype]); - else if (field->fld_dtype == dtype_varying) - { - dsqlScratch->appendUChar(blr_varying2); - dsqlScratch->appendUShort(field->fld_ttype); - } - else if (field->fld_dtype == dtype_cstring) - { - dsqlScratch->appendUChar(blr_cstring2); - dsqlScratch->appendUShort(field->fld_ttype); - } - else if (field->fld_dtype == dtype_blob) - { - dsqlScratch->appendUChar(blr_blob2); - dsqlScratch->appendUShort(field->fld_sub_type); - dsqlScratch->appendUShort(field->fld_ttype); - } - else - { - dsqlScratch->appendUChar(blr_text2); - dsqlScratch->appendUShort(field->fld_ttype); - } - - if (field->fld_dtype == dtype_varying) - dsqlScratch->appendUShort(field->fld_length - sizeof(USHORT)); - else if (field->fld_dtype != dtype_blob) - dsqlScratch->appendUShort(field->fld_length); - break; - - default: - dsqlScratch->appendUChar(blr_dtypes[field->fld_dtype]); - if (DTYPE_IS_EXACT(field->fld_dtype) || (dtype_quad == field->fld_dtype)) - dsqlScratch->appendUChar(field->fld_scale); - break; - } -} - -// Emit dyn for the local variables declared in a procedure or trigger. -void BlockNode::putLocalVariables(DsqlCompilerScratch* dsqlScratch, const dsql_nod* parameters, - SSHORT locals) -{ - if (!parameters) - return; - - dsql_nod* const* ptr = parameters->nod_arg; - for (const dsql_nod* const* const end = ptr + parameters->nod_count; ptr < end; ptr++) - { - dsql_nod* parameter = *ptr; - - dsqlScratch->putDebugSrcInfo(parameter->nod_line, parameter->nod_column); - - if (parameter->nod_type == Dsql::nod_def_field) - { - dsql_fld* field = (dsql_fld*) parameter->nod_arg[Dsql::e_dfl_field]; - const dsql_nod* const* rest = ptr; - while (++rest != end) - { - if ((*rest)->nod_type == Dsql::nod_def_field) - { - const dsql_fld* rest_field = (dsql_fld*) (*rest)->nod_arg[Dsql::e_dfl_field]; - if (field->fld_name == rest_field->fld_name) - { - ERRD_post(Arg::Gds(isc_sqlerr) << Arg::Num(-637) << - Arg::Gds(isc_dsql_duplicate_spec) << Arg::Str(field->fld_name)); - } - } - } - - dsql_nod* var_node = MAKE_variable(field, field->fld_name.c_str(), VAR_local, 0, 0, locals); - variables.add(var_node); - - dsql_var* variable = (dsql_var*) var_node->nod_arg[Dsql::e_var_variable]; - putLocalVariable(dsqlScratch, variable, parameter, - reinterpret_cast(parameter->nod_arg[Dsql::e_dfl_collate])); - - // Some field attributes are calculated inside - // putLocalVariable(), so we reinitialize the - // descriptor - MAKE_desc_from_field(&var_node->nod_desc, field); - - locals++; - } - else if (parameter->nod_type == Dsql::nod_cursor) - { - PASS1_statement(dsqlScratch, parameter); - GEN_statement(dsqlScratch, parameter); - } - } -} - -// Write out local variable field data type. -void BlockNode::putLocalVariable(DsqlCompilerScratch* dsqlScratch, dsql_var* variable, - dsql_nod* hostParam, const dsql_str* collationName) -{ - dsql_fld* field = variable->var_field; - - dsqlScratch->appendUChar(blr_dcl_variable); - dsqlScratch->appendUShort(variable->var_variable_number); - DDL_resolve_intl_type(dsqlScratch, field, collationName); - - //const USHORT dtype = field->fld_dtype; - - putDtype(dsqlScratch, field, true); - //field->fld_dtype = dtype; - - // Check for a default value, borrowed from define_domain - dsql_nod* node = hostParam ? hostParam->nod_arg[Dsql::e_dfl_default] : NULL; - - if (node || (!field->fld_full_domain && !field->fld_not_nullable)) - { - dsqlScratch->appendUChar(blr_assignment); - - if (node) - { - fb_assert(node->nod_type == Dsql::nod_def_default); - PsqlChanger psqlChanger(dsqlScratch, false); - node = PASS1_node(dsqlScratch, node->nod_arg[Dsql::e_dft_default]); - GEN_expr(dsqlScratch, node); - } - else - dsqlScratch->appendUChar(blr_null); // Initialize variable to NULL - - dsqlScratch->appendUChar(blr_variable); - dsqlScratch->appendUShort(variable->var_variable_number); - } - else - { - dsqlScratch->appendUChar(blr_init_variable); - dsqlScratch->appendUShort(variable->var_variable_number); - } - - if (variable->var_name[0]) // Not a function return value - dsqlScratch->putDebugVariable(variable->var_variable_number, variable->var_name); - - ++dsqlScratch->hiddenVarsNumber; -} - -// Try to resolve variable name against parameters and local variables. -dsql_nod* BlockNode::resolveVariable(const dsql_str* varName) -{ - for (dsql_nod* const* i = variables.begin(); i != variables.end(); ++i) - { - dsql_nod* var_node = *i; - fb_assert(var_node->nod_type == Dsql::nod_variable); - - if (var_node->nod_type == Dsql::nod_variable) - { - const dsql_var* variable = (dsql_var*) var_node->nod_arg[Dsql::e_var_variable]; - DEV_BLKCHK(variable, dsql_type_var); - - if (!strcmp(varName->str_data, variable->var_name)) - return var_node; - } - } - - return NULL; -} - -// Generate BLR for a return. -void BlockNode::genReturn(DsqlCompilerScratch* dsqlScratch, bool eosFlag) -{ - if (hasEos && !eosFlag) - dsqlScratch->appendUChar(blr_begin); - - dsqlScratch->appendUChar(blr_send); - dsqlScratch->appendUChar(1); - dsqlScratch->appendUChar(blr_begin); - - for (Array::const_iterator i = outputVariables.begin(); i != outputVariables.end(); ++i) - { - const dsql_nod* parameter = *i; - const dsql_var* variable = (dsql_var*) parameter->nod_arg[Dsql::e_var_variable]; - dsqlScratch->appendUChar(blr_assignment); - dsqlScratch->appendUChar(blr_variable); - dsqlScratch->appendUShort(variable->var_variable_number); - dsqlScratch->appendUChar(blr_parameter2); - dsqlScratch->appendUChar(variable->var_msg_number); - dsqlScratch->appendUShort(variable->var_msg_item); - dsqlScratch->appendUShort(variable->var_msg_item + 1); - } - - if (hasEos) - { - dsqlScratch->appendUChar(blr_assignment); - dsqlScratch->appendUChar(blr_literal); - dsqlScratch->appendUChar(blr_short); - dsqlScratch->appendUChar(0); - dsqlScratch->appendUShort((eosFlag ? 0 : 1)); - dsqlScratch->appendUChar(blr_parameter); - dsqlScratch->appendUChar(1); - dsqlScratch->appendUShort(USHORT(2 * outputVariables.getCount())); - } - - dsqlScratch->appendUChar(blr_end); - - if (hasEos && !eosFlag) - { - dsqlScratch->appendUChar(blr_stall); - dsqlScratch->appendUChar(blr_end); - } -} - - -//-------------------- - - DmlNode* DmlNode::pass1(thread_db* tdbb, CompilerScratch* csb, jrd_nod* aNode) { node = aNode; @@ -678,8 +406,6 @@ ExecBlockNode* ExecBlockNode::internalDsqlPass() { DsqlCompiledStatement* statement = dsqlScratch->getStatement(); - statement->setBlockNode(this); - if (returns.hasData()) statement->setType(DsqlCompiledStatement::TYPE_SELECT_BLOCK); else @@ -793,11 +519,6 @@ void ExecBlockNode::print(string& text, Array& nodes) const void ExecBlockNode::genBlr() { - DsqlCompiledStatement* statement = dsqlScratch->getStatement(); - - // Update blockNode, because we have a reference to the original unprocessed node. - statement->setBlockNode(this); - dsqlScratch->beginDebug(); // now do the input parameters @@ -808,10 +529,10 @@ void ExecBlockNode::genBlr() dsql_nod* var = MAKE_variable(parameter.legacyField, parameter.name.c_str(), VAR_input, 0, (USHORT) (2 * i), 0); - variables.add(var); + dsqlScratch->variables.add(var); } - const unsigned returnsPos = variables.getCount(); + const unsigned returnsPos = dsqlScratch->variables.getCount(); // now do the output parameters for (size_t i = 0; i < returns.getCount(); ++i) @@ -821,10 +542,12 @@ void ExecBlockNode::genBlr() dsql_nod* var = MAKE_variable(parameter.legacyField, parameter.name.c_str(), VAR_output, 1, (USHORT) (2 * i), i); - variables.add(var); - outputVariables.add(var); + dsqlScratch->variables.add(var); + dsqlScratch->outputVariables.add(var); } + DsqlCompiledStatement* statement = dsqlScratch->getStatement(); + dsqlScratch->appendUChar(blr_begin); if (parameters.hasData()) @@ -835,10 +558,12 @@ void ExecBlockNode::genBlr() else statement->setSendMsg(NULL); - for (Array::const_iterator i = outputVariables.begin(); i != outputVariables.end(); ++i) + for (Array::const_iterator i = dsqlScratch->outputVariables.begin(); + i != dsqlScratch->outputVariables.end(); + ++i) { dsql_par* param = MAKE_parameter(statement->getReceiveMsg(), true, true, - (i - outputVariables.begin()) + 1, *i); + (i - dsqlScratch->outputVariables.begin()) + 1, *i); param->par_node = *i; MAKE_desc(dsqlScratch, ¶m->par_desc, *i, NULL); param->par_desc.dsc_flags |= DSC_nullable; @@ -864,7 +589,7 @@ void ExecBlockNode::genBlr() for (unsigned i = 0; i < returnsPos; ++i) { - const dsql_nod* parameter = variables[i]; + const dsql_nod* parameter = dsqlScratch->variables[i]; const dsql_var* variable = (dsql_var*) parameter->nod_arg[Dsql::e_var_variable]; const dsql_fld* field = variable->var_field; @@ -875,7 +600,7 @@ void ExecBlockNode::genBlr() // connection charset influence. So to validate, we cast them and assign to null. dsqlScratch->appendUChar(blr_assignment); dsqlScratch->appendUChar(blr_cast); - BlockNode::putDtype(dsqlScratch, field, true); + dsqlScratch->putDtype(field, true); dsqlScratch->appendUChar(blr_parameter2); dsqlScratch->appendUChar(0); dsqlScratch->appendUShort(variable->var_msg_item); @@ -884,16 +609,18 @@ void ExecBlockNode::genBlr() } } - for (Array::const_iterator i = outputVariables.begin(); i != outputVariables.end(); ++i) + for (Array::const_iterator i = dsqlScratch->outputVariables.begin(); + i != dsqlScratch->outputVariables.end(); + ++i) { dsql_nod* parameter = *i; dsql_var* variable = (dsql_var*) parameter->nod_arg[Dsql::e_var_variable]; - putLocalVariable(dsqlScratch, variable, 0, NULL); + dsqlScratch->putLocalVariable(variable, 0, NULL); } dsqlScratch->setPsql(true); - putLocalVariables(dsqlScratch, localDeclList, USHORT(returns.getCount())); + dsqlScratch->putLocalVariables(localDeclList, USHORT(returns.getCount())); dsqlScratch->loopLevel = 0; @@ -912,7 +639,7 @@ void ExecBlockNode::genBlr() statement->setType(DsqlCompiledStatement::TYPE_EXEC_BLOCK); dsqlScratch->appendUChar(blr_end); - genReturn(dsqlScratch, true); + dsqlScratch->genReturn(true); dsqlScratch->appendUChar(blr_end); dsqlScratch->endDebug(); @@ -1799,8 +1526,6 @@ SuspendNode* SuspendNode::internalDsqlPass() statement->addFlags(DsqlCompiledStatement::FLAG_SELECTABLE); - blockNode = statement->getBlockNode(); - return this; } @@ -1813,8 +1538,7 @@ void SuspendNode::print(string& text, Array& /*nodes*/) const void SuspendNode::genBlr() { - if (blockNode) - blockNode->genReturn(dsqlScratch); + dsqlScratch->genReturn(); } @@ -1879,7 +1603,6 @@ ReturnNode* ReturnNode::internalDsqlPass() ReturnNode* node = FB_NEW(getPool()) ReturnNode(getPool()); node->dsqlScratch = dsqlScratch; - node->blockNode = statement->getBlockNode(); node->value = PASS1_node(dsqlScratch, value); return node; @@ -1899,9 +1622,7 @@ void ReturnNode::genBlr() GEN_expr(dsqlScratch, value); dsqlScratch->appendUChar(blr_variable); dsqlScratch->appendUShort(0); - - blockNode->genReturn(dsqlScratch); - + dsqlScratch->genReturn(); dsqlScratch->appendUChar(blr_leave); dsqlScratch->appendUChar(0); } diff --git a/src/dsql/StmtNodes.h b/src/dsql/StmtNodes.h index c1db3ac875..bef79a0b15 100644 --- a/src/dsql/StmtNodes.h +++ b/src/dsql/StmtNodes.h @@ -103,12 +103,11 @@ public: }; -class ExecBlockNode : public DsqlOnlyStmtNode, public BlockNode +class ExecBlockNode : public DsqlOnlyStmtNode { public: explicit ExecBlockNode(MemoryPool& pool) : DsqlOnlyStmtNode(pool), - BlockNode(pool, true), parameters(pool), returns(pool), localDeclList(NULL), @@ -319,7 +318,6 @@ class SuspendNode : public StmtNode public: explicit SuspendNode(MemoryPool& pool) : StmtNode(pool), - blockNode(NULL), message(NULL), statement(NULL) { @@ -339,7 +337,6 @@ public: virtual const jrd_nod* execute(thread_db* tdbb, jrd_req* request) const; public: - BlockNode* blockNode; NestConst message; NestConst statement; }; @@ -361,7 +358,6 @@ public: virtual void genBlr(); public: - BlockNode* blockNode; dsql_nod* value; }; diff --git a/src/dsql/dsql.h b/src/dsql/dsql.h index a96b4d3f90..326128eb19 100644 --- a/src/dsql/dsql.h +++ b/src/dsql/dsql.h @@ -66,14 +66,14 @@ const char* const TEMP_CONTEXT = "TEMP"; namespace Jrd { - class Database; class Attachment; + class Database; + class DsqlCompilerScratch; class jrd_tra; class jrd_req; class blb; struct bid; - class BlockNode; class dsql_blb; class dsql_ctx; class dsql_msg; @@ -103,8 +103,6 @@ namespace Firebird namespace Jrd { -class DsqlCompilerScratch; - //! generic data type used to store strings class dsql_str : public pool_alloc_rpt { @@ -416,7 +414,6 @@ public: type(TYPE_SELECT), flags(0), ddlData(p), - blockNode(NULL), ddlNode(NULL), blob(NULL), sendMsg(NULL), @@ -448,10 +445,6 @@ public: const Firebird::HalfStaticArray& getDdlData() const { return ddlData; } void setDdlData(Firebird::HalfStaticArray& value) { ddlData = value; } - BlockNode* getBlockNode() { return blockNode; } - const BlockNode* getBlockNode() const { return blockNode; } - void setBlockNode(BlockNode* value) { blockNode = value; } - dsql_nod* getDdlNode() { return ddlNode; } const dsql_nod* getDdlNode() const { return ddlNode; } void setDdlNode(dsql_nod* value) { ddlNode = value; } @@ -496,7 +489,6 @@ private: ULONG flags; // generic flag Firebird::RefStrPtr sqlText; Firebird::HalfStaticArray ddlData; - BlockNode* blockNode; // Defining block dsql_nod* ddlNode; // Store metadata statement dsql_blb* blob; // Blob info for blob statements dsql_msg* sendMsg; // Message to be sent to start request @@ -580,236 +572,6 @@ protected: friend class Firebird::MemoryPool; }; - -// DSQL Compiler scratch block - may be discarded after compilation in the future. -class DsqlCompilerScratch : public BlrWriter -{ -public: - static const unsigned FLAG_IN_AUTO_TRANS_BLOCK = 0x001; - static const unsigned FLAG_RETURNING_INTO = 0x002; - static const unsigned FLAG_METADATA_SAVED = 0x004; - static const unsigned FLAG_PROCEDURE = 0x008; - static const unsigned FLAG_TRIGGER = 0x010; - static const unsigned FLAG_BLOCK = 0x020; - static const unsigned FLAG_RECURSIVE_CTE = 0x040; - static const unsigned FLAG_UPDATE_OR_INSERT = 0x080; - static const unsigned FLAG_MERGE = 0x100; - static const unsigned FLAG_FUNCTION = 0x200; - -public: - explicit DsqlCompilerScratch(MemoryPool& p, dsql_dbb* aDbb, jrd_tra* aTransaction, - DsqlCompiledStatement* aStatement) - : BlrWriter(p), - dbb(aDbb), - transaction(aTransaction), - statement(aStatement), - flags(0), - ports(p), - relation(NULL), - procedure(NULL), - mainContext(p), - context(&mainContext), - unionContext(p), - derivedContext(p), - outerAggContext(NULL), - contextNumber(0), - derivedContextNumber(0), - scopeLevel(0), - loopLevel(0), - labels(p), - cursorNumber(0), - cursors(p), - inSelectList(0), - inWhereClause(0), - inGroupByClause(0), - inHavingClause(0), - inOrderByClause(0), - errorHandlers(0), - clientDialect(0), - inOuterJoin(0), - aliasRelationPrefix(NULL), - hiddenVars(p), - hiddenVarsNumber(0), - package(p), - currCtes(p), - recursiveCtx(0), - recursiveCtxId(0), - processingWindow(false), - checkConstraintTrigger(false), - ctes(p), - cteAliases(p), - currCteAlias(NULL), - psql(false) - { - domainValue.clear(); - } - -protected: - // DsqlCompilerScratch should never be destroyed using delete. - // It dies together with it's pool in release_request(). - ~DsqlCompilerScratch() - { - } - - virtual bool isDdlDyn() - { - return (statement->getType() == DsqlCompiledStatement::TYPE_DDL || statement->getDdlNode()) && - !statement->getBlockNode(); - } - -public: - virtual bool isVersion4() - { - return statement->getFlags() & DsqlCompiledStatement::FLAG_BLR_VERSION4; - } - - MemoryPool& getPool() - { - return PermanentStorage::getPool(); - } - - dsql_dbb* getAttachment() - { - return dbb; - } - - jrd_tra* getTransaction() - { - return transaction; - } - - void setTransaction(jrd_tra* value) - { - transaction = value; - } - - DsqlCompiledStatement* getStatement() - { - return statement; - } - - DsqlCompiledStatement* getStatement() const - { - return statement; - } - - void addCTEs(dsql_nod* list); - dsql_nod* findCTE(const dsql_str* name); - void clearCTEs(); - void checkUnusedCTEs() const; - - // hvlad: each member of recursive CTE can refer to CTE itself (only once) via - // CTE name or via alias. We need to substitute this aliases when processing CTE - // member to resolve field names. Therefore we store all aliases in order of - // occurrence and later use it in backward order (since our parser is right-to-left). - // Also we put CTE name after all such aliases to distinguish aliases for - // different CTE's. - // We also need to repeat this process if main select expression contains union with - // recursive CTE - void addCTEAlias(const dsql_str* alias) - { - cteAliases.add(alias); - } - - const dsql_str* getNextCTEAlias() - { - return *(--currCteAlias); - } - - void resetCTEAlias(const dsql_str* alias) - { - const dsql_str* const* begin = cteAliases.begin(); - - currCteAlias = cteAliases.end() - 1; - fb_assert(currCteAlias >= begin); - - const dsql_str* curr = *(currCteAlias); - while (strcmp(curr->str_data, alias->str_data)) - { - currCteAlias--; - fb_assert(currCteAlias >= begin); - - curr = *(currCteAlias); - } - } - - bool isPsql() const { return psql; } - void setPsql(bool value) { psql = value; } - -private: - dsql_dbb* dbb; // DSQL attachment - jrd_tra* transaction; // Transaction - DsqlCompiledStatement* statement; // Compiled statement - -public: - unsigned flags; // flags - Firebird::Array ports; // Port messages - dsql_rel* relation; // relation created by this request (for DDL) - dsql_prc* procedure; // procedure created by this request (for DDL) - DsqlContextStack mainContext; - DsqlContextStack* context; - DsqlContextStack unionContext; // Save contexts for views of unions - DsqlContextStack derivedContext; // Save contexts for views of derived tables - dsql_ctx* outerAggContext; // agg context for outer ref - USHORT contextNumber; // Next available context number - USHORT derivedContextNumber; // Next available context number for derived tables - USHORT scopeLevel; // Scope level for parsing aliases in subqueries - USHORT loopLevel; // Loop level - DsqlStrStack labels; // Loop labels - USHORT cursorNumber; // Cursor number - DsqlNodStack cursors; // Cursors - USHORT inSelectList; // now processing "select list" - USHORT inWhereClause; // processing "where clause" - USHORT inGroupByClause; // processing "group by clause" - USHORT inHavingClause; // processing "having clause" - USHORT inOrderByClause; // processing "order by clause" - USHORT errorHandlers; // count of active error handlers - USHORT clientDialect; // dialect passed into the API call - USHORT inOuterJoin; // processing inside outer-join part - dsql_str* aliasRelationPrefix; // prefix for every relation-alias. - DsqlNodStack hiddenVars; // hidden variables - USHORT hiddenVarsNumber; // next hidden variable number - Firebird::MetaName package; // package being defined - DsqlNodStack currCtes; // current processing CTE's - class dsql_ctx* recursiveCtx; // context of recursive CTE - USHORT recursiveCtxId; // id of recursive union stream context - bool processingWindow; // processing window functions - bool checkConstraintTrigger; // compiling a check constraint trigger - dsc domainValue; // VALUE in the context of domain's check constraint - -private: - Firebird::HalfStaticArray ctes; // common table expressions - Firebird::HalfStaticArray cteAliases; // CTE aliases in recursive members - const dsql_str* const* currCteAlias; - bool psql; -}; - - -class PsqlChanger -{ -public: - PsqlChanger(DsqlCompilerScratch* aStatement, bool value) - : statement(aStatement), - oldValue(statement->isPsql()) - { - statement->setPsql(value); - } - - ~PsqlChanger() - { - statement->setPsql(oldValue); - } - -private: - // copying is prohibited - PsqlChanger(const PsqlChanger&); - PsqlChanger& operator =(const PsqlChanger&); - - DsqlCompilerScratch* statement; - const bool oldValue; -}; - - // Blob class dsql_blb : public pool_alloc { @@ -904,7 +666,7 @@ public: bool getImplicitJoinField(const Firebird::MetaName& name, dsql_nod*& node); PartitionMap* getPartitionMap(DsqlCompilerScratch* dsqlScratch, dsql_nod* partitionNode, - dsql_nod* orderNode); + dsql_nod* orderNode); }; // Flag values for ctx_flags diff --git a/src/dsql/gen.cpp b/src/dsql/gen.cpp index 96c2e9244d..f22abfb419 100644 --- a/src/dsql/gen.cpp +++ b/src/dsql/gen.cpp @@ -1316,7 +1316,7 @@ static void gen_cast( DsqlCompilerScratch* dsqlScratch, const dsql_nod* node) { dsqlScratch->appendUChar(blr_cast); const dsql_fld* field = (dsql_fld*) node->nod_arg[e_cast_target]; - BlockNode::putDtype(dsqlScratch, field, true); + dsqlScratch->putDtype(field, true); GEN_expr(dsqlScratch, node->nod_arg[e_cast_source]); } diff --git a/src/dsql/misc_func.cpp b/src/dsql/misc_func.cpp index cbf56e4449..ca03d61768 100644 --- a/src/dsql/misc_func.cpp +++ b/src/dsql/misc_func.cpp @@ -21,7 +21,7 @@ */ #include "firebird.h" -#include "../dsql/dsql.h" +#include "../dsql/DsqlCompilerScratch.h" #include "../dsql/misc_func.h" using namespace Jrd; diff --git a/src/dsql/pass1.cpp b/src/dsql/pass1.cpp index 72e3971892..4f88c59d66 100644 --- a/src/dsql/pass1.cpp +++ b/src/dsql/pass1.cpp @@ -186,7 +186,6 @@ static void DSQL_pretty(const dsql_nod*, int); static dsql_nod* ambiguity_check(DsqlCompilerScratch*, dsql_nod*, const dsql_str*, const DsqlContextStack&); static void assign_fld_dtype_from_dsc(dsql_fld*, const dsc*); -static dsql_nod* compose(dsql_nod*, dsql_nod*, NOD_TYPE); static dsql_nod* explode_fields(dsql_rel*); static dsql_nod* explode_outputs(DsqlCompilerScratch*, const dsql_prc*); static void field_appears_once(const dsql_nod*, const dsql_nod*, const bool, const char*); @@ -243,11 +242,6 @@ static dsql_nod* resolve_using_field(DsqlCompilerScratch* dsqlScratch, dsql_str* static void set_parameters_name(dsql_nod*, const dsql_nod*); static void set_parameter_name(dsql_nod*, const dsql_nod*, const dsql_rel*); static dsql_nod* pass1_savepoint(const DsqlCompilerScratch*, dsql_nod*); - -static bool pass1_relproc_is_recursive(DsqlCompilerScratch*, dsql_nod*); -static dsql_nod* pass1_join_is_recursive(DsqlCompilerScratch*, dsql_nod*&); -static dsql_nod* pass1_rse_is_recursive(DsqlCompilerScratch*, dsql_nod*); -static dsql_nod* pass1_recursive_cte(DsqlCompilerScratch*, dsql_nod*); static dsql_nod* process_returning(DsqlCompilerScratch*, dsql_nod*); const char* const DB_KEY_STRING = "DB_KEY"; // NTX: pseudo field name @@ -1607,7 +1601,7 @@ dsql_nod* PASS1_node(DsqlCompilerScratch* dsqlScratch, dsql_nod* input) dsql_nod* temp = MAKE_node(input->nod_type, 2); temp->nod_arg[0] = input->nod_arg[0]; temp->nod_arg[1] = *ptr; - node = compose(node, PASS1_node(dsqlScratch, temp), nod_or); + node = PASS1_compose(node, PASS1_node(dsqlScratch, temp), nod_or); } return node; @@ -2623,19 +2617,8 @@ void PASS1_check_unique_fields_names(StrArray& names, const dsql_nod* fields) } -/** - - compose - - @brief Compose two booleans. - - - @param expr1 - @param expr2 - @param dsql_operator - - **/ -static dsql_nod* compose( dsql_nod* expr1, dsql_nod* expr2, NOD_TYPE dsql_operator) +// Compose two booleans. +dsql_nod* PASS1_compose( dsql_nod* expr1, dsql_nod* expr2, NOD_TYPE dsql_operator) { DEV_BLKCHK(expr1, dsql_type_nod); DEV_BLKCHK(expr2, dsql_type_nod); @@ -3897,7 +3880,7 @@ static dsql_nod* pass1_cursor_reference( DsqlCompilerScratch* dsqlScratch, const temp->nod_arg[e_par_parameter] = (dsql_nod*) parameter; parameter->par_desc = rv_source->par_desc; - rse->nod_arg[e_rse_boolean] = compose(rse->nod_arg[e_rse_boolean], node, nod_and); + rse->nod_arg[e_rse_boolean] = PASS1_compose(rse->nod_arg[e_rse_boolean], node, nod_and); } return rse; @@ -4084,415 +4067,6 @@ static dsql_nod* pass1_delete( DsqlCompilerScratch* dsqlScratch, dsql_nod* input } -/** - - pass1_relproc_is_recursive - - @brief check if table reference is recursive i.e. its name is equal - to the name of current processing CTE - - @param dsqlScratch - @param input - - **/ -static bool pass1_relproc_is_recursive(DsqlCompilerScratch* dsqlScratch, dsql_nod* input) -{ - const dsql_str* rel_name = NULL; - const dsql_str* rel_alias = NULL; - - switch (input->nod_type) - { - case nod_rel_proc_name: - rel_name = (dsql_str*) input->nod_arg[e_rpn_name]; - rel_alias = (dsql_str*) input->nod_arg[e_rpn_alias]; - break; - - case nod_relation_name: - rel_name = (dsql_str*) input->nod_arg[e_rln_name]; - rel_alias = (dsql_str*) input->nod_arg[e_rln_alias]; - break; - - default: - return false; - } - - fb_assert(dsqlScratch->currCtes.hasData()); - const dsql_nod* curr_cte = dsqlScratch->currCtes.object(); - const dsql_str* cte_name = (dsql_str*) curr_cte->nod_arg[e_derived_table_alias]; - - const bool recursive = (cte_name->str_length == rel_name->str_length) && - (strncmp(rel_name->str_data, cte_name->str_data, cte_name->str_length) == 0); - - if (recursive) { - dsqlScratch->addCTEAlias(rel_alias ? rel_alias : rel_name); - } - - return recursive; -} - - -/** - - pass1_join_is_recursive - - @brief check if join have recursive members. If found remove this member - from join and return its boolean (to be added into WHERE clause). - We must remove member only if it is a table reference. - Punt if recursive reference is found in outer join or more than one - recursive reference is found - - @param dsqlScratch - @param input - - **/ -static dsql_nod* pass1_join_is_recursive(DsqlCompilerScratch* dsqlScratch, dsql_nod*& input) -{ - const NOD_TYPE join_type = input->nod_arg[e_join_type]->nod_type; - bool remove = false; - - bool leftRecursive = false; - dsql_nod* leftBool = NULL; - dsql_nod** join_table = &input->nod_arg[e_join_left_rel]; - if ((*join_table)->nod_type == nod_join) - { - leftBool = pass1_join_is_recursive(dsqlScratch, *join_table); - leftRecursive = (leftBool != NULL); - } - else - { - leftBool = input->nod_arg[e_join_boolean]; - leftRecursive = pass1_relproc_is_recursive(dsqlScratch, *join_table); - if (leftRecursive) - remove = true; - } - - if (leftRecursive && join_type != nod_join_inner) - { - ERRD_post(Arg::Gds(isc_sqlerr) << Arg::Num(-104) << - // Recursive member of CTE can't be member of an outer join - Arg::Gds(isc_dsql_cte_outer_join)); - } - - bool rightRecursive = false; - dsql_nod* rightBool = NULL; - join_table = &input->nod_arg[e_join_rght_rel]; - if ((*join_table)->nod_type == nod_join) - { - rightBool = pass1_join_is_recursive(dsqlScratch, *join_table); - rightRecursive = (rightBool != NULL); - } - else - { - rightBool = input->nod_arg[e_join_boolean]; - rightRecursive = pass1_relproc_is_recursive(dsqlScratch, *join_table); - if (rightRecursive) - remove = true; - } - - if (rightRecursive && join_type != nod_join_inner) - { - ERRD_post(Arg::Gds(isc_sqlerr) << Arg::Num(-104) << - // Recursive member of CTE can't be member of an outer join - Arg::Gds(isc_dsql_cte_outer_join)); - } - - if (leftRecursive && rightRecursive) - { - ERRD_post(Arg::Gds(isc_sqlerr) << Arg::Num(-104) << - // Recursive member of CTE can't reference itself more than once - Arg::Gds(isc_dsql_cte_mult_references)); - } - - if (leftRecursive) - { - if (remove) - input = input->nod_arg[e_join_rght_rel]; - - return leftBool; - } - - if (rightRecursive) - { - if (remove) - input = input->nod_arg[e_join_left_rel]; - - return rightBool; - } - - return 0; -} - - -/** - - pass1_rse_is_recursive - - @brief check if rse is recursive. If recursive reference is a table - in the FROM list remove it. If recursive reference is a part of - join add join boolean (returned by pass1_join_is_recursive) to the - WHERE clause. Punt if more than one recursive reference is found - - @param dsqlScratch - @param input - - **/ -static dsql_nod* pass1_rse_is_recursive(DsqlCompilerScratch* dsqlScratch, dsql_nod* input) -{ - fb_assert(input->nod_type == nod_query_spec); - - dsql_nod* result = MAKE_node(nod_query_spec, e_qry_count); - memcpy(result->nod_arg, input->nod_arg, e_qry_count * sizeof(dsql_nod*)); - - dsql_nod* src_tables = input->nod_arg[e_qry_from]; - dsql_nod* dst_tables = MAKE_node(nod_list, src_tables->nod_count); - result->nod_arg[e_qry_from] = dst_tables; - - dsql_nod** p_dst_table = dst_tables->nod_arg; - dsql_nod** p_src_table = src_tables->nod_arg; - dsql_nod** end = src_tables->nod_arg + src_tables->nod_count; - - bool found = false; - for (dsql_nod** prev = p_dst_table; p_src_table < end; p_src_table++, p_dst_table++) - { - *prev++ = *p_dst_table = *p_src_table; - - switch ((*p_dst_table)->nod_type) - { - case nod_rel_proc_name: - case nod_relation_name: - if (pass1_relproc_is_recursive(dsqlScratch, *p_dst_table)) - { - if (found) - { - ERRD_post(Arg::Gds(isc_sqlerr) << Arg::Num(-104) << - // Recursive member of CTE can't reference itself more than once - Arg::Gds(isc_dsql_cte_mult_references)); - } - found = true; - - prev--; - dst_tables->nod_count--; - } - break; - - case nod_join: - { - *p_dst_table = MAKE_node(nod_join, e_join_count); - memcpy((*p_dst_table)->nod_arg, (*p_src_table)->nod_arg, - e_join_count * sizeof(dsql_nod*)); - - dsql_nod* joinBool = pass1_join_is_recursive(dsqlScratch, *p_dst_table); - if (joinBool) - { - if (found) - { - ERRD_post(Arg::Gds(isc_sqlerr) << Arg::Num(-104) << - // Recursive member of CTE can't reference itself more than once - Arg::Gds(isc_dsql_cte_mult_references)); - } - found = true; - - result->nod_arg[e_qry_where] = - compose(result->nod_arg[e_qry_where], joinBool, nod_and); - } - } - break; - - case nod_derived_table: - break; - - default: - fb_assert(false); - } - } - - return found ? result : NULL; -} - - -/** - - pass1_recursive_cte - - @brief Process derived table which can be recursive CTE - If it is non-recursive return input node unchanged - If it is recursive return new derived table which is an union of - union of anchor (non-recursive) queries and union of recursive - queries. Check recursive queries to satisfy various criterias. - Note that our parser is right-to-left therefore nested list linked - as first node in parent list and second node is always query spec. - - For example, if we have 4 CTE's where first two is non-recursive - and last two is recursive : - - list union - [0] [1] [0] [1] - list cte3 ===> anchor recursive - [0] [1] [0] [1] [0] [1] - list cte3 cte1 cte2 cte3 cte4 - [0] [1] - cte1 cte2 - - Also, we should not change layout of original parse tree to allow it to - be parsed again if needed. Therefore recursive part is built using newly - allocated list nodes. - - @param dsqlScratch - @param input - - **/ -static dsql_nod* pass1_recursive_cte(DsqlCompilerScratch* dsqlScratch, dsql_nod* input) -{ - dsql_str* const cte_alias = (dsql_str*) input->nod_arg[e_derived_table_alias]; - dsql_nod* const select_expr = input->nod_arg[e_derived_table_rse]; - dsql_nod* query = select_expr->nod_arg[e_sel_query_spec]; - - if (query->nod_type != nod_list && pass1_rse_is_recursive(dsqlScratch, query)) - { - ERRD_post(Arg::Gds(isc_sqlerr) << Arg::Num(-104) << - // Recursive CTE (%s) must be an UNION - Arg::Gds(isc_dsql_cte_not_a_union) << Arg::Str(cte_alias->str_data)); - } - - // split queries list on two parts: anchor and recursive - dsql_nod* anchor_rse = 0, *recursive_rse = 0; - dsql_nod* qry = query; - - dsql_nod* new_qry = MAKE_node(nod_list, 2); - new_qry->nod_flags = query->nod_flags; - while (true) - { - dsql_nod* rse = 0; - if (qry->nod_type == nod_list) - rse = qry->nod_arg[1]; - else - rse = qry; - - dsql_nod* new_rse = pass1_rse_is_recursive(dsqlScratch, rse); - if (new_rse) // rse is recursive - { - if (anchor_rse) - { - ERRD_post(Arg::Gds(isc_sqlerr) << Arg::Num(-104) << - // CTE '%s' defined non-recursive member after recursive - Arg::Gds(isc_dsql_cte_nonrecurs_after_recurs) << Arg::Str(cte_alias->str_data)); - } - if (new_rse->nod_arg[e_qry_distinct]) - { - ERRD_post(Arg::Gds(isc_sqlerr) << Arg::Num(-104) << - // Recursive member of CTE '%s' has %s clause - Arg::Gds(isc_dsql_cte_wrong_clause) << Arg::Str(cte_alias->str_data) << - Arg::Str("DISTINCT")); - } - if (new_rse->nod_arg[e_qry_group]) - { - ERRD_post(Arg::Gds(isc_sqlerr) << Arg::Num(-104) << - // Recursive member of CTE '%s' has %s clause - Arg::Gds(isc_dsql_cte_wrong_clause) << Arg::Str(cte_alias->str_data) << - Arg::Str("GROUP BY")); - } - if (new_rse->nod_arg[e_qry_having]) - { - ERRD_post(Arg::Gds(isc_sqlerr) << Arg::Num(-104) << - // Recursive member of CTE '%s' has %s clause - Arg::Gds(isc_dsql_cte_wrong_clause) << Arg::Str(cte_alias->str_data) << - Arg::Str("HAVING")); - } - // hvlad: we need also forbid any aggregate function here - // but for now i have no idea how to do it simple - - if ((new_qry->nod_type == nod_list) && !(new_qry->nod_flags & NOD_UNION_ALL)) - { - ERRD_post(Arg::Gds(isc_sqlerr) << Arg::Num(-104) << - // Recursive members of CTE (%s) must be linked with another members via UNION ALL - Arg::Gds(isc_dsql_cte_union_all) << Arg::Str(cte_alias->str_data)); - } - if (!recursive_rse) - { - recursive_rse = new_qry; - } - new_rse->nod_flags |= NOD_SELECT_EXPR_RECURSIVE; - - if (qry->nod_type == nod_list) - new_qry->nod_arg[1] = new_rse; - else - new_qry->nod_arg[0] = new_rse; - } - else - { - if (qry->nod_type == nod_list) - new_qry->nod_arg[1] = rse; - else - new_qry->nod_arg[0] = rse; - - if (!anchor_rse) - { - if (qry->nod_type == nod_list) - anchor_rse = new_qry; - else - anchor_rse = rse; - } - } - - if (qry->nod_type != nod_list) - break; - - qry = qry->nod_arg[0]; - - if (qry->nod_type == nod_list) - { - new_qry->nod_arg[0] = MAKE_node(nod_list, 2); - new_qry = new_qry->nod_arg[0]; - new_qry->nod_flags = qry->nod_flags; - } - } - - if (!recursive_rse) { - return input; - } - if (!anchor_rse) - { - ERRD_post(Arg::Gds(isc_sqlerr) << Arg::Num(-104) << - // Non-recursive member is missing in CTE '%s' - Arg::Gds(isc_dsql_cte_miss_nonrecursive) << Arg::Str(cte_alias->str_data)); - } - - qry = recursive_rse; - dsql_nod* list = 0; - while (qry->nod_arg[0] != anchor_rse) - { - list = qry; - qry = qry->nod_arg[0]; - } - qry->nod_arg[0] = 0; - if (list) { - list->nod_arg[0] = qry->nod_arg[1]; - } - else { - recursive_rse = qry->nod_arg[1]; - } - - dsql_nod* union_node = MAKE_node(nod_list, 2); - union_node->nod_flags = NOD_UNION_ALL | NOD_UNION_RECURSIVE; - union_node->nod_arg[0] = anchor_rse; - union_node->nod_arg[1] = recursive_rse; - - dsql_nod* select = MAKE_node(nod_select_expr, e_sel_count); - select->nod_arg[e_sel_query_spec] = union_node; - select->nod_arg[e_sel_order] = select->nod_arg[e_sel_rows] = - select->nod_arg[e_sel_with_list] = NULL; - - dsql_nod* node = MAKE_node(nod_derived_table, e_derived_table_count); - dsql_str* alias = (dsql_str*) input->nod_arg[e_derived_table_alias]; - node->nod_arg[e_derived_table_alias] = (dsql_nod*) alias; - node->nod_arg[e_derived_table_column_alias] = input->nod_arg[e_derived_table_column_alias]; - node->nod_arg[e_derived_table_rse] = select; - node->nod_arg[e_derived_table_context] = input->nod_arg[e_derived_table_context]; - - return node; -} - - /** process_returning @@ -9011,13 +8585,9 @@ static dsql_nod* pass1_variable( DsqlCompilerScratch* dsqlScratch, dsql_nod* inp DEV_BLKCHK(var_name, dsql_type_str); - BlockNode* block = dsqlScratch->getStatement()->getBlockNode(); - if (block) - { - dsql_nod* varNode = block->resolveVariable(var_name); - if (varNode) - return varNode; - } + dsql_nod* varNode = dsqlScratch->resolveVariable(var_name); + if (varNode) + return varNode; // field unresolved // CVC: That's all [the fix], folks! @@ -9676,88 +9246,6 @@ static dsql_nod* pass1_savepoint(const DsqlCompilerScratch* dsqlScratch, dsql_no } -void DsqlCompilerScratch::addCTEs(dsql_nod* with) -{ - DEV_BLKCHK(with, dsql_type_nod); - fb_assert(with->nod_type == nod_with); - - if (ctes.getCount()) - { - ERRD_post(Arg::Gds(isc_sqlerr) << Arg::Num(-104) << - // WITH clause can't be nested - Arg::Gds(isc_dsql_cte_nested_with)); - } - - if (with->nod_flags & NOD_UNION_RECURSIVE) - flags |= DsqlCompilerScratch::FLAG_RECURSIVE_CTE; - - const dsql_nod* list = with->nod_arg[0]; - const dsql_nod* const* end = list->nod_arg + list->nod_count; - for (dsql_nod* const* cte = list->nod_arg; cte < end; cte++) - { - fb_assert((*cte)->nod_type == nod_derived_table); - - if (with->nod_flags & NOD_UNION_RECURSIVE) - { - currCtes.push(*cte); - PsqlChanger changer(this, false); - ctes.add(pass1_recursive_cte(this, *cte)); - currCtes.pop(); - - // Add CTE name into CTE aliases stack. It allows later to search for - // aliases of given CTE. - const dsql_str* cte_name = (dsql_str*) (*cte)->nod_arg[e_derived_table_alias]; - addCTEAlias(cte_name); - } - else { - ctes.add(*cte); - } - } -} - - -dsql_nod* DsqlCompilerScratch::findCTE(const dsql_str* name) -{ - for (size_t i = 0; i < ctes.getCount(); i++) - { - dsql_nod* cte = ctes[i]; - const dsql_str* cte_name = (dsql_str*) cte->nod_arg[e_derived_table_alias]; - - if (name->str_length == cte_name->str_length && - strncmp(name->str_data, cte_name->str_data, cte_name->str_length) == 0) - { - return cte; - } - } - - return NULL; -} - - -void DsqlCompilerScratch::clearCTEs() -{ - flags &= ~DsqlCompilerScratch::FLAG_RECURSIVE_CTE; - ctes.clear(); - cteAliases.clear(); -} - - -void DsqlCompilerScratch::checkUnusedCTEs() const -{ - for (size_t i = 0; i < ctes.getCount(); i++) - { - const dsql_nod* cte = ctes[i]; - - if (!(cte->nod_flags & NOD_DT_CTE_USED)) - { - const dsql_str* cte_name = (dsql_str*) cte->nod_arg[e_derived_table_alias]; - - ERRD_post(Arg::Gds(isc_sqlerr) << Arg::Num(-104) << - Arg::Gds(isc_dsql_cte_not_used) << Arg::Str(cte_name->str_data)); - } - } -} - // Returns false for hidden fields and true for non-hidden. // For non-hidden, change "node" if the field is part of an // implicit join. diff --git a/src/dsql/pass1_proto.h b/src/dsql/pass1_proto.h index 719bf3ed9a..a538ad08af 100644 --- a/src/dsql/pass1_proto.h +++ b/src/dsql/pass1_proto.h @@ -25,6 +25,7 @@ #define DSQL_PASS1_PROTO_H void PASS1_check_unique_fields_names(Jrd::StrArray& names, const Jrd::dsql_nod* fields); +Jrd::dsql_nod* PASS1_compose(Jrd::dsql_nod*, Jrd::dsql_nod*, Jrd::NOD_TYPE); Jrd::dsql_nod* PASS1_cursor_name(Jrd::DsqlCompilerScratch*, const Jrd::dsql_str*, USHORT, bool); Jrd::dsql_nod* PASS1_label(Jrd::DsqlCompilerScratch*, Jrd::dsql_nod*); Jrd::dsql_nod* PASS1_label2(Jrd::DsqlCompilerScratch*, Jrd::dsql_nod*, Jrd::dsql_nod*);