using Microsoft.AspNetCore.Mvc; using Microsoft.Extensions.Primitives; namespace GameList.Infrastructure; public sealed class CsrfProtectionMiddleware(RequestDelegate next) { public async Task InvokeAsync(HttpContext context) { if (!ShouldValidate(context)) { await next(context); return; } if (IsSameOriginRequest(context)) { await next(context); return; } await WriteCsrfFailureAsync(context); } private static bool ShouldValidate(HttpContext context) { if (!context.Request.Path.StartsWithSegments("/api", StringComparison.OrdinalIgnoreCase)) return false; if (!HttpMethods.IsPost(context.Request.Method) && !HttpMethods.IsPut(context.Request.Method) && !HttpMethods.IsDelete(context.Request.Method) && !HttpMethods.IsPatch(context.Request.Method)) return false; return context.User.Identity?.IsAuthenticated == true; } private static bool IsSameOriginRequest(HttpContext context) { var originValues = context.Request.Headers.Origin; if (!StringValues.IsNullOrEmpty(originValues)) { foreach (var origin in originValues) { if (string.IsNullOrWhiteSpace(origin)) return false; if (!IsSameOrigin(origin, context)) return false; } return true; } var referer = context.Request.Headers.Referer.ToString(); if (string.IsNullOrWhiteSpace(referer)) return false; return IsSameOrigin(referer, context); } private static bool IsSameOrigin(string raw, HttpContext context) { if (!Uri.TryCreate(raw, UriKind.Absolute, out var uri)) return false; var requestScheme = context.Request.Scheme; if (!string.Equals(uri.Scheme, requestScheme, StringComparison.OrdinalIgnoreCase)) return false; var requestHost = context.Request.Host.Host; if (!string.Equals(uri.Host, requestHost, StringComparison.OrdinalIgnoreCase)) return false; var uriPort = uri.IsDefaultPort ? GetDefaultPort(uri.Scheme) : uri.Port; var requestPort = context.Request.Host.Port ?? GetDefaultPort(requestScheme); return uriPort == requestPort; } private static int GetDefaultPort(string scheme) { return string.Equals(scheme, "https", StringComparison.OrdinalIgnoreCase) ? 443 : 80; } private static Task WriteCsrfFailureAsync(HttpContext context) { if (context.Response.HasStarted) return Task.CompletedTask; context.Response.StatusCode = StatusCodes.Status400BadRequest; context.Response.ContentType = "application/problem+json"; var problem = new ProblemDetails { Status = StatusCodes.Status400BadRequest, Title = "Bad Request", Detail = "CSRF validation failed.", Extensions = { ["error"] = "CSRF validation failed." } }; return context.Response.WriteAsJsonAsync(problem); } }