106 lines
3.2 KiB
C#
106 lines
3.2 KiB
C#
using Microsoft.AspNetCore.Mvc;
|
|
using Microsoft.Extensions.Primitives;
|
|
|
|
namespace GameList.Infrastructure;
|
|
|
|
public sealed class CsrfProtectionMiddleware(RequestDelegate next)
|
|
{
|
|
public async Task InvokeAsync(HttpContext context)
|
|
{
|
|
if (!ShouldValidate(context))
|
|
{
|
|
await next(context);
|
|
return;
|
|
}
|
|
|
|
if (IsSameOriginRequest(context))
|
|
{
|
|
await next(context);
|
|
return;
|
|
}
|
|
|
|
await WriteCsrfFailureAsync(context);
|
|
}
|
|
|
|
private static bool ShouldValidate(HttpContext context)
|
|
{
|
|
if (!context.Request.Path.StartsWithSegments("/api", StringComparison.OrdinalIgnoreCase))
|
|
return false;
|
|
|
|
if (!HttpMethods.IsPost(context.Request.Method)
|
|
&& !HttpMethods.IsPut(context.Request.Method)
|
|
&& !HttpMethods.IsDelete(context.Request.Method)
|
|
&& !HttpMethods.IsPatch(context.Request.Method))
|
|
return false;
|
|
|
|
return context.User.Identity?.IsAuthenticated == true;
|
|
}
|
|
|
|
private static bool IsSameOriginRequest(HttpContext context)
|
|
{
|
|
var originValues = context.Request.Headers.Origin;
|
|
if (!StringValues.IsNullOrEmpty(originValues))
|
|
{
|
|
foreach (var origin in originValues)
|
|
{
|
|
if (string.IsNullOrWhiteSpace(origin))
|
|
return false;
|
|
|
|
if (!IsSameOrigin(origin, context))
|
|
return false;
|
|
}
|
|
|
|
return true;
|
|
}
|
|
|
|
var referer = context.Request.Headers.Referer.ToString();
|
|
if (string.IsNullOrWhiteSpace(referer))
|
|
return false;
|
|
|
|
return IsSameOrigin(referer, context);
|
|
}
|
|
|
|
private static bool IsSameOrigin(string raw, HttpContext context)
|
|
{
|
|
if (!Uri.TryCreate(raw, UriKind.Absolute, out var uri))
|
|
return false;
|
|
|
|
var requestScheme = context.Request.Scheme;
|
|
if (!string.Equals(uri.Scheme, requestScheme, StringComparison.OrdinalIgnoreCase))
|
|
return false;
|
|
|
|
var requestHost = context.Request.Host.Host;
|
|
if (!string.Equals(uri.Host, requestHost, StringComparison.OrdinalIgnoreCase))
|
|
return false;
|
|
|
|
var uriPort = uri.IsDefaultPort ? GetDefaultPort(uri.Scheme) : uri.Port;
|
|
var requestPort = context.Request.Host.Port ?? GetDefaultPort(requestScheme);
|
|
|
|
return uriPort == requestPort;
|
|
}
|
|
|
|
private static int GetDefaultPort(string scheme)
|
|
{
|
|
return string.Equals(scheme, "https", StringComparison.OrdinalIgnoreCase) ? 443 : 80;
|
|
}
|
|
|
|
private static Task WriteCsrfFailureAsync(HttpContext context)
|
|
{
|
|
if (context.Response.HasStarted)
|
|
return Task.CompletedTask;
|
|
|
|
context.Response.StatusCode = StatusCodes.Status400BadRequest;
|
|
context.Response.ContentType = "application/problem+json";
|
|
|
|
var problem = new ProblemDetails
|
|
{
|
|
Status = StatusCodes.Status400BadRequest,
|
|
Title = "Bad Request",
|
|
Detail = "CSRF validation failed.",
|
|
Extensions = { ["error"] = "CSRF validation failed." }
|
|
};
|
|
|
|
return context.Response.WriteAsJsonAsync(problem);
|
|
}
|
|
}
|