about summary refs log tree commit diff
path: root/Tests/LibMatrix.HomeserverEmulator/Controllers/Media/MediaController.cs
diff options
context:
space:
mode:
Diffstat (limited to 'Tests/LibMatrix.HomeserverEmulator/Controllers/Media/MediaController.cs')
-rw-r--r--Tests/LibMatrix.HomeserverEmulator/Controllers/Media/MediaController.cs86
1 files changed, 32 insertions, 54 deletions
diff --git a/Tests/LibMatrix.HomeserverEmulator/Controllers/Media/MediaController.cs b/Tests/LibMatrix.HomeserverEmulator/Controllers/Media/MediaController.cs
index 4820a65..7899ada 100644
--- a/Tests/LibMatrix.HomeserverEmulator/Controllers/Media/MediaController.cs
+++ b/Tests/LibMatrix.HomeserverEmulator/Controllers/Media/MediaController.cs
@@ -1,9 +1,8 @@
 using System.Text.Json.Nodes;
 using System.Text.RegularExpressions;
-using ArcaneLibs.Extensions;
+using ArcaneLibs.Collections;
 using LibMatrix.HomeserverEmulator.Services;
 using LibMatrix.Services;
-using Microsoft.AspNetCore.Html;
 using Microsoft.AspNetCore.Mvc;
 
 namespace LibMatrix.HomeserverEmulator.Controllers.Media;
@@ -41,60 +40,10 @@ public class MediaController(
         return media;
     }
 
-    private Dictionary<string, SemaphoreSlim> downloadLocks = new();
-
     [HttpGet("download/{serverName}/{mediaId}")]
     public async Task DownloadMedia(string serverName, string mediaId) {
-        while (true)
-            try {
-                if (cfg.StoreData) {
-                    SemaphoreSlim ss;
-                    if (!downloadLocks.ContainsKey(serverName + mediaId))
-                        downloadLocks[serverName + mediaId] = new SemaphoreSlim(1);
-                    ss = downloadLocks[serverName + mediaId];
-                    await ss.WaitAsync();
-                    var serverMediaPath = Path.Combine(cfg.DataStoragePath, "media", serverName);
-                    Directory.CreateDirectory(serverMediaPath);
-                    var mediaPath = Path.Combine(serverMediaPath, mediaId);
-                    if (System.IO.File.Exists(mediaPath)) {
-                        ss.Release();
-                        await using var stream = new FileStream(mediaPath, FileMode.Open);
-                        await stream.CopyToAsync(Response.Body);
-                        return;
-                    }
-                    else {
-                        var mediaUrl = await hsResolver.ResolveMediaUri(serverName, $"mxc://{serverName}/{mediaId}");
-                        if (mediaUrl is null)
-                            throw new MatrixException() {
-                                ErrorCode = "M_NOT_FOUND",
-                                Error = "Media not found"
-                            };
-                        await using var stream = System.IO.File.OpenWrite(mediaPath);
-                        using var response = await new HttpClient().GetAsync(mediaUrl);
-                        await response.Content.CopyToAsync(stream);
-                        await stream.FlushAsync();
-                        ss.Release();
-                        await DownloadMedia(serverName, mediaId);
-                        return;
-                    }
-                }
-                else {
-                    var mediaUrl = await hsResolver.ResolveMediaUri(serverName, $"mxc://{serverName}/{mediaId}");
-                    if (mediaUrl is null)
-                        throw new MatrixException() {
-                            ErrorCode = "M_NOT_FOUND",
-                            Error = "Media not found"
-                        };
-                    using var response = await new HttpClient().GetAsync(mediaUrl);
-                    await response.Content.CopyToAsync(Response.Body);
-                    return;
-                }
-
-                return;
-            }
-            catch (IOException) {
-                //ignored
-            }
+        var stream = await DownloadRemoteMedia(serverName, mediaId);
+        await stream.CopyToAsync(Response.Body);
     }
 
     [HttpGet("thumbnail/{serverName}/{mediaId}")]
@@ -118,4 +67,33 @@ public class MediaController(
 
         return data;
     }
+
+    private async Task<Stream> DownloadRemoteMedia(string serverName, string mediaId) {
+        if (cfg.StoreData) {
+            var path = Path.Combine(cfg.DataStoragePath, "media", serverName, mediaId);
+            if (!System.IO.File.Exists(path)) {
+                var mediaUrl = await hsResolver.ResolveMediaUri(serverName, $"mxc://{serverName}/{mediaId}");
+                if (mediaUrl is null)
+                    throw new MatrixException() {
+                        ErrorCode = "M_NOT_FOUND",
+                        Error = "Media not found"
+                    };
+                using var client = new HttpClient();
+                var stream = await client.GetStreamAsync(mediaUrl);
+                await using var fs = System.IO.File.Create(path);
+                await stream.CopyToAsync(fs);
+            }
+            return new FileStream(path, FileMode.Open);
+        }
+        else {
+            var mediaUrl = await hsResolver.ResolveMediaUri(serverName, $"mxc://{serverName}/{mediaId}");
+            if (mediaUrl is null)
+                throw new MatrixException() {
+                    ErrorCode = "M_NOT_FOUND",
+                    Error = "Media not found"
+                };
+            using var client = new HttpClient();
+            return await client.GetStreamAsync(mediaUrl);
+        }
+    }
 }
\ No newline at end of file