Files
GameList/Infrastructure/CsrfProtectionMiddleware.cs

106 lines
3.2 KiB
C#

using Microsoft.AspNetCore.Mvc;
using Microsoft.Extensions.Primitives;
namespace GameList.Infrastructure;
public sealed class CsrfProtectionMiddleware(RequestDelegate next)
{
public async Task InvokeAsync(HttpContext context)
{
if (!ShouldValidate(context))
{
await next(context);
return;
}
if (IsSameOriginRequest(context))
{
await next(context);
return;
}
await WriteCsrfFailureAsync(context);
}
private static bool ShouldValidate(HttpContext context)
{
if (!context.Request.Path.StartsWithSegments("/api", StringComparison.OrdinalIgnoreCase))
return false;
if (!HttpMethods.IsPost(context.Request.Method)
&& !HttpMethods.IsPut(context.Request.Method)
&& !HttpMethods.IsDelete(context.Request.Method)
&& !HttpMethods.IsPatch(context.Request.Method))
return false;
return context.User.Identity?.IsAuthenticated == true;
}
private static bool IsSameOriginRequest(HttpContext context)
{
var originValues = context.Request.Headers.Origin;
if (!StringValues.IsNullOrEmpty(originValues))
{
foreach (var origin in originValues)
{
if (string.IsNullOrWhiteSpace(origin))
return false;
if (!IsSameOrigin(origin, context))
return false;
}
return true;
}
var referer = context.Request.Headers.Referer.ToString();
if (string.IsNullOrWhiteSpace(referer))
return false;
return IsSameOrigin(referer, context);
}
private static bool IsSameOrigin(string raw, HttpContext context)
{
if (!Uri.TryCreate(raw, UriKind.Absolute, out var uri))
return false;
var requestScheme = context.Request.Scheme;
if (!string.Equals(uri.Scheme, requestScheme, StringComparison.OrdinalIgnoreCase))
return false;
var requestHost = context.Request.Host.Host;
if (!string.Equals(uri.Host, requestHost, StringComparison.OrdinalIgnoreCase))
return false;
var uriPort = uri.IsDefaultPort ? GetDefaultPort(uri.Scheme) : uri.Port;
var requestPort = context.Request.Host.Port ?? GetDefaultPort(requestScheme);
return uriPort == requestPort;
}
private static int GetDefaultPort(string scheme)
{
return string.Equals(scheme, "https", StringComparison.OrdinalIgnoreCase) ? 443 : 80;
}
private static Task WriteCsrfFailureAsync(HttpContext context)
{
if (context.Response.HasStarted)
return Task.CompletedTask;
context.Response.StatusCode = StatusCodes.Status400BadRequest;
context.Response.ContentType = "application/problem+json";
var problem = new ProblemDetails
{
Status = StatusCodes.Status400BadRequest,
Title = "Bad Request",
Detail = "CSRF validation failed.",
Extensions = { ["error"] = "CSRF validation failed." }
};
return context.Response.WriteAsJsonAsync(problem);
}
}