Skip to content

Commit 1c9ed13

Browse files
committed
Fix missing scaling
1 parent 55ab0fe commit 1c9ed13

File tree

2 files changed

+15
-4
lines changed

2 files changed

+15
-4
lines changed

src/FaceAiSharp/FaceAiSharp.csproj

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ This package contains just FaceAiSharp's managed code and does not include any O
2929
<PackageReference Include="SimpleSIMD" />
3030
<PackageReference Include="SixLabors.ImageSharp" />
3131
<PackageReference Include="SixLabors.ImageSharp.Drawing" />
32+
<PackageReference Include="System.Numerics.Tensors" />
3233
</ItemGroup>
3334

3435
</Project>

src/FaceAiSharp/ScrfdDetector.cs

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
// Copyright (c) Georg Jung. All rights reserved.
22
// Licensed under the MIT license. See LICENSE file in the project root for full license information.
33

4+
using System.Numerics.Tensors;
45
using System.Runtime.CompilerServices;
56
using FaceAiSharp.Extensions;
67
using FaceAiSharp.Simd;
@@ -170,7 +171,7 @@ internal IReadOnlyCollection<FaceDetectorResult> Detect(DenseTensor<float> input
170171
var strideResults = new List<FaceDetectorResult>();
171172
foreach (var (idx, stride) in _modelParameters.FeatStrideFpn.Select((val, idx) => (idx, val)))
172173
{
173-
var strideRes = HandleStride(idx, stride, outputs.ToList(), imgSize);
174+
var strideRes = HandleStride(idx, stride, outputs.ToList(), imgSize, scale);
174175
strideResults.AddRange(strideRes ?? []);
175176
}
176177

@@ -214,15 +215,15 @@ private static List<int> IndicesOfElementsLargerThanOrEqual(float[] input, float
214215
return indices;
215216
}
216217

217-
private static IReadOnlyList<PointF> Kps(ReadOnlySpan<float> flatKps, int anchorX, int anchorY, int stride) => [
218+
private static IReadOnlyList<PointF> Kps(ReadOnlySpan<float> flatKps, float anchorX, float anchorY, int stride) => [
218219
new(anchorX + (flatKps[0] * stride), anchorY + (flatKps[1] * stride)),
219220
new(anchorX + (flatKps[2] * stride), anchorY + (flatKps[3] * stride)),
220221
new(anchorX + (flatKps[4] * stride), anchorY + (flatKps[5] * stride)),
221222
new(anchorX + (flatKps[6] * stride), anchorY + (flatKps[7] * stride)),
222223
new(anchorX + (flatKps[8] * stride), anchorY + (flatKps[9] * stride)),
223224
];
224225

225-
private List<FaceDetectorResult>? HandleStride(int strideIndex, int stride, IReadOnlyList<NamedOnnxValue> outputs, Size inputSize)
226+
private List<FaceDetectorResult>? HandleStride(int strideIndex, int stride, IReadOnlyList<NamedOnnxValue> outputs, Size inputSize, float scale)
226227
{
227228
var thresh = Options.ConfidenceThreshold;
228229
var scores = outputs[strideIndex].ToArray<float>();
@@ -234,11 +235,20 @@ private static IReadOnlyList<PointF> Kps(ReadOnlySpan<float> flatKps, int anchor
234235

235236
var bboxPreds = outputs[strideIndex + _modelParameters.Fmc].ToArray<float>();
236237
var kpsPreds = outputs.ElementAtOrDefault(strideIndex + (_modelParameters.Fmc * 2))?.ToArray<float>();
238+
if (scale != 1.0f)
239+
{
240+
TensorPrimitives.Multiply(bboxPreds, scale, bboxPreds);
241+
if (kpsPreds is not null)
242+
{
243+
TensorPrimitives.Multiply(kpsPreds, scale, kpsPreds);
244+
}
245+
}
237246

238247
var returnValues = new List<FaceDetectorResult>(indicesAboveThreshold.Count);
239248
foreach (var anchorIdx in indicesAboveThreshold)
240249
{
241-
var (x, y) = GetAnchorCenter(inputSize, stride, _modelParameters.NumAnchors, anchorIdx);
250+
(float x, float y) = GetAnchorCenter(inputSize, stride, _modelParameters.NumAnchors, anchorIdx);
251+
(x, y) = (x * scale, y * scale);
242252
var bboxBaseIdx = anchorIdx * 4;
243253
var (x0diff, y0diff, x1diff, y1diff) = (bboxPreds[bboxBaseIdx + 0] * stride, bboxPreds[bboxBaseIdx + 1] * stride, bboxPreds[bboxBaseIdx + 2] * stride, bboxPreds[bboxBaseIdx + 3] * stride);
244254
var bbox = new RectangleF(x - x0diff, y - y0diff, x0diff + x1diff, y0diff + y1diff);

0 commit comments

Comments
 (0)