diff --git a/Tests/LibMatrix.HomeserverEmulator/Controllers/Media/MediaController.cs b/Tests/LibMatrix.HomeserverEmulator/Controllers/Media/MediaController.cs
index 4820a65..7899ada 100644
--- a/Tests/LibMatrix.HomeserverEmulator/Controllers/Media/MediaController.cs
+++ b/Tests/LibMatrix.HomeserverEmulator/Controllers/Media/MediaController.cs
@@ -1,9 +1,8 @@
using System.Text.Json.Nodes;
using System.Text.RegularExpressions;
-using ArcaneLibs.Extensions;
+using ArcaneLibs.Collections;
using LibMatrix.HomeserverEmulator.Services;
using LibMatrix.Services;
-using Microsoft.AspNetCore.Html;
using Microsoft.AspNetCore.Mvc;
namespace LibMatrix.HomeserverEmulator.Controllers.Media;
@@ -41,60 +40,10 @@ public class MediaController(
return media;
}
- private Dictionary<string, SemaphoreSlim> downloadLocks = new();
-
[HttpGet("download/{serverName}/{mediaId}")]
public async Task DownloadMedia(string serverName, string mediaId) {
- while (true)
- try {
- if (cfg.StoreData) {
- SemaphoreSlim ss;
- if (!downloadLocks.ContainsKey(serverName + mediaId))
- downloadLocks[serverName + mediaId] = new SemaphoreSlim(1);
- ss = downloadLocks[serverName + mediaId];
- await ss.WaitAsync();
- var serverMediaPath = Path.Combine(cfg.DataStoragePath, "media", serverName);
- Directory.CreateDirectory(serverMediaPath);
- var mediaPath = Path.Combine(serverMediaPath, mediaId);
- if (System.IO.File.Exists(mediaPath)) {
- ss.Release();
- await using var stream = new FileStream(mediaPath, FileMode.Open);
- await stream.CopyToAsync(Response.Body);
- return;
- }
- else {
- var mediaUrl = await hsResolver.ResolveMediaUri(serverName, $"mxc://{serverName}/{mediaId}");
- if (mediaUrl is null)
- throw new MatrixException() {
- ErrorCode = "M_NOT_FOUND",
- Error = "Media not found"
- };
- await using var stream = System.IO.File.OpenWrite(mediaPath);
- using var response = await new HttpClient().GetAsync(mediaUrl);
- await response.Content.CopyToAsync(stream);
- await stream.FlushAsync();
- ss.Release();
- await DownloadMedia(serverName, mediaId);
- return;
- }
- }
- else {
- var mediaUrl = await hsResolver.ResolveMediaUri(serverName, $"mxc://{serverName}/{mediaId}");
- if (mediaUrl is null)
- throw new MatrixException() {
- ErrorCode = "M_NOT_FOUND",
- Error = "Media not found"
- };
- using var response = await new HttpClient().GetAsync(mediaUrl);
- await response.Content.CopyToAsync(Response.Body);
- return;
- }
-
- return;
- }
- catch (IOException) {
- //ignored
- }
+ var stream = await DownloadRemoteMedia(serverName, mediaId);
+ await stream.CopyToAsync(Response.Body);
}
[HttpGet("thumbnail/{serverName}/{mediaId}")]
@@ -118,4 +67,33 @@ public class MediaController(
return data;
}
+
+ private async Task<Stream> DownloadRemoteMedia(string serverName, string mediaId) {
+ if (cfg.StoreData) {
+ var path = Path.Combine(cfg.DataStoragePath, "media", serverName, mediaId);
+ if (!System.IO.File.Exists(path)) {
+ var mediaUrl = await hsResolver.ResolveMediaUri(serverName, $"mxc://{serverName}/{mediaId}");
+ if (mediaUrl is null)
+ throw new MatrixException() {
+ ErrorCode = "M_NOT_FOUND",
+ Error = "Media not found"
+ };
+ using var client = new HttpClient();
+ var stream = await client.GetStreamAsync(mediaUrl);
+ await using var fs = System.IO.File.Create(path);
+ await stream.CopyToAsync(fs);
+ }
+ return new FileStream(path, FileMode.Open);
+ }
+ else {
+ var mediaUrl = await hsResolver.ResolveMediaUri(serverName, $"mxc://{serverName}/{mediaId}");
+ if (mediaUrl is null)
+ throw new MatrixException() {
+ ErrorCode = "M_NOT_FOUND",
+ Error = "Media not found"
+ };
+ using var client = new HttpClient();
+ return await client.GetStreamAsync(mediaUrl);
+ }
+ }
}
\ No newline at end of file
|