Refactor Temporal Anti-Aliasing shader to have less ghosting and quality scalability

Fix lots of ghosting artifacts.
Fix aliasing on small cables/lines during camera movement.
Add scalability via AA Quality setting for TAA.
This commit is contained in:
2026-04-30 17:58:19 +02:00
parent 3a5d831e71
commit 24675ace93
5 changed files with 281 additions and 93 deletions
Binary file not shown.
+24 -13
View File
@@ -9,22 +9,23 @@
#include "Engine/Renderer/RenderList.h"
#include "Engine/Renderer/GBufferPass.h"
#include "Engine/Engine/Engine.h"
#include "Engine/Graphics/Graphics.h"
#include "Engine/Graphics/RenderTools.h"
GPU_CB_STRUCT(Data {
Float2 ScreenSizeInv;
Float2 JitterInv;
float Sharpness;
Float2 MotionScale;
float StationaryBlending;
float MotionBlending;
float Dummy0;
Float3 QuantizationError;
float Dummy1;
float Sharpness;
ShaderGBufferData GBuffer;
});
bool TAA::Init()
{
_psTAA.CreatePipelineStates();
_shader = Content::LoadAsyncInternal<Shader>(TEXT("Shaders/TAA"));
if (_shader == nullptr)
return true;
@@ -40,14 +41,11 @@ bool TAA::setupResources()
return true;
const auto shader = _shader->GetShader();
CHECK_INVALID_SHADER_PASS_CB_SIZE(shader, 0, Data);
if (!_psTAA)
_psTAA = GPUDevice::Instance->CreatePipelineState();
GPUPipelineState::Description psDesc;
if (!_psTAA->IsValid())
if (!_psTAA.IsValid())
{
psDesc = GPUPipelineState::Description::DefaultFullscreenTriangle;
psDesc.PS = shader->GetPS("PS");
if (_psTAA->Init(psDesc))
if (_psTAA.Create(psDesc, shader, "PS"))
return true;
}
return false;
@@ -58,15 +56,13 @@ void TAA::Dispose()
// Base
RendererPass::Dispose();
SAFE_DELETE_GPU_RESOURCE(_psTAA);
_psTAA.Delete();
_shader = nullptr;
}
void TAA::Render(const RenderContext& renderContext, GPUTexture* input, GPUTextureView* output)
{
auto context = GPUDevice::Instance->GetMainContext();
// Ensure to have valid data
if (checkIfSkipPass())
{
// Resources are missing. Do not perform rendering, just copy source frame.
@@ -121,9 +117,10 @@ void TAA::Render(const RenderContext& renderContext, GPUTexture* input, GPUTextu
data.ScreenSizeInv.Y = renderContext.View.ScreenSize.W;
data.JitterInv.X = renderContext.View.TemporalAAJitter.X / (float)tempDesc.Width;
data.JitterInv.Y = renderContext.View.TemporalAAJitter.Y / (float)tempDesc.Height;
data.Sharpness = settings.TAA_Sharpness;
data.Sharpness = settings.TAA_Sharpness * 3; // Hardcoded scale
data.StationaryBlending = settings.TAA_StationaryBlending * blendStrength;
data.MotionBlending = settings.TAA_MotionBlending * blendStrength;
data.MotionScale = 0.1f / data.ScreenSizeInv; // Hardcoded scale
data.QuantizationError = RenderTools::GetColorQuantizationError(tempDesc.Format);
GBufferPass::SetInputs(renderContext.View, data.GBuffer);
const auto cb = _shader->GetShader()->GetCB(0);
@@ -136,7 +133,21 @@ void TAA::Render(const RenderContext& renderContext, GPUTexture* input, GPUTextu
// Render
context->SetRenderTarget(output);
context->SetState(_psTAA);
int qualityLevel;
switch (Graphics::AAQuality)
{
case Quality::Low:
qualityLevel = 0;
break;
case Quality::Medium:
qualityLevel = 1;
break;
case Quality::High:
case Quality::Ultra:
qualityLevel = 2;
break;
}
context->SetState(_psTAA.Get(qualityLevel));
context->DrawFullscreenTriangle();
// Update the history
+2 -7
View File
@@ -11,9 +11,8 @@
class TAA : public RendererPass<TAA>
{
private:
AssetReference<Shader> _shader;
GPUPipelineState* _psTAA;
GPUPipelineStatePermutationsPs<3> _psTAA;
public:
/// <summary>
@@ -25,18 +24,15 @@ public:
void Render(const RenderContext& renderContext, GPUTexture* input, GPUTextureView* output);
private:
#if COMPILE_WITH_DEV_ENV
void OnShaderReloading(Asset* obj)
{
if (_psTAA)
_psTAA->ReleaseGPU();
_psTAA.Release();
invalidateResources();
}
#endif
public:
// [RendererPass]
String ToString() const override
{
@@ -46,7 +42,6 @@ public:
void Dispose() override;
protected:
// [RendererPass]
bool setupResources() override;
};
+1
View File
@@ -160,6 +160,7 @@ float4 LoadTextureWGSL(Texture2D tex, float2 uv)
#define HDR_CLAMP_MAX 65472.0
#define PI 3.1415926535897932
#define UNITS_TO_METERS_SCALE 0.01f
#define REVERSE_Z 0
// Structure that contains information about GBuffer
struct GBufferData
+252 -71
View File
@@ -1,7 +1,28 @@
// Copyright (c) Wojciech Figat. All rights reserved.
#define DEBUG_HISTORY_REJECTION 0
// Implementation based on:
// "Temporal Reprojection Anti-Aliasing in INSIDE", Lasse Jon Fuglsang Pedersen at GDC 2026
// Source: https://github.com/playdeadgames/temporal, MIT Licence, 2015 Playdead
// Configs
#define UNJITTER_INPUT 1
#define UNJITTER_NEIGHBORHOOD 0
#define UNJITTER_VELOCITY_DEPTH 0
#define MINMAX_3X3 1
//#define MINMAX_3X3_ROUNDED 1
//#define MINMAX_4TAP_VARYING 1
#define DEBUG_LUMINANCE_DIFF 0
#define DEBUG_MOTION 0
#define DEBUG_VELOCITY_REJECTION 0
#define TAA_EPSILON 0.000001f
#define NO_GBUFFER_SAMPLING
#define NEED_DEPTH_VELOCITY (MINMAX_4TAP_VARYING)
#if NEED_DEPTH_VELOCITY
#define VelocityDepth float3
#else
#define VelocityDepth float2
#endif
#include "./Flax/Common.hlsl"
#include "./Flax/GBuffer.hlsl"
@@ -11,12 +32,11 @@
META_CB_BEGIN(0, Data)
float2 ScreenSizeInv;
float2 JitterInv;
float Sharpness;
float2 MotionScale;
float StationaryBlending;
float MotionBlending;
float Dummy0;
float3 QuantizationError;
float Dummy1;
float Sharpness;
GBufferData GBuffer;
META_CB_END
@@ -25,82 +45,243 @@ Texture2D InputHistory : register(t1);
Texture2D MotionVectors : register(t2);
Texture2D Depth : register(t3);
// Samples nearby pixels and returns (uv, raw depth) of the closest one to camera.
float3 FindClosestDepth3x3(float2 uv)
{
float2 du = float2(ScreenSizeInv.x, 0.0);
float2 dv = float2(0.0, ScreenSizeInv.y);
float3 dtl = float3(-1, -1, SAMPLE_RT_DEPTH(Depth, uv - dv - du));
float3 dtc = float3( 0, -1, SAMPLE_RT_DEPTH(Depth, uv - dv));
float3 dtr = float3( 1, -1, SAMPLE_RT_DEPTH(Depth, uv - dv + du));
float3 dml = float3(-1, 0, SAMPLE_RT_DEPTH(Depth, uv - du));
float3 dmc = float3( 0, 0, SAMPLE_RT_DEPTH(Depth, uv).x);
float3 dmr = float3( 1, 0, SAMPLE_RT_DEPTH(Depth, uv + du));
float3 dbl = float3(-1, 1, SAMPLE_RT_DEPTH(Depth, uv + dv - du));
float3 dbc = float3( 0, 1, SAMPLE_RT_DEPTH(Depth, uv + dv));
float3 dbr = float3( 1, 1, SAMPLE_RT_DEPTH(Depth, uv + dv + du));
float3 dmin = dtl;
#if REVERSE_Z
#define FIND_MIN(a, b) a < b
#else
#define FIND_MIN(a, b) a > b
#endif
if (FIND_MIN(dmin.z, dtc.z)) dmin = dtc;
if (FIND_MIN(dmin.z, dtr.z)) dmin = dtr;
if (FIND_MIN(dmin.z, dml.z)) dmin = dml;
if (FIND_MIN(dmin.z, dmc.z)) dmin = dmc;
if (FIND_MIN(dmin.z, dmr.z)) dmin = dmr;
if (FIND_MIN(dmin.z, dbl.z)) dmin = dbl;
if (FIND_MIN(dmin.z, dbc.z)) dmin = dbc;
if (FIND_MIN(dmin.z, dbr.z)) dmin = dbr;
#undef FIND_MIN
return float3(uv + ScreenSizeInv.xy * dmin.xy, dmin.z);
}
// Samples nearby pixels and returns (velocity, uv) of the pixel with largest velocity.
float4 FindLargestVelocity(float2 uv, int range)
{
float4 result = float4(0, 0, uv.x, uv.y);
float vLargest = 0.0f;
for (int y = -range; y <= range; y++)
{
for (int x = -range; x <= range; x++)
{
float2 vUV = uv + ScreenSizeInv * float2(x, y);
float2 v = SAMPLE_RT_LINEAR(MotionVectors, vUV).xy;
float vSize = dot(v, v);
if (vSize > vLargest)
{
result = float4(v, vUV);
vLargest = vSize;
}
}
}
return result;
}
VelocityDepth SampleVelocityDepth(float2 uv)
{
#if UNJITTER_VELOCITY_DEPTH
uv -= JitterInv.xy;
#endif
VelocityDepth velocityDepth;
#if QUALITY >= 2
// 3x3 search for the largest velocity
float4 largestVelocity = FindLargestVelocity(uv, 3);
velocityDepth.xy = largestVelocity.xy;
#if NEED_DEPTH_VELOCITY
velocityDepth.z = SAMPLE_RT_DEPTH(Depth, largestVelocity.zw);
#endif
#elif QUALITY >= 1
// 3x3 search of the closest depth
float3 closestDepth = FindClosestDepth3x3(uv);
velocityDepth.xy = SAMPLE_RT_LINEAR(MotionVectors, closestDepth.xy).xy;
#if NEED_DEPTH_VELOCITY
velocityDepth.z = closestDepth.z;
#endif
#else // QUALITY == 0
// Current sample
velocityDepth.xy = SAMPLE_RT_LINEAR(MotionVectors, uv).xy;
#if NEED_DEPTH_VELOCITY
velocityDepth.z = SAMPLE_RT_DEPTH(Depth, uv);
#endif
#endif
#if NEED_DEPTH_VELOCITY
// Linearize depth
velocityDepth.z = LinearizeZ(GBuffer, velocityDepth.z);
#endif
return velocityDepth;
}
float4 ClipAAB(float3 aabbMin, float3 aabbMax, float4 p, float4 q)
{
// only clips towards aabb center
float3 pClip = 0.5 * (aabbMax + aabbMin);
float3 eClip = 0.5 * (aabbMax - aabbMin) + TAA_EPSILON;
float4 vClip = q - float4(pClip, p.w);
float3 vUnit = vClip.xyz / eClip;
float3 aUnit = abs(vUnit);
float maUnit = max(aUnit.x, max(aUnit.y, aUnit.z));
if (maUnit > 1.0)
return float4(pClip, p.w) + vClip / maUnit;
else
return q; // point inside aabb
}
// Pixel Shader for Temporal Anti-Aliasing
META_PS(true, FEATURE_LEVEL_ES2)
META_PERMUTATION_1(QUALITY=0)
META_PERMUTATION_1(QUALITY=1)
META_PERMUTATION_1(QUALITY=2)
float4 PS(Quad_VS2PS input) : SV_Target0
{
float2 tanHalfFOV = float2(GBuffer.InvProjectionMatrix[0][0], GBuffer.InvProjectionMatrix[1][1]);
// Sample velocity and depth
VelocityDepth velocityDepth = SampleVelocityDepth(input.TexCoord);
// Calculate previous frame UVs based on per-pixel velocity
float2 velocity = SAMPLE_RT_LINEAR(MotionVectors, input.TexCoord).xy;
float velocityLength = length(velocity);
float2 prevUV = input.TexCoord - velocity;
float prevDepth = LinearizeZ(GBuffer, SAMPLE_RT_DEPTH(Depth, prevUV));
// Find the closest pixel in 3x3 neighborhood
float currentDepth = 1;
float4 neighborhoodMin = 100000;
float4 neighborhoodMax = -10000;
float4 current;
float4 neighborhoodSum = 0;
float minDepthDiff = 100000;
for (int x = -1; x <= 1; ++x)
{
for (int y = -1; y <= 1; ++y)
{
float2 sampleUV = input.TexCoord + float2(x, y) * ScreenSizeInv;
float4 neighbor = SAMPLE_RT(Input, sampleUV);
neighborhoodMin = min(neighborhoodMin, neighbor);
neighborhoodMax = max(neighborhoodMax, neighbor);
neighborhoodSum += neighbor;
float neighborDepth = LinearizeZ(GBuffer, SAMPLE_RT_DEPTH(Depth, sampleUV));
float depthDiff = abs(max(neighborDepth - prevDepth, 0));
minDepthDiff = min(minDepthDiff, depthDiff);
if (x == 0 && y == 0)
{
current = neighbor;
currentDepth = neighborDepth;
}
}
}
// Apply sharpening
float4 neighborhoodAvg = neighborhoodSum / 9.0;
current += (current - neighborhoodAvg) * Sharpness;
// Sample history by clamp it to the nearby colors range to reduce artifacts
current = clamp(current, 0, HDR_CLAMP_MAX);
float4 history = SAMPLE_RT_LINEAR(InputHistory, prevUV);
float lumaOffset = abs(Luminance(neighborhoodAvg.rgb) - Luminance(current.rgb));
float aabbMargin = lerp(4.0, 0.25, saturate(velocityLength * 100.0)) * lumaOffset;
history = ClipToAABB(history, neighborhoodMin - aabbMargin, neighborhoodMax + aabbMargin);
//history = clamp(history, neighborhoodMin, neighborhoodMax);
// Calculate history blending factor
float motion = saturate(velocityLength * 1000.0f);
float blendfactor = lerp(StationaryBlending, MotionBlending, motion);
// Perform linear accumulation of the previous samples with a current one
float4 color = lerp(current, history, blendfactor);
// Reduce history blend in favor of neighborhood blend when sample has no valid prevous frame data
float miss = any(abs(prevUV * 2 - 1) >= 1.0f) ? 1 : 0;
float currentDepthWorld = currentDepth * GBuffer.ViewFar;
float minDepthDiffWorld = minDepthDiff * GBuffer.ViewFar;
float depthError = tanHalfFOV.x * ScreenSizeInv.x * 200.0f * (currentDepthWorld + 10.0f);
miss += minDepthDiffWorld > depthError ? 1 : 0;
float4 neighborhoodSharp = lerp(neighborhoodAvg, current, 0.5f);
#if DEBUG_HISTORY_REJECTION
neighborhoodSharp = float4(1, 0, 0, 1);
// Sample image and history
#if UNJITTER_INPUT
float4 current = SAMPLE_RT(Input, input.TexCoord - JitterInv.xy);
#else
float4 current = SAMPLE_RT(Input, input.TexCoord);
#endif
color = lerp(color, neighborhoodSharp, saturate(miss));
color = clamp(color, 0, HDR_CLAMP_MAX);
float4 history = SAMPLE_RT_LINEAR(InputHistory, input.TexCoord - velocityDepth.xy);
// Calculate min-max of current pixel neighbourhood
#if UNJITTER_NEIGHBORHOOD
float2 uv = input.TexCoord - JitterInv.xy;
#else
float2 uv = input.TexCoord;
#endif
#if MINMAX_3X3 || MINMAX_3X3_ROUNDED
float2 du = float2(ScreenSizeInv.x, 0.0);
float2 dv = float2(0.0, ScreenSizeInv.y);
float4 ctl = SAMPLE_RT(Input, uv - dv - du);
float4 ctc = SAMPLE_RT(Input, uv - dv);
float4 ctr = SAMPLE_RT(Input, uv - dv + du);
float4 cml = SAMPLE_RT(Input, uv - du);
float4 cmc = SAMPLE_RT(Input, uv);
float4 cmr = SAMPLE_RT(Input, uv + du);
float4 cbl = SAMPLE_RT(Input, uv + dv - du);
float4 cbc = SAMPLE_RT(Input, uv + dv);
float4 cbr = SAMPLE_RT(Input, uv + dv + du);
float4 cMin = min(ctl, min(ctc, min(ctr, min(cml, min(cmc, min(cmr, min(cbl, min(cbc, cbr))))))));
float4 cMax = max(ctl, max(ctc, max(ctr, max(cml, max(cmc, max(cmr, max(cbl, max(cbc, cbr))))))));
float4 cAvg = (ctl + ctc + ctr + cml + cmc + cmr + cbl + cbc + cbr) / 9.0f;
float4 corners = 4.0f * (ctl + cbr) - 2.0f * current;
#if MINMAX_3X3_ROUNDED
float4 cMin5 = min(ctc, min(cml, min(cmc, min(cmr, cbc))));
float4 cMax5 = max(ctc, max(cml, max(cmc, max(cmr, cbc))));
float4 cAvg5 = (ctc + cml + cmc + cmr + cbc) / 5.0;
cMin = 0.5 * (cMin + cMin5);
cMax = 0.5 * (cMax + cMax5);
cAvg = 0.5 * (cAvg + cAvg5);
#endif
#elif MINMAX_4TAP_VARYING
const float SubpixelThreshold = 0.5;
const float GatherBase = 0.5;
const float GatherSubpixelMotion = 0.1666;
float velocityMagnitude = length(velocityDepth.xy / ScreenSizeInv.xy) * velocityDepth.z;
float subpixelMotion = saturate(SubpixelThreshold / (velocityMagnitude + TAA_EPSILON));
float minMaxSupport = GatherBase + GatherSubpixelMotion * subpixelMotion;
float2 ssOffset01 = minMaxSupport * float2(-ScreenSizeInv.x, ScreenSizeInv.y);
float2 ssOffset11 = minMaxSupport * float2(ScreenSizeInv.x, ScreenSizeInv.y);
float4 c00 = SAMPLE_RT_LINEAR(Input, uv - ssOffset11);
float4 c10 = SAMPLE_RT_LINEAR(Input, uv - ssOffset01);
float4 c01 = SAMPLE_RT_LINEAR(Input, uv + ssOffset01);
float4 c11 = SAMPLE_RT_LINEAR(Input, uv + ssOffset11);
float4 corners = 4.0f * (c00 + c11) - 2.0f * current;
float4 cMin = min(c00, min(c10, min(c01, c11)));
float4 cMax = max(c00, max(c10, max(c01, c11)));
float4 cAvg = (c00 + c10 + c01 + c11) / 4.0f;
#endif
// Apply sharpening
current += (current - (corners * 0.166667)) * Sharpness;
current = clamp(current, 0, HDR_CLAMP_MAX);
// Clamp current sample to neighbourhood pixels nearby colors to reduce ghosting artifacts
history = ClipAAB(cMin.xyz, cMax.xyz, clamp(cAvg, cMin, cMax), history);
//history = clamp(history, cMin, cMax);
// Calculate history weight from unbiased luminance diff
// [Reference: "TSSAA (Temporal Super-Sampling AA)" by Timothy Lottes (2011)]
float currentLum = Luminance(current.rgb);
float historyLum = Luminance(history.rgb);
float unbiasedDiff = abs(currentLum - historyLum) / max(currentLum, max(historyLum, 0.2f));
float unbiasedWeight = 1.0 - unbiasedDiff;
float unbiasedWeightSqr = unbiasedWeight * unbiasedWeight;
#if DEBUG_LUMINANCE_DIFF
return unbiasedWeightSqr.xxxx;
#endif
float historyBlend = lerp(MotionBlending, min(MotionBlending + 0.2f, 0.97f), unbiasedWeightSqr);
// Higher history blend when there is no motion
float motion = saturate(length(velocityDepth.xy) * 1000.0f);
#if DEBUG_MOTION
return motion.xxxx;
#endif
historyBlend = lerp(StationaryBlending, historyBlend, motion);
// Lower history blend when motion goes outside the view
const float velocityMin = 2.0f;
const float velocityMax = 20.0f;
const float velocityRange = velocityMax - velocityMin;
float velocityRejection = clamp(length(velocityDepth.xy * MotionScale) - velocityMin, 0.0, velocityRange) / velocityRange;
#if DEBUG_VELOCITY_REJECTION
return velocityRejection.xxxx;
#endif
historyBlend = lerp(historyBlend, 0.0f, velocityRejection);
// Blend current frame with history
float4 color = lerp(current, history, historyBlend);
// Apply quantization error to reduce yellowish artifacts due to R11G11B10 format
float noise = rand2dTo1d(input.TexCoord);
color.rgb = QuantizeColor(color.rgb, noise, QuantizationError);
return color;
return color;
}