Files
GameList/Endpoints/EndpointHelpers.cs

376 lines
13 KiB
C#

using GameList.Data;
using GameList.Domain;
using Microsoft.EntityFrameworkCore;
using System.Security.Claims;
namespace GameList.Endpoints;
internal static class EndpointHelpers
{
public static async Task<Player?> GetAuthenticatedPlayer(HttpContext ctx, AppDbContext db)
{
if (ctx.User.Identity?.IsAuthenticated != true)
return null;
if (ctx.Items.TryGetValue(nameof(Player), out var cached) && cached is Player cachedPlayer)
return cachedPlayer;
var idValue = ctx.User.FindFirstValue(ClaimTypes.NameIdentifier);
if (string.IsNullOrWhiteSpace(idValue) || !Guid.TryParse(idValue, out var playerId))
{
// Auth cookie is present but malformed; clear and reject.
await Infrastructure.PlayerIdentityExtensions.SignOutPlayerAsync(ctx);
return null;
}
var existing = await db.Players.FindAsync(playerId);
if (existing is null)
{
await Infrastructure.PlayerIdentityExtensions.SignOutPlayerAsync(ctx);
return null;
}
ctx.Items[nameof(Player)] = existing;
return existing;
}
public static async Task<Phase> GetCurrentPhaseAsync(AppDbContext db, Guid playerId)
{
var playerPhase = await db.Players
.AsNoTracking()
.Where(p => p.Id == playerId)
.Select(p => (Phase?)p.CurrentPhase)
.FirstOrDefaultAsync();
if (playerPhase is null)
return Phase.Suggest;
var resultsOpen = await db.AppState.AsNoTracking().Select(s => s.ResultsOpen).SingleAsync();
return GetCurrentPhase(playerPhase.Value, resultsOpen);
}
public static Phase GetCurrentPhase(Phase phase, bool resultsOpen)
{
var normalized = NormalizePhase(phase);
if (resultsOpen)
return Phase.Results;
return normalized == Phase.Results ? Phase.Vote : normalized;
}
public static bool ReconcilePlayerPhase(Player player, bool resultsOpen)
{
var changed = false;
var normalized = NormalizePhase(player.CurrentPhase);
if (player.CurrentPhase != normalized)
{
player.CurrentPhase = normalized;
changed = true;
}
if (resultsOpen && player.CurrentPhase != Phase.Results)
{
player.CurrentPhase = Phase.Results;
changed = true;
}
else if (!resultsOpen && player.CurrentPhase == Phase.Results)
{
player.CurrentPhase = Phase.Vote;
player.VotesFinal = false;
changed = true;
}
return changed;
}
private static Phase NormalizePhase(Phase phase)
{
return phase switch
{
Phase.Suggest => Phase.Suggest,
Phase.Vote => Phase.Vote,
Phase.Results => Phase.Results,
_ => Phase.Vote // legacy/invalid phase fallback
};
}
public static IResult PhaseMismatch(Phase required, Phase current) =>
BadRequestError($"This endpoint is available in the {required} phase. Your current phase is {current}.");
public static IResult BadRequestError(string detail) => Problem(StatusCodes.Status400BadRequest, "Bad Request", detail);
public static IResult NotFoundError(string detail) => Problem(StatusCodes.Status404NotFound, "Not Found", detail);
public static IResult ConflictError(string detail) => Problem(StatusCodes.Status409Conflict, "Conflict", detail);
public static IResult UnauthorizedError(string detail = "Unauthorized") => Problem(StatusCodes.Status401Unauthorized, "Unauthorized", detail);
private static IResult Problem(int statusCode, string title, string detail)
{
return Results.Problem(
statusCode: statusCode,
title: title,
detail: detail,
extensions: new Dictionary<string, object?>
{
["error"] = detail
}
);
}
public static string? TrimTo(string? input, int max) =>
string.IsNullOrWhiteSpace(input) ? null : input.Trim() is { Length: > 0 } t ? t[..Math.Min(t.Length, max)] : null;
public static bool IsValidImageUrl(string? url)
{
if (string.IsNullOrWhiteSpace(url))
return true; // empty is acceptable
if (!Uri.TryCreate(url, UriKind.Absolute, out var uri))
return false;
if (uri.Scheme is not ("http" or "https"))
return false;
var path = uri.AbsolutePath.ToLowerInvariant();
return path.EndsWith(".png", StringComparison.Ordinal)
|| path.EndsWith(".jpg", StringComparison.Ordinal)
|| path.EndsWith(".jpeg", StringComparison.Ordinal)
|| path.EndsWith(".gif", StringComparison.Ordinal)
|| path.EndsWith(".webp", StringComparison.Ordinal)
|| path.EndsWith(".avif", StringComparison.Ordinal);
}
public static async Task<bool> IsReachableImageAsync(string? url, IHttpClientFactory httpFactory, HttpMessageHandler? handler = null, CancellationToken ct = default)
{
if (string.IsNullOrWhiteSpace(url))
return true;
if (!Uri.TryCreate(url, UriKind.Absolute, out var uri))
return false;
if (uri.Scheme is not ("http" or "https"))
return false;
if (!await IsSafePublicHostAsync(uri, ct))
return false;
using var cts = CancellationTokenSource.CreateLinkedTokenSource(ct);
cts.CancelAfter(TimeSpan.FromSeconds(3));
var client = handler is null ? httpFactory.CreateClient("imageValidation") : new HttpClient(handler, disposeHandler: false);
try
{
using var head = new HttpRequestMessage(HttpMethod.Head, uri);
var headResp = await client.SendAsync(head, HttpCompletionOption.ResponseHeadersRead, cts.Token);
if (WasRedirected(uri, headResp))
return false;
if (headResp is { IsSuccessStatusCode: true, StatusCode: not System.Net.HttpStatusCode.Redirect })
{
if (headResp.Content.Headers.ContentLength is > MaxImageBytes)
return false;
var ctHeader = headResp.Content.Headers.ContentType?.MediaType;
if (!string.IsNullOrWhiteSpace(ctHeader) && ctHeader.StartsWith("image/", StringComparison.OrdinalIgnoreCase))
return true;
}
}
catch
{
/* fallback */
}
try
{
using var get = new HttpRequestMessage(HttpMethod.Get, uri);
get.Headers.Range = new System.Net.Http.Headers.RangeHeaderValue(0, 1023);
var resp = await client.SendAsync(get, HttpCompletionOption.ResponseHeadersRead, cts.Token);
if (WasRedirected(uri, resp))
return false;
if (!resp.IsSuccessStatusCode)
return false;
if (resp.StatusCode is System.Net.HttpStatusCode.Redirect)
return false;
if (resp.Content.Headers.ContentLength is > MaxImageBytes)
return false;
var ctHeader = resp.Content.Headers.ContentType?.MediaType;
if (!string.IsNullOrWhiteSpace(ctHeader) && ctHeader.StartsWith("image/", StringComparison.OrdinalIgnoreCase))
return true;
await using var stream = await resp.Content.ReadAsStreamAsync(cts.Token);
var rented = new byte[12];
var read = await stream.ReadAsync(rented.AsMemory(0, rented.Length), cts.Token);
var sig = new ReadOnlySpan<byte>(rented, 0, read);
if (IsMagic(sig, "PNG"))
return true;
if (IsMagic(sig, [0xFF, 0xD8]))
return true; // JPEG
if (IsMagic(sig, "GIF8"))
return true;
if (IsRiffWithTag(sig, "WEBP"))
return true;
if (ContainsFtyp(sig, "avif"))
return true;
return false;
}
catch
{
return false;
}
}
private const long MaxImageBytes = 5 * 1024 * 1024; // 5 MB guard
private static bool WasRedirected(Uri requestedUri, HttpResponseMessage response)
{
var finalUri = response.RequestMessage?.RequestUri;
if (finalUri is null)
return false;
return !requestedUri.Equals(finalUri);
}
private static async Task<bool> IsSafePublicHostAsync(Uri uri, CancellationToken ct)
{
try
{
var host = uri.Host;
if (Uri.CheckHostName(host) == UriHostNameType.Dns || Uri.CheckHostName(host) == UriHostNameType.IPv4 || Uri.CheckHostName(host) == UriHostNameType.IPv6)
{
var addresses = await System.Net.Dns.GetHostAddressesAsync(host, ct);
foreach (var ip in addresses)
{
if (System.Net.IPAddress.IsLoopback(ip))
return false;
if (IsPrivate(ip))
return false;
}
}
else
{
return false;
}
return true;
}
catch
{
return false;
}
}
private static bool IsPrivate(System.Net.IPAddress ip)
{
if (ip.AddressFamily == System.Net.Sockets.AddressFamily.InterNetwork)
{
var bytes = ip.GetAddressBytes();
return bytes[0] switch
{
10 => true,
172 when bytes[1] >= 16 && bytes[1] <= 31 => true,
192 when bytes[1] == 168 => true,
127 => true,
_ => false
};
}
if (ip.AddressFamily == System.Net.Sockets.AddressFamily.InterNetworkV6)
{
return ip.IsIPv6LinkLocal || ip.IsIPv6SiteLocal || ip.IsIPv6Multicast || System.Net.IPAddress.IsLoopback(ip);
}
return false;
}
private static bool IsMagic(ReadOnlySpan<byte> data, string ascii)
{
var bytes = System.Text.Encoding.ASCII.GetBytes(ascii);
return data.StartsWith(bytes);
}
private static bool IsMagic(ReadOnlySpan<byte> data, ReadOnlySpan<byte> magic) => data.StartsWith(magic);
private static bool IsRiffWithTag(ReadOnlySpan<byte> data, string tag)
{
if (data.Length < 12)
return false;
var riff = "RIFF"u8.ToArray();
if (!data.StartsWith(riff))
return false;
var tagBytes = System.Text.Encoding.ASCII.GetBytes(tag);
return data[8..].StartsWith(tagBytes);
}
private static bool ContainsFtyp(ReadOnlySpan<byte> data, string brand)
{
if (data.Length < 12)
return false;
var ftyp = "ftyp"u8.ToArray();
if (!data[4..].StartsWith(ftyp))
return false;
var brandBytes = System.Text.Encoding.ASCII.GetBytes(brand);
return data[8..].StartsWith(brandBytes);
}
public static bool IsValidHttpUrl(string? url)
{
if (string.IsNullOrWhiteSpace(url))
return true; // empty is allowed
if (!Uri.TryCreate(url, UriKind.Absolute, out var uri))
return false;
return uri.Scheme is "http" or "https";
}
public static async Task<bool> IsAdmin(HttpContext ctx, AppDbContext db)
{
var player = await GetAuthenticatedPlayer(ctx, db);
return player?.IsAdmin == true;
}
public static AppState NewAppState() => new()
{
Id = 1,
ResultsOpen = false,
UpdatedAt = DateTimeOffset.UnixEpoch
};
public static Dictionary<int, int> BuildLinkRoots(IEnumerable<(int Id, int? ParentId)> items)
{
var parentMap = items.ToDictionary(x => x.Id, x => x.ParentId);
var roots = new Dictionary<int, int>();
foreach (var id in parentMap.Keys)
{
roots[id] = FindRootId(id, parentMap);
}
return roots;
}
public static int FindRootId(int suggestionId, IReadOnlyDictionary<int, int?> parentMap)
{
var current = suggestionId;
var visited = new HashSet<int>();
while (parentMap.TryGetValue(current, out var parent) && parent is { } p && !visited.Contains(p))
{
visited.Add(current);
current = p;
}
return current;
}
public static List<int> LinkedIdsFor(int suggestionId, IReadOnlyDictionary<int, int> rootIndex)
{
if (!rootIndex.TryGetValue(suggestionId, out var root))
return [];
return rootIndex.Where(kv => kv.Value == root).Select(kv => kv.Key).ToList();
}
}