8
0
mirror of https://github.com/FirebirdSQL/firebird.git synced 2025-01-23 06:03:02 +01:00

Fixed #7056 (Fetching from a scrollable cursor may overwrite user-specified buffer and corrupt memory) and #7057 (Client-side positioned updates work wrongly with scrollable cursors) with a single shot

This commit is contained in:
Dmitry Yemanov 2021-12-01 11:44:50 +03:00
parent 35f0f1e70f
commit c4ad5afb77
4 changed files with 49 additions and 52 deletions

View File

@ -34,10 +34,10 @@ static const char* const SCRATCH = "fb_cursor_";
static const ULONG PREFETCH_SIZE = 65536; // 64 KB static const ULONG PREFETCH_SIZE = 65536; // 64 KB
DsqlCursor::DsqlCursor(dsql_req* req, ULONG flags) DsqlCursor::DsqlCursor(dsql_req* req, ULONG flags)
: m_request(req), m_resultSet(NULL), m_flags(flags), : m_request(req), m_message(req->getStatement()->getReceiveMsg()),
m_resultSet(NULL), m_flags(flags),
m_space(req->getPool(), SCRATCH), m_space(req->getPool(), SCRATCH),
m_state(BOS), m_eof(false), m_position(0), m_cachedCount(0), m_state(BOS), m_eof(false), m_position(0), m_cachedCount(0)
m_messageSize(req->getStatement()->getReceiveMsg()->msg_length)
{ {
TRA_link_cursor(m_request->req_transaction, this); TRA_link_cursor(m_request->req_transaction, this);
} }
@ -232,9 +232,14 @@ int DsqlCursor::fetchFromCache(thread_db* tdbb, UCHAR* buffer, FB_UINT64 positio
fb_assert(position < m_cachedCount); fb_assert(position < m_cachedCount);
const FB_UINT64 offset = position * m_messageSize; UCHAR* const msgBuffer = m_request->req_msg_buffers[m_message->msg_buffer_number];
const FB_UINT64 readBytes = m_space.read(offset, buffer, m_messageSize);
fb_assert(readBytes == m_messageSize); const FB_UINT64 offset = position * m_message->msg_length;
const FB_UINT64 readBytes = m_space.read(offset, msgBuffer, m_message->msg_length);
fb_assert(readBytes == m_message->msg_length);
m_request->mapInOut(tdbb, true, m_message, NULL, buffer);
m_position = position; m_position = position;
m_state = POSITIONED; m_state = POSITIONED;
return 0; return 0;
@ -244,34 +249,23 @@ bool DsqlCursor::cacheInput(thread_db* tdbb, FB_UINT64 position)
{ {
fb_assert(!m_eof); fb_assert(!m_eof);
const ULONG prefetchCount = MAX(PREFETCH_SIZE / m_messageSize, 1); const ULONG prefetchCount = MAX(PREFETCH_SIZE / m_message->msg_length, 1);
const ULONG prefetchSize = prefetchCount * m_messageSize; const UCHAR* const msgBuffer = m_request->req_msg_buffers[m_message->msg_buffer_number];
UCharBuffer messageBuffer;
UCHAR* const buffer = messageBuffer.getBuffer(prefetchSize);
while (position >= m_cachedCount) while (position >= m_cachedCount)
{ {
ULONG count = 0; for (ULONG count = 0; count < prefetchCount; count++)
for (; count < prefetchCount; count++)
{ {
UCHAR* const ptr = buffer + count * m_messageSize; if (!m_request->fetch(tdbb, NULL))
if (!m_request->fetch(tdbb, ptr))
{ {
m_eof = true; m_eof = true;
break; break;
} }
}
if (count) const FB_UINT64 offset = m_cachedCount * m_message->msg_length;
{ const FB_UINT64 writtenBytes = m_space.write(offset, msgBuffer, m_message->msg_length);
const FB_UINT64 offset = m_cachedCount * m_messageSize; fb_assert(writtenBytes == m_message->msg_length);
const ULONG fetchedSize = count * m_messageSize; m_cachedCount++;
const FB_UINT64 writtenBytes = m_space.write(offset, buffer, fetchedSize);
fb_assert(writtenBytes == fetchedSize);
m_cachedCount += count;
} }
if (m_eof) if (m_eof)

View File

@ -66,6 +66,7 @@ private:
bool cacheInput(thread_db* tdbb, FB_UINT64 position = MAX_UINT64); bool cacheInput(thread_db* tdbb, FB_UINT64 position = MAX_UINT64);
dsql_req* const m_request; dsql_req* const m_request;
const dsql_msg* const m_message;
JResultSet* m_resultSet; JResultSet* m_resultSet;
const ULONG m_flags; const ULONG m_flags;
TempSpace m_space; TempSpace m_space;
@ -73,7 +74,6 @@ private:
bool m_eof; bool m_eof;
FB_UINT64 m_position; FB_UINT64 m_position;
FB_UINT64 m_cachedCount; FB_UINT64 m_cachedCount;
ULONG m_messageSize;
}; };
} // namespace } // namespace

View File

@ -80,9 +80,6 @@ using namespace Firebird;
static ULONG get_request_info(thread_db*, dsql_req*, ULONG, UCHAR*); static ULONG get_request_info(thread_db*, dsql_req*, ULONG, UCHAR*);
static dsql_dbb* init(Jrd::thread_db*, Jrd::Attachment*); static dsql_dbb* init(Jrd::thread_db*, Jrd::Attachment*);
static void map_in_out(Jrd::thread_db*, dsql_req*, bool, const dsql_msg*, IMessageMetadata*, UCHAR*,
const UCHAR* = NULL);
static USHORT parse_metadata(dsql_req*, IMessageMetadata*, const Array<dsql_par*>&);
static dsql_req* prepareRequest(thread_db*, dsql_dbb*, jrd_tra*, ULONG, const TEXT*, USHORT, bool); static dsql_req* prepareRequest(thread_db*, dsql_dbb*, jrd_tra*, ULONG, const TEXT*, USHORT, bool);
static dsql_req* prepareStatement(thread_db*, dsql_dbb*, jrd_tra*, ULONG, const TEXT*, USHORT, bool); static dsql_req* prepareStatement(thread_db*, dsql_dbb*, jrd_tra*, ULONG, const TEXT*, USHORT, bool);
static UCHAR* put_item(UCHAR, const USHORT, const UCHAR*, UCHAR*, const UCHAR* const); static UCHAR* put_item(UCHAR, const USHORT, const UCHAR*, UCHAR*, const UCHAR* const);
@ -272,6 +269,12 @@ bool DsqlDmlRequest::fetch(thread_db* tdbb, UCHAR* msgBuffer)
dsql_msg* message = (dsql_msg*) statement->getReceiveMsg(); dsql_msg* message = (dsql_msg*) statement->getReceiveMsg();
if (delayedFormat && message)
{
parseMetadata(delayedFormat, message->msg_parameters);
delayedFormat = NULL;
}
// Set up things for tracing this call // Set up things for tracing this call
Jrd::Attachment* att = req_dbb->dbb_attachment; Jrd::Attachment* att = req_dbb->dbb_attachment;
TraceDSQLFetch trace(att, this); TraceDSQLFetch trace(att, this);
@ -285,13 +288,12 @@ bool DsqlDmlRequest::fetch(thread_db* tdbb, UCHAR* msgBuffer)
if (eofReached) if (eofReached)
{ {
delayedFormat = NULL;
trace.fetch(true, ITracePlugin::RESULT_SUCCESS); trace.fetch(true, ITracePlugin::RESULT_SUCCESS);
return false; return false;
} }
map_in_out(tdbb, this, true, message, delayedFormat, msgBuffer); if (msgBuffer)
delayedFormat = NULL; mapInOut(tdbb, true, message, NULL, msgBuffer);
trace.fetch(false, ITracePlugin::RESULT_SUCCESS); trace.fetch(false, ITracePlugin::RESULT_SUCCESS);
return true; return true;
@ -679,9 +681,9 @@ void DsqlDmlRequest::execute(thread_db* tdbb, jrd_tra** traHandle,
const dsql_msg* message = statement->getSendMsg(); const dsql_msg* message = statement->getSendMsg();
if (message) if (message)
map_in_out(tdbb, this, false, message, inMetadata, NULL, inMsg); mapInOut(tdbb, false, message, inMetadata, NULL, inMsg);
// we need to map_in_out before tracing of execution start to let trace // we need to mapInOut before tracing of execution start to let trace
// manager know statement parameters values // manager know statement parameters values
TraceDSQLExecute trace(req_dbb->dbb_attachment, this); TraceDSQLExecute trace(req_dbb->dbb_attachment, this);
@ -713,7 +715,7 @@ void DsqlDmlRequest::execute(thread_db* tdbb, jrd_tra** traHandle,
} }
if (outMetadata && message) if (outMetadata && message)
parse_metadata(this, outMetadata, message->msg_parameters); parseMetadata(outMetadata, message->msg_parameters);
if ((outMsg && message) || isBlock) if ((outMsg && message) || isBlock)
{ {
@ -736,7 +738,7 @@ void DsqlDmlRequest::execute(thread_db* tdbb, jrd_tra** traHandle,
JRD_receive(tdbb, req_request, message->msg_number, message->msg_length, msgBuffer); JRD_receive(tdbb, req_request, message->msg_number, message->msg_length, msgBuffer);
if (outMsg) if (outMsg)
map_in_out(tdbb, this, true, message, NULL, outMsg); mapInOut(tdbb, true, message, NULL, outMsg);
// if this is a singleton select, make sure there's in fact one record // if this is a singleton select, make sure there's in fact one record
@ -989,13 +991,12 @@ static dsql_dbb* init(thread_db* tdbb, Jrd::Attachment* attachment)
/** /**
map_in_out mapInOut
@brief Map data from external world into message or @brief Map data from external world into message or
from message to external world. from message to external world.
@param request
@param toExternal @param toExternal
@param message @param message
@param meta @param meta
@ -1003,10 +1004,10 @@ static dsql_dbb* init(thread_db* tdbb, Jrd::Attachment* attachment)
@param in_dsql_msg_buf @param in_dsql_msg_buf
**/ **/
static void map_in_out(thread_db* tdbb, dsql_req* request, bool toExternal, const dsql_msg* message, void dsql_req::mapInOut(thread_db* tdbb, bool toExternal, const dsql_msg* message,
IMessageMetadata* meta, UCHAR* dsql_msg_buf, const UCHAR* in_dsql_msg_buf) IMessageMetadata* meta, UCHAR* dsql_msg_buf, const UCHAR* in_dsql_msg_buf)
{ {
USHORT count = parse_metadata(request, meta, message->msg_parameters); USHORT count = parseMetadata(meta, message->msg_parameters);
// Sanity check // Sanity check
@ -1041,7 +1042,7 @@ static void map_in_out(thread_db* tdbb, dsql_req* request, bool toExternal, cons
// Make sure the message given to us is long enough // Make sure the message given to us is long enough
dsc desc; dsc desc;
if (!request->req_user_descs.get(parameter, desc)) if (!req_user_descs.get(parameter, desc))
desc.clear(); desc.clear();
/*** /***
@ -1059,14 +1060,14 @@ static void map_in_out(thread_db* tdbb, dsql_req* request, bool toExternal, cons
Arg::Gds(isc_dsql_sqlvar_index) << Arg::Num(parameter->par_index-1)); Arg::Gds(isc_dsql_sqlvar_index) << Arg::Num(parameter->par_index-1));
} }
UCHAR* msgBuffer = request->req_msg_buffers[parameter->par_message->msg_buffer_number]; UCHAR* msgBuffer = req_msg_buffers[parameter->par_message->msg_buffer_number];
SSHORT* flag = NULL; SSHORT* flag = NULL;
dsql_par* const null_ind = parameter->par_null; dsql_par* const null_ind = parameter->par_null;
if (null_ind != NULL) if (null_ind != NULL)
{ {
dsc userNullDesc; dsc userNullDesc;
if (!request->req_user_descs.get(null_ind, userNullDesc)) if (!req_user_descs.get(null_ind, userNullDesc))
userNullDesc.clear(); userNullDesc.clear();
const ULONG null_offset = (IPTR) userNullDesc.dsc_address; const ULONG null_offset = (IPTR) userNullDesc.dsc_address;
@ -1129,7 +1130,7 @@ static void map_in_out(thread_db* tdbb, dsql_req* request, bool toExternal, cons
Arg::Gds(isc_dsql_wrong_param_num) << Arg::Num(count) <<Arg::Num(count2)); Arg::Gds(isc_dsql_wrong_param_num) << Arg::Num(count) <<Arg::Num(count2));
} }
const DsqlCompiledStatement* statement = request->getStatement(); const DsqlCompiledStatement* statement = getStatement();
const dsql_par* parameter; const dsql_par* parameter;
const dsql_par* dbkey; const dsql_par* dbkey;
@ -1138,7 +1139,7 @@ static void map_in_out(thread_db* tdbb, dsql_req* request, bool toExternal, cons
{ {
UCHAR* parentMsgBuffer = statement->getParentRequest() ? UCHAR* parentMsgBuffer = statement->getParentRequest() ?
statement->getParentRequest()->req_msg_buffers[dbkey->par_message->msg_buffer_number] : NULL; statement->getParentRequest()->req_msg_buffers[dbkey->par_message->msg_buffer_number] : NULL;
UCHAR* msgBuffer = request->req_msg_buffers[parameter->par_message->msg_buffer_number]; UCHAR* msgBuffer = req_msg_buffers[parameter->par_message->msg_buffer_number];
fb_assert(parentMsgBuffer); fb_assert(parentMsgBuffer);
@ -1168,7 +1169,7 @@ static void map_in_out(thread_db* tdbb, dsql_req* request, bool toExternal, cons
UCHAR* parentMsgBuffer = statement->getParentRequest() ? UCHAR* parentMsgBuffer = statement->getParentRequest() ?
statement->getParentRequest()->req_msg_buffers[rec_version->par_message->msg_buffer_number] : statement->getParentRequest()->req_msg_buffers[rec_version->par_message->msg_buffer_number] :
NULL; NULL;
UCHAR* msgBuffer = request->req_msg_buffers[parameter->par_message->msg_buffer_number]; UCHAR* msgBuffer = req_msg_buffers[parameter->par_message->msg_buffer_number];
fb_assert(parentMsgBuffer); fb_assert(parentMsgBuffer);
@ -1195,18 +1196,16 @@ static void map_in_out(thread_db* tdbb, dsql_req* request, bool toExternal, cons
/** /**
parse_metadata parseMetadata
@brief Parse the message of a request. @brief Parse the message of a request.
@param request
@param meta @param meta
@param parameters_list @param parameters_list
**/ **/
static USHORT parse_metadata(dsql_req* request, IMessageMetadata* meta, USHORT dsql_req::parseMetadata(IMessageMetadata* meta, const Array<dsql_par*>& parameters_list)
const Array<dsql_par*>& parameters_list)
{ {
HalfStaticArray<const dsql_par*, 16> parameters; HalfStaticArray<const dsql_par*, 16> parameters;
@ -1280,7 +1279,7 @@ static USHORT parse_metadata(dsql_req* request, IMessageMetadata* meta,
if (desc.isText() && desc.getTextType() == ttype_dynamic) if (desc.isText() && desc.getTextType() == ttype_dynamic)
desc.setTextType(ttype_none); desc.setTextType(ttype_none);
request->req_user_descs.put(parameter, desc); req_user_descs.put(parameter, desc);
dsql_par* null = parameter->par_null; dsql_par* null = parameter->par_null;
if (null) if (null)
@ -1291,7 +1290,7 @@ static USHORT parse_metadata(dsql_req* request, IMessageMetadata* meta,
desc.dsc_length = sizeof(SSHORT); desc.dsc_length = sizeof(SSHORT);
desc.dsc_address = (UCHAR*)(IPTR) nullOffset; desc.dsc_address = (UCHAR*)(IPTR) nullOffset;
request->req_user_descs.put(null, desc); req_user_descs.put(null, desc);
} }
} }

View File

@ -573,6 +573,10 @@ public:
virtual void setDelayedFormat(thread_db* tdbb, Firebird::IMessageMetadata* metadata); virtual void setDelayedFormat(thread_db* tdbb, Firebird::IMessageMetadata* metadata);
USHORT parseMetadata(Firebird::IMessageMetadata* meta, const Firebird::Array<dsql_par*>& parameters_list);
void mapInOut(Jrd::thread_db* tdbb, bool toExternal, const dsql_msg* message, Firebird::IMessageMetadata* meta,
UCHAR* dsql_msg_buf, const UCHAR* in_dsql_msg_buf = NULL);
static void destroy(thread_db* tdbb, dsql_req* request, bool drop); static void destroy(thread_db* tdbb, dsql_req* request, bool drop);
private: private: