Skip to content

Commit 60685a7

Browse files
PrathamAdityamarkwallace-microsoftrogerbarreto
authored
.Net: Fix: include taskType in Google AI embedding request (fixes #13250) (#13277)
### Motivation and Context This PR fixes issue [#13250](#13250) by ensuring that the Google AI embedding connector correctly includes the taskType field in outgoing requests when EmbeddingGenerationOptions contains a task_type value. Previously, the connector ignored this field, causing task-specific embeddings (RETRIEVAL_DOCUMENT, RETRIEVAL_QUERY, etc.) to default to generic embeddings. ### Description Changes Extended GoogleAIEmbeddingRequest.FromData() to include an optional taskType parameter. Updated GoogleAIEmbeddingClient to extract task_type from EmbeddingGenerationOptions.AdditionalProperties. Added a new overload in GoogleAITextEmbeddingGenerationService to pass EmbeddingGenerationOptions while maintaining backward compatibility. Added unit test FromData_Should_Include_TaskType_When_Provided to verify that "taskType" appears correctly in serialized JSON. ### Contribution Checklist <!-- Before submitting this PR, please make sure: --> - [ Done] The code builds clean without any errors or warnings - [ Done] The PR follows the [SK Contribution Guidelines](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md) and the [pre-submission formatting script](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md#development-scripts) raises no violations - [ Done] All unit tests pass, and I have added new tests where possible - [ Done] I didn't break anyone 😄 --------- Co-authored-by: Mark Wallace <[email protected]> Co-authored-by: Roger Barreto <[email protected]>
1 parent e1fbafe commit 60685a7

File tree

5 files changed

+132
-19
lines changed

5 files changed

+132
-19
lines changed

dotnet/src/Connectors/Connectors.Google.UnitTests/Core/GoogleAI/GoogleAIClientEmbeddingsGenerationTests.cs

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -219,6 +219,35 @@ public async Task ShouldNotIncludeDimensionsInAllRequestsWhenNotProvidedAsync()
219219
Assert.All(request.Requests, item => Assert.Null(item.Dimensions));
220220
}
221221

222+
[Fact]
223+
public async Task GenerateEmbeddingsUsingEmbeddingGenerationOptionsShouldOverrideDimensionsAndModelAsync()
224+
{
225+
// Arrange
226+
var client = this.CreateEmbeddingsClient();
227+
var dataToEmbed = new List<string>()
228+
{
229+
"First text to embed",
230+
"Second text to embed",
231+
"Third text to embed"
232+
};
233+
234+
var options = new Microsoft.Extensions.AI.EmbeddingGenerationOptions { Dimensions = 10, ModelId = "override-model" };
235+
236+
// Act
237+
await client.GenerateEmbeddingsAsync(dataToEmbed, options);
238+
239+
// Assert
240+
var request = JsonSerializer.Deserialize<GoogleAIEmbeddingRequest>(this._messageHandlerStub.RequestContent);
241+
Assert.NotNull(request);
242+
Assert.Equal(dataToEmbed.Count, request.Requests.Count);
243+
Assert.All(request.Requests,
244+
item =>
245+
{
246+
Assert.Contains(options.ModelId, item.Model);
247+
Assert.Equal(options.Dimensions, item.Dimensions);
248+
});
249+
}
250+
222251
private GoogleAIEmbeddingClient CreateEmbeddingsClient(
223252
string modelId = "fake-model",
224253
int? dimensions = null)

dotnet/src/Connectors/Connectors.Google.UnitTests/Core/GoogleAI/GoogleAIEmbeddingRequestTests.cs

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
// Copyright (c) Microsoft. All rights reserved.
22

33
using System.Text.Json;
4+
using Microsoft.Extensions.AI;
45
using Microsoft.SemanticKernel.Connectors.Google.Core;
56
using Xunit;
67

@@ -83,4 +84,32 @@ public void FromDataJsonIncludesDimensionsWhenProvided()
8384
// Assert
8485
Assert.Contains($"{DimensionalityJsonPropertyName}:{Dimensions}", json);
8586
}
87+
88+
[Theory]
89+
[InlineData("TaskType")]
90+
[InlineData("Task_Type")]
91+
[InlineData("taskType")]
92+
[InlineData("task_Type")]
93+
[InlineData("tasktype")]
94+
[InlineData("task_type")]
95+
public void FromDataShouldIncludeTaskTypeWhenProvided(string additionalPropertyKeyName)
96+
{
97+
// Arrange
98+
var input = new[] { "This is a retrieval document." };
99+
var modelId = "embedding-001";
100+
var dimensions = 1024;
101+
var taskType = "RETRIEVAL_DOCUMENT";
102+
103+
var options = new EmbeddingGenerationOptions { AdditionalProperties = new AdditionalPropertiesDictionary { [additionalPropertyKeyName] = taskType } };
104+
105+
// Act
106+
var request = GoogleAIEmbeddingRequest.FromData(input, modelId, dimensions, options);
107+
108+
// Serialize to JSON (this is what would be sent in the HTTP request)
109+
var json = System.Text.Json.JsonSerializer.Serialize(request);
110+
111+
// Assert
112+
Assert.Contains("\"taskType\":\"RETRIEVAL_DOCUMENT\"", json);
113+
Assert.Contains("\"model\":\"models/embedding-001\"", json);
114+
}
86115
}

dotnet/src/Connectors/Connectors.Google/Core/GoogleAI/GoogleAIEmbeddingClient.cs

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
using System.Net.Http;
77
using System.Threading;
88
using System.Threading.Tasks;
9+
using Microsoft.Extensions.AI;
910
using Microsoft.Extensions.Logging;
1011

1112
namespace Microsoft.SemanticKernel.Connectors.Google.Core;
@@ -54,15 +55,18 @@ public GoogleAIEmbeddingClient(
5455
/// Generates embeddings for the given data asynchronously.
5556
/// </summary>
5657
/// <param name="data">The list of strings to generate embeddings for.</param>
58+
/// <param name="options">The embedding generation options.</param>
5759
/// <param name="cancellationToken">The cancellation token to cancel the operation.</param>
5860
/// <returns>Result contains a list of read-only memories of floats representing the generated embeddings.</returns>
5961
public async Task<IList<ReadOnlyMemory<float>>> GenerateEmbeddingsAsync(
6062
IList<string> data,
63+
EmbeddingGenerationOptions? options = null,
6164
CancellationToken cancellationToken = default)
6265
{
6366
Verify.NotNullOrEmpty(data);
6467

65-
var geminiRequest = this.GetEmbeddingRequest(data);
68+
var geminiRequest = this.GetEmbeddingRequest(data, options);
69+
6670
using var httpRequestMessage = await this.CreateHttpRequestAsync(geminiRequest, this._embeddingEndpoint).ConfigureAwait(false);
6771

6872
string body = await this.SendRequestAndGetStringBodyAsync(httpRequestMessage, cancellationToken)
@@ -71,8 +75,8 @@ public async Task<IList<ReadOnlyMemory<float>>> GenerateEmbeddingsAsync(
7175
return DeserializeAndProcessEmbeddingsResponse(body);
7276
}
7377

74-
private GoogleAIEmbeddingRequest GetEmbeddingRequest(IEnumerable<string> data)
75-
=> GoogleAIEmbeddingRequest.FromData(data, this._embeddingModelId, this._dimensions);
78+
private GoogleAIEmbeddingRequest GetEmbeddingRequest(IEnumerable<string> data, EmbeddingGenerationOptions? options = null)
79+
=> GoogleAIEmbeddingRequest.FromData(data, options?.ModelId ?? this._embeddingModelId, options?.Dimensions ?? this._dimensions, options);
7680

7781
private static List<ReadOnlyMemory<float>> DeserializeAndProcessEmbeddingsResponse(string body)
7882
=> ProcessEmbeddingsResponse(DeserializeResponse<GoogleAIEmbeddingResponse>(body));

dotnet/src/Connectors/Connectors.Google/Core/GoogleAI/GoogleAIEmbeddingRequest.cs

Lines changed: 40 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
using System.Collections.Generic;
44
using System.Linq;
55
using System.Text.Json.Serialization;
6+
using Microsoft.Extensions.AI;
67

78
namespace Microsoft.SemanticKernel.Connectors.Google.Core;
89

@@ -11,24 +12,48 @@ internal sealed class GoogleAIEmbeddingRequest
1112
[JsonPropertyName("requests")]
1213
public IList<RequestEmbeddingContent> Requests { get; set; } = null!;
1314

14-
public static GoogleAIEmbeddingRequest FromData(IEnumerable<string> data, string modelId, int? dimensions = null) => new()
15+
public static GoogleAIEmbeddingRequest FromData(IEnumerable<string> data, string modelId, int? dimensions = null, EmbeddingGenerationOptions? options = null)
1516
{
16-
Requests = data.Select(text => new RequestEmbeddingContent
17+
static string? GetTaskType(EmbeddingGenerationOptions? options)
1718
{
18-
Model = $"models/{modelId}",
19-
Content = new()
19+
if (options?.AdditionalProperties is not null)
2020
{
21-
Parts =
22-
[
23-
new()
24-
{
25-
Text = text
26-
}
27-
]
28-
},
29-
Dimensions = dimensions
30-
}).ToList()
31-
};
21+
object? taskType = null;
22+
object? task_type = null;
23+
24+
// AdditionalProperties is case-insensitive
25+
if (options?.AdditionalProperties.TryGetValue("task_type", out task_type) == true ||
26+
options?.AdditionalProperties.TryGetValue("tasktype", out taskType) == true)
27+
{
28+
return (task_type ?? taskType)?.ToString();
29+
}
30+
}
31+
32+
return null;
33+
}
34+
35+
var request = new GoogleAIEmbeddingRequest
36+
{
37+
Requests = [.. data.Select(text => new RequestEmbeddingContent
38+
{
39+
Model = $"models/{modelId}",
40+
Content = new()
41+
{
42+
Parts =
43+
[
44+
new()
45+
{
46+
Text = text
47+
}
48+
]
49+
},
50+
Dimensions = dimensions,
51+
TaskType = GetTaskType(options)
52+
})]
53+
};
54+
55+
return request;
56+
}
3257

3358
internal sealed class RequestEmbeddingContent
3459
{

dotnet/src/Connectors/Connectors.Google/Services/GoogleAITextEmbeddingGenerationService.cs

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
using System.Net.Http;
66
using System.Threading;
77
using System.Threading.Tasks;
8+
using Microsoft.Extensions.AI;
89
using Microsoft.Extensions.Logging;
910
using Microsoft.SemanticKernel.Connectors.Google.Core;
1011
using Microsoft.SemanticKernel.Embeddings;
@@ -68,6 +69,31 @@ public Task<IList<ReadOnlyMemory<float>>> GenerateEmbeddingsAsync(
6869
Kernel? kernel = null,
6970
CancellationToken cancellationToken = default)
7071
{
71-
return this._embeddingClient.GenerateEmbeddingsAsync(data, cancellationToken);
72+
return this._embeddingClient.GenerateEmbeddingsAsync(data, null, cancellationToken);
73+
}
74+
75+
/// <summary>
76+
/// Generates an embedding from the given <paramref name="data"/>.
77+
/// </summary>
78+
/// <param name="data">List of strings to generate embeddings for</param>
79+
/// <param name="options">Additional options for embedding generation</param>
80+
/// <param name="kernel">The <see cref="Kernel"/> containing services, plugins, and other state for use throughout the operation.</param>
81+
/// <param name="cancellationToken">The <see cref="CancellationToken"/> to monitor for cancellation requests. The default is <see cref="CancellationToken.None"/>.</param>
82+
/// <returns>List of embeddings</returns>
83+
/// <remarks>
84+
/// <para>
85+
/// The <paramref name="options"/> parameter can be used to override default settings such as <see cref="EmbeddingGenerationOptions.ModelId"/> and <see cref="EmbeddingGenerationOptions.Dimensions"/>
86+
/// </para>
87+
/// <para>
88+
/// Additionally a key/value of <c>"taskType"</c> can be provided in the <see cref="EmbeddingGenerationOptions.AdditionalProperties"/> for specific embedding tasks.
89+
/// </para>
90+
/// </remarks>
91+
public Task<IList<ReadOnlyMemory<float>>> GenerateEmbeddingsAsync(
92+
IList<string> data,
93+
EmbeddingGenerationOptions? options,
94+
Kernel? kernel = null,
95+
CancellationToken cancellationToken = default)
96+
{
97+
return this._embeddingClient.GenerateEmbeddingsAsync(data, options, cancellationToken);
7298
}
7399
}

0 commit comments

Comments
 (0)