Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 76 additions & 0 deletions src/DurableTask.AzureStorage/Tracking/AzureTableQueryFilter.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
// ----------------------------------------------------------------------------------
// Copyright Microsoft Corporation
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// ----------------------------------------------------------------------------------

namespace DurableTask.AzureStorage.Tracking
{
using System.Runtime.CompilerServices;
using Azure.Data.Tables;

/// <summary>
/// Builds individual OData filter conditions for Azure Table Storage queries with proper escaping
/// to prevent OData injection via user-influenced values (e.g. instance IDs, execution IDs, task hub
/// names). Each helper returns a single condition string; callers compose them with " and " / " or ".
/// </summary>
static class AzureTableQueryFilter
{
/// <summary>
/// Builds a <c>PartitionKey eq '...'</c> condition for an orchestration instance ID. The instance ID
/// is <b>unsanitized</b>: this method applies <see cref="KeySanitation.EscapePartitionKey"/> itself and
/// then OData quote-escaping. Do not pre-sanitize the value, because
/// <see cref="KeySanitation.EscapePartitionKey"/> is not idempotent. Contrast with
/// <see cref="PartitionKeyGreaterOrEqual"/> / <see cref="PartitionKeyLessThan"/>, which take an
/// already-sanitized partition key.
/// </summary>
public static string PartitionKeyEquals(string unsanitizedInstanceId)
{
string sanitizedInstanceId = KeySanitation.EscapePartitionKey(unsanitizedInstanceId);
return TableClient.CreateQueryFilter($"PartitionKey eq {sanitizedInstanceId}");
}

/// <summary>
/// Builds a <c>{columnName} eq '...'</c> condition with <paramref name="rawValue"/> OData quote-escaped.
/// This does <b>not</b> apply <see cref="KeySanitation.EscapePartitionKey"/>: it is for
/// non-partition-key columns (e.g. ExecutionId, EventType, RowKey), so <paramref name="rawValue"/> is
/// used as-is aside from OData escaping. The <paramref name="columnName"/> must be a trusted constant
/// (never user input), since it is emitted as literal filter text rather than an escaped value.
/// </summary>
public static string ColumnEquals(string columnName, string rawValue)
{
// CreateQueryFilter takes a FormattableString and escapes/quotes each interpolation hole as a
// value, while leaving the literal text of the format untouched. The column name must remain
// literal (an interpolated {columnName} would be emitted as a quoted value, e.g. 'ExecutionId'
// eq 'x', which is invalid), but it is a runtime parameter rather than a compile-time literal,
// so a normal interpolated string ($"...") can't express that. FormattableStringFactory lets us
// bake the column name into the format text while keeping the value as the only escaped argument.
return TableClient.CreateQueryFilter(FormattableStringFactory.Create(columnName + " eq {0}", rawValue));
}

/// <summary>
/// Builds a <c>PartitionKey ge '...'</c> condition. The value must already be sanitized via
/// <see cref="KeySanitation.EscapePartitionKey"/> by the caller.
/// </summary>
public static string PartitionKeyGreaterOrEqual(string sanitizedPartitionKey)
{
return TableClient.CreateQueryFilter($"PartitionKey ge {sanitizedPartitionKey}");
}

/// <summary>
/// Builds a <c>PartitionKey lt '...'</c> condition. The value must already be sanitized via
/// <see cref="KeySanitation.EscapePartitionKey"/> by the caller.
/// </summary>
public static string PartitionKeyLessThan(string sanitizedPartitionKey)
{
return TableClient.CreateQueryFilter($"PartitionKey lt {sanitizedPartitionKey}");
}
}
}
31 changes: 18 additions & 13 deletions src/DurableTask.AzureStorage/Tracking/AzureTableTrackingStore.cs
Original file line number Diff line number Diff line change
Expand Up @@ -233,10 +233,13 @@ public override async Task<OrchestrationHistory> GetHistoryEventsAsync(string in

TableQueryResponse<TableEntity> GetHistoryEntitiesResponseInfoAsync(string instanceId, string expectedExecutionId, IList<string> projectionColumns, CancellationToken cancellationToken)
{
string filter = $"{nameof(ITableEntity.PartitionKey)} eq '{KeySanitation.EscapePartitionKey(instanceId)}'";
string filter = AzureTableQueryFilter.PartitionKeyEquals(instanceId);
if (!string.IsNullOrEmpty(expectedExecutionId))
{
filter += $" and ({nameof(ITableEntity.RowKey)} eq '{SentinelRowKey}' or {nameof(OrchestrationInstance.ExecutionId)} eq '{expectedExecutionId}')";
// Use parameterized filters to prevent OData injection via crafted execution IDs
string sentinelCondition = AzureTableQueryFilter.ColumnEquals(nameof(ITableEntity.RowKey), SentinelRowKey);
string executionIdCondition = AzureTableQueryFilter.ColumnEquals(nameof(OrchestrationInstance.ExecutionId), expectedExecutionId);
filter += $" and ({sentinelCondition} or {executionIdCondition})";
}

return this.HistoryTable.ExecuteQueryAsync<TableEntity>(filter, select: projectionColumns, cancellationToken: cancellationToken);
Expand Down Expand Up @@ -278,7 +281,7 @@ public override async IAsyncEnumerable<string> RewindHistoryAsync(string instanc
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////

bool hasFailedSubOrchestrations = false;
string partitionFilter = $"{nameof(ITableEntity.PartitionKey)} eq '{KeySanitation.EscapePartitionKey(instanceId)}'";
string partitionFilter = AzureTableQueryFilter.PartitionKeyEquals(instanceId);

string orchestratorStartedFilter = $"{partitionFilter} and {nameof(HistoryEvent.EventType)} eq '{nameof(EventType.OrchestratorStarted)}'";
IReadOnlyList<TableEntity> orchestratorStartedEntities = await this.QueryHistoryAsync(orchestratorStartedFilter, instanceId, cancellationToken);
Expand All @@ -289,7 +292,8 @@ public override async IAsyncEnumerable<string> RewindHistoryAsync(string instanc
string executionId = recentStartRow[0].GetString(nameof(OrchestrationInstance.ExecutionId));
DateTime instanceTimestamp = recentStartRow[0].Timestamp.GetValueOrDefault().DateTime;

string executionIdFilter = $"{nameof(OrchestrationInstance.ExecutionId)} eq '{executionId}'";
// Use parameterized filter to prevent OData injection via crafted execution IDs
string executionIdFilter = AzureTableQueryFilter.ColumnEquals(nameof(OrchestrationInstance.ExecutionId), executionId);

var updateFilterBuilder = new StringBuilder();
updateFilterBuilder.Append($"{partitionFilter}");
Expand Down Expand Up @@ -711,9 +715,8 @@ public override Task PurgeHistoryAsync(DateTime thresholdDateTimeUtc, Orchestrat
/// <inheritdoc />
public override async Task<PurgeHistoryResult> PurgeInstanceHistoryAsync(string instanceId, CancellationToken cancellationToken = default)
{
string sanitizedInstanceId = KeySanitation.EscapePartitionKey(instanceId);

string filter = $"{PartitionKeyProperty} eq '{sanitizedInstanceId}' and {RowKeyProperty} eq ''";
// Use parameterized filters to prevent OData injection via crafted instance IDs
string filter = $"{AzureTableQueryFilter.PartitionKeyEquals(instanceId)} and {AzureTableQueryFilter.ColumnEquals(RowKeyProperty, string.Empty)}";
var results = await this.InstancesTable
.ExecuteQueryAsync<OrchestrationInstanceStatus>(filter, cancellationToken: cancellationToken)
.GetResultsAsync(cancellationToken: cancellationToken);
Expand Down Expand Up @@ -1173,9 +1176,10 @@ static TableEntity GetSingleEntityFromHistoryTableResults(IReadOnlyList<TableEnt
// In the case that the output is too large and is stored in blob storage, extract the blob name from the ExecutionCompleted history entity.
if (this.ExceedsMaxTablePropertySize(runtimeState.Output))
{
string filter = $"{nameof(ITableEntity.PartitionKey)} eq '{KeySanitation.EscapePartitionKey(instanceId)}'" +
$" and {nameof(OrchestrationInstance.ExecutionId)} eq '{executionId}'" +
$" and {nameof(HistoryEvent.EventType)} eq '{nameof(EventType.ExecutionCompleted)}'";
// Use parameterized filters to prevent OData injection via crafted instance/execution IDs
string filter = $"{AzureTableQueryFilter.PartitionKeyEquals(instanceId)}" +
$" and {AzureTableQueryFilter.ColumnEquals(nameof(OrchestrationInstance.ExecutionId), executionId)}" +
$" and {AzureTableQueryFilter.ColumnEquals(nameof(HistoryEvent.EventType), nameof(EventType.ExecutionCompleted))}";
TableEntity executionCompletedEntity = GetSingleEntityFromHistoryTableResults(await this.QueryHistoryAsync(filter, instanceId, cancellationToken), "output");
this.SetInstancesTablePropertyFromHistoryProperty(
executionCompletedEntity,
Expand All @@ -1195,9 +1199,10 @@ static TableEntity GetSingleEntityFromHistoryTableResults(IReadOnlyList<TableEnt
// In the case that the input is too large and is stored in blob storage, extract the blob name from the ExecutionStarted history entity.
if (this.ExceedsMaxTablePropertySize(runtimeState.Input))
{
string filter = $"{nameof(ITableEntity.PartitionKey)} eq '{KeySanitation.EscapePartitionKey(instanceId)}'" +
$" and {nameof(OrchestrationInstance.ExecutionId)} eq '{executionId}'" +
$" and {nameof(HistoryEvent.EventType)} eq '{nameof(EventType.ExecutionStarted)}'";
// Use parameterized filters to prevent OData injection via crafted instance/execution IDs
string filter = $"{AzureTableQueryFilter.PartitionKeyEquals(instanceId)}" +
$" and {AzureTableQueryFilter.ColumnEquals(nameof(OrchestrationInstance.ExecutionId), executionId)}" +
$" and {AzureTableQueryFilter.ColumnEquals(nameof(HistoryEvent.EventType), nameof(EventType.ExecutionStarted))}";
TableEntity executionStartedEntity = GetSingleEntityFromHistoryTableResults(await this.QueryHistoryAsync(filter, instanceId, cancellationToken), "input");
this.SetInstancesTablePropertyFromHistoryProperty(
executionStartedEntity,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,8 @@ internal ODataCondition ToOData()

if (this.TaskHubNames != null && this.TaskHubNames.Any())
{
conditions.Add($"{string.Join(" or ", this.TaskHubNames.Select(x => $"TaskHubName eq '{x}'"))}");
// Use parameterized filter to prevent OData injection via crafted task hub names
conditions.Add(string.Join(" or ", this.TaskHubNames.Select(x => AzureTableQueryFilter.ColumnEquals("TaskHubName", x))));
}

if (!string.IsNullOrEmpty(this.InstanceIdPrefix))
Expand All @@ -140,8 +141,9 @@ internal ODataCondition ToOData()

string greaterThanPrefix = sanitizedPrefix.Substring(0, length) + incrementedLastChar;

conditions.Add($"{nameof(OrchestrationInstanceStatus.PartitionKey)} ge '{sanitizedPrefix}'");
conditions.Add($"{nameof(OrchestrationInstanceStatus.PartitionKey)} lt '{greaterThanPrefix}'");
// Use parameterized filters to prevent OData injection via crafted instance ID prefixes.
conditions.Add(AzureTableQueryFilter.PartitionKeyGreaterOrEqual(sanitizedPrefix));
conditions.Add(AzureTableQueryFilter.PartitionKeyLessThan(greaterThanPrefix));
}
else if (this.ExcludeEntities)
{
Expand All @@ -150,8 +152,8 @@ internal ODataCondition ToOData()

if (this.InstanceId != null)
{
string sanitizedInstanceId = KeySanitation.EscapePartitionKey(this.InstanceId);
conditions.Add($"{nameof(OrchestrationInstanceStatus.PartitionKey)} eq '{sanitizedInstanceId}'");
// Use parameterized filter to prevent OData injection via crafted instance IDs
conditions.Add(AzureTableQueryFilter.PartitionKeyEquals(this.InstanceId));
}

return conditions.Count switch
Expand Down
32 changes: 19 additions & 13 deletions src/DurableTask.ServiceBus/Tracking/AzureTableClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,12 @@ static readonly IDictionary<FilterComparisonType, string> ComparisonOperatorMap
{{ FilterComparisonType.Equals, AzureTableConstants.EqualityOperator},
{ FilterComparisonType.NotEquals, AzureTableConstants.InEqualityOperator}};

/// <summary>
/// Escapes a string value for safe use in an OData filter expression
/// by doubling single quotes (<c>'</c> → <c>''</c>).
/// </summary>
static string EscapeODataValue(string value) => value?.Replace("'", "''");

volatile TableClient historyTableClient;
volatile TableClient jumpStartTableClient;

Expand Down Expand Up @@ -234,22 +240,22 @@ string GetPrimaryFilterExpression(OrchestrationStateQueryFilter filter, bool isJ
{
filterExpression = string.Format(CultureInfo.InvariantCulture,
AzureTableConstants.PrimaryInstanceQueryRangeTemplate,
typedFilter.InstanceId, ComputeNextKeyInRange(typedFilter.InstanceId));
EscapeODataValue(typedFilter.InstanceId), EscapeODataValue(ComputeNextKeyInRange(typedFilter.InstanceId)));
}
else
{
if (string.IsNullOrWhiteSpace(typedFilter.ExecutionId))
{
filterExpression = string.Format(CultureInfo.InvariantCulture,
AzureTableConstants.PrimaryInstanceQueryRangeTemplate,
typedFilter.InstanceId, ComputeNextKeyInRange(typedFilter.InstanceId));
EscapeODataValue(typedFilter.InstanceId), EscapeODataValue(ComputeNextKeyInRange(typedFilter.InstanceId)));
}
else
{
filterExpression = string.Format(CultureInfo.InvariantCulture,
AzureTableConstants.PrimaryInstanceQueryExactTemplate,
typedFilter.InstanceId,
typedFilter.ExecutionId);
EscapeODataValue(typedFilter.InstanceId),
EscapeODataValue(typedFilter.ExecutionId));
}
}
}
Expand All @@ -275,20 +281,20 @@ string GetSecondaryFilterExpression(OrchestrationStateQueryFilter filter)
{
query = string.Format(CultureInfo.InvariantCulture,
AzureTableConstants.InstanceQuerySecondaryFilterRangeTemplate,
orchestrationStateInstanceFilter.InstanceId, ComputeNextKeyInRange(orchestrationStateInstanceFilter.InstanceId));
EscapeODataValue(orchestrationStateInstanceFilter.InstanceId), EscapeODataValue(ComputeNextKeyInRange(orchestrationStateInstanceFilter.InstanceId)));
}
else
{
if (string.IsNullOrWhiteSpace(orchestrationStateInstanceFilter.ExecutionId))
{
query = string.Format(CultureInfo.InvariantCulture,
AzureTableConstants.InstanceQuerySecondaryFilterTemplate, orchestrationStateInstanceFilter.InstanceId);
AzureTableConstants.InstanceQuerySecondaryFilterTemplate, EscapeODataValue(orchestrationStateInstanceFilter.InstanceId));
}
else
{
query = string.Format(CultureInfo.InvariantCulture,
AzureTableConstants.InstanceQuerySecondaryFilterExactTemplate, orchestrationStateInstanceFilter.InstanceId,
orchestrationStateInstanceFilter.ExecutionId);
AzureTableConstants.InstanceQuerySecondaryFilterExactTemplate, EscapeODataValue(orchestrationStateInstanceFilter.InstanceId),
EscapeODataValue(orchestrationStateInstanceFilter.ExecutionId));
}
}
}
Expand All @@ -297,20 +303,20 @@ string GetSecondaryFilterExpression(OrchestrationStateQueryFilter filter)
if (orchestrationStateNameVersionFilter.Version == null)
{
query = string.Format(CultureInfo.InvariantCulture,
AzureTableConstants.NameVersionQuerySecondaryFilterTemplate, orchestrationStateNameVersionFilter.Name);
AzureTableConstants.NameVersionQuerySecondaryFilterTemplate, EscapeODataValue(orchestrationStateNameVersionFilter.Name));
}
else
{
query = string.Format(CultureInfo.InvariantCulture,
AzureTableConstants.NameVersionQuerySecondaryFilterExactTemplate, orchestrationStateNameVersionFilter.Name,
orchestrationStateNameVersionFilter.Version);
AzureTableConstants.NameVersionQuerySecondaryFilterExactTemplate, EscapeODataValue(orchestrationStateNameVersionFilter.Name),
EscapeODataValue(orchestrationStateNameVersionFilter.Version));
}
}
else if (filter is OrchestrationStateStatusFilter orchestrationStateStatusFilter)
{
string template = AzureTableConstants.StatusQuerySecondaryFilterTemplate;
query = string.Format(CultureInfo.InvariantCulture,
template, ComparisonOperatorMap[orchestrationStateStatusFilter.ComparisonType], orchestrationStateStatusFilter.Status);
template, ComparisonOperatorMap[orchestrationStateStatusFilter.ComparisonType], EscapeODataValue(orchestrationStateStatusFilter.Status.ToString()));
}
else if (filter is OrchestrationStateTimeRangeFilter orchestrationStateTimeRangeFilter)
{
Expand Down Expand Up @@ -383,7 +389,7 @@ public async Task<IEnumerable<AzureTableOrchestrationHistoryEventEntity>> ReadOr
AzureTableConstants.JoinDelimiterPlusOne;

string filter = string.Format(CultureInfo.InvariantCulture, AzureTableConstants.TableRangeQueryFormat,
partitionKey, rowKeyLower, rowKeyUpper);
EscapeODataValue(partitionKey), EscapeODataValue(rowKeyLower), EscapeODataValue(rowKeyUpper));

var pageableResults = historyTableClient.QueryAsync<AzureTableOrchestrationHistoryEventEntity>(filter);

Expand Down
Loading
Loading