diff --git a/MatrixMediaGate/Program.cs b/MatrixMediaGate/Program.cs
index 9ebac51..4a62ebc 100644
--- a/MatrixMediaGate/Program.cs
+++ b/MatrixMediaGate/Program.cs
@@ -8,11 +8,19 @@ var builder = WebApplication.CreateBuilder(args);
builder.Services.AddSingleton<ProxyConfiguration>();
builder.Services.AddSingleton<IHttpContextAccessor, HttpContextAccessor>();
-builder.Services.AddScoped<AuthValidator>();
+builder.Services.AddSingleton<AuthValidator>();
+builder.Services.AddSingleton<HttpClient>(services => {
+ var cfg = services.GetRequiredService<ProxyConfiguration>();
+ // var handler = new HttpClientHandler() { AutomaticDecompression = DecompressionMethods.None };
+ return new HttpClient() {
+ BaseAddress = new Uri(cfg.Upstream),
+ MaxResponseContentBufferSize = 1 * 1024 * 1024 // 1MB
+ };
+});
var app = builder.Build();
-async Task Proxy(ProxyConfiguration cfg, AuthValidator auth, HttpContext ctx, ILogger<Program> logger) {
+async Task Proxy(HttpClient hc, ProxyConfiguration cfg, HttpContext ctx, ILogger<Program> logger) {
if (ctx is null) return;
var path = ctx.Request.Path.Value;
if (path is null) return;
@@ -20,15 +28,13 @@ async Task Proxy(ProxyConfiguration cfg, AuthValidator auth, HttpContext ctx, IL
path = path[1..];
path += ctx.Request.QueryString.Value;
- using var handler = new HttpClientHandler() { AutomaticDecompression = DecompressionMethods.None };
-
- using var hc = new HttpClient(handler) { BaseAddress = new Uri(cfg.Upstream) };
var method = new HttpMethod(ctx.Request.Method);
using var req = new HttpRequestMessage(method, path);
foreach (var header in ctx.Request.Headers) {
- // if (header.Key != "Accept-Encoding" && header.Key != "Content-Type" && header.Key != "Content-Length")
+ if (header.Key != "Accept-Encoding" && header.Key != "Content-Type" && header.Key != "Content-Length")
req.Headers.Add(header.Key, header.Value.ToArray());
}
+ req.Headers.Host = cfg.Host;
if (ctx.Request.ContentLength > 0) {
req.Content = new StreamContent(ctx.Request.Body);
@@ -41,12 +47,13 @@ async Task Proxy(ProxyConfiguration cfg, AuthValidator auth, HttpContext ctx, IL
using var response = await hc.SendAsync(req, HttpCompletionOption.ResponseHeadersRead);
ctx.Response.Headers.Clear();
foreach (var header in response.Headers) {
- // if (header.Key != "Transfer-Encoding")
+ if (header.Key != "Transfer-Encoding")
ctx.Response.Headers[header.Key] = header.Value.ToArray();
}
ctx.Response.StatusCode = (int)response.StatusCode;
ctx.Response.ContentType = response.Content.Headers.ContentType?.ToString() ?? "application/json";
+ if (response.Content.Headers.ContentLength != null) ctx.Response.ContentLength = response.Content.Headers.ContentLength;
await ctx.Response.StartAsync();
await using var content = await response.Content.ReadAsStreamAsync();
await content.CopyToAsync(ctx.Response.Body);
@@ -54,17 +61,15 @@ async Task Proxy(ProxyConfiguration cfg, AuthValidator auth, HttpContext ctx, IL
await ctx.Response.CompleteAsync();
}
-async Task ProxyMaybeAuth(ProxyConfiguration cfg, AuthValidator auth, HttpContext ctx, ILogger<Program> logger) {
- if (ctx is null) return;
-
- await auth.UpdateAuth();
+async Task ProxyMaybeAuth(HttpClient hc, ProxyConfiguration cfg, AuthValidator auth, HttpContext ctx, ILogger<Program> logger) {
+ await auth.UpdateAuth(ctx);
- await Proxy(cfg, auth, ctx, logger);
+ await Proxy(hc, cfg, ctx, logger);
}
-async Task ProxyMedia(string serverName, ProxyConfiguration cfg, AuthValidator auth, HttpContext ctx, ILogger<Program> logger) {
- if (cfg.TrustedServers.Contains(serverName) || auth.ValidateAuth()) {
- await ProxyMaybeAuth(cfg, auth, ctx, logger); // Some clients may send Authorization header...
+async Task ProxyMedia(string serverName, ProxyConfiguration cfg, HttpClient hc, AuthValidator auth, HttpContext ctx, ILogger<Program> logger) {
+ if (cfg.TrustedServers.Contains(serverName) || auth.ValidateAuth(ctx)) {
+ await ProxyMaybeAuth(hc, cfg, auth, ctx, logger); // Some clients may send Authorization header...
}
else {
ctx.Response.StatusCode = 403;
|