Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions src/EasyWebSockets/EasyWebSockets.csproj
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
<Project Sdk="Microsoft.NET.Sdk">
<Project Sdk="Microsoft.NET.Sdk">

<PropertyGroup>
<TargetFramework>netstandard2.0</TargetFramework>
<LangVersion>9</LangVersion>
<Nullable>Enable</Nullable>
<WarningsAsErrors>nullable</WarningsAsErrors>
<PackageId>EasyWebSockets</PackageId>
<Version>1.0.0</Version>
<AssemblyName>EasyWebSockets</AssemblyName>
Expand All @@ -12,8 +15,8 @@
</PropertyGroup>

<ItemGroup>
<PackageReference Include="Microsoft.AspNetCore.WebSockets" Version="2.2.0" />
<PackageReference Include="Newtonsoft.Json" Version="12.0.1" />
<PackageReference Include="Microsoft.AspNetCore.WebSockets" Version="2.2.1" />
<PackageReference Include="Newtonsoft.Json" Version="12.0.3" />
</ItemGroup>

</Project>
14 changes: 6 additions & 8 deletions src/EasyWebSockets/WebSocketConnectionManager.cs
Original file line number Diff line number Diff line change
Expand Up @@ -9,28 +9,26 @@ namespace EasyWebSockets
{
internal class WebSocketConnectionManager
{
private ConcurrentDictionary<string, WebSocket> _sockets = new ConcurrentDictionary<string, WebSocket>();
private readonly ConcurrentDictionary<string, WebSocket> _sockets = new ConcurrentDictionary<string, WebSocket>();

public WebSocket GetSocketById(string id) =>
_sockets.FirstOrDefault(p => p.Key == id).Value;


public ConcurrentDictionary<string, WebSocket> GetAll() => _sockets;

public string GetId(WebSocket socket) => _sockets.FirstOrDefault(p => p.Value == socket).Key;

public void AddSocket(WebSocket socket) => _sockets.TryAdd(CreateConnectionId(), socket);

public async Task RemoveSocket(string id)
public Task RemoveSocket(string id, CancellationToken cancellationToken = default)
{
WebSocket socket;
_sockets.TryRemove(id, out socket);
_sockets.TryRemove(id, out WebSocket socket);

await socket.CloseAsync(closeStatus: WebSocketCloseStatus.NormalClosure,
return socket.CloseAsync(closeStatus: WebSocketCloseStatus.NormalClosure,
statusDescription: "Closed by the WebSocketManager",
cancellationToken: CancellationToken.None);
cancellationToken: cancellationToken);
}

private string CreateConnectionId() => Guid.NewGuid().ToString();
private static string CreateConnectionId() => Guid.NewGuid().ToString();
}
}
35 changes: 19 additions & 16 deletions src/EasyWebSockets/WebSocketHandler.cs
Original file line number Diff line number Diff line change
Expand Up @@ -9,50 +9,53 @@

namespace EasyWebSockets
{
public interface IWebSocketPublisher
public interface IWebSocketPublisher
{
Task SendMessageToAllAsync(object message);
Task SendMessageToAllAsync(object message, CancellationToken cancellationToken = default);
}

internal class WebSocketHandler : IWebSocketPublisher
{
private readonly WebSocketConnectionManager _webSocketConnectionManager;

private JsonSerializerSettings _jsonSerializerSettings = new JsonSerializerSettings()
private readonly JsonSerializerSettings _jsonSerializerSettings = new JsonSerializerSettings()
{
ContractResolver = new CamelCasePropertyNamesContractResolver()
};

public WebSocketHandler(WebSocketConnectionManager webSocketConnectionManager) =>
_webSocketConnectionManager = webSocketConnectionManager;

public async Task SendMessageToAllAsync(object message) =>
await Task.WhenAll(
_webSocketConnectionManager.GetAll()
public Task SendMessageToAllAsync(object message, CancellationToken cancellationToken = default) =>
Task.WhenAll(_webSocketConnectionManager.GetAll()
.Where(pair => pair.Value.State == WebSocketState.Open)
.Select(pair => SendMessageAsync(pair.Value, message)));
public async Task OnConnected(WebSocket socket)
.Select(pair => SendMessageAsync(pair.Value, message, cancellationToken)));

public Task OnConnected(WebSocket socket, CancellationToken cancellationToken = default)
{
_webSocketConnectionManager.AddSocket(socket);
await SendMessageAsync(socket, $"Connected with Id: ${_webSocketConnectionManager.GetId(socket)}");
return SendMessageAsync(socket, $"Connected with Id: ${_webSocketConnectionManager.GetId(socket)}", cancellationToken);
}

public async Task OnDisconnected(WebSocket socket) =>
await _webSocketConnectionManager.RemoveSocket(_webSocketConnectionManager.GetId(socket));
public Task OnDisconnected(WebSocket socket, CancellationToken cancellationToken = default)
{
cancellationToken.ThrowIfCancellationRequested();
return _webSocketConnectionManager.RemoveSocket(_webSocketConnectionManager.GetId(socket), cancellationToken);
}

private async Task SendMessageAsync(WebSocket socket, object message)
private async Task SendMessageAsync(WebSocket socket, object message, CancellationToken cancellationToken = default)
{
if (socket.State != WebSocketState.Open)
return;

var serializedMessage = JsonConvert.SerializeObject(message, _jsonSerializerSettings);
string? serializedMessage = JsonConvert.SerializeObject(message, _jsonSerializerSettings);
await socket.SendAsync(buffer: new ArraySegment<byte>(
array: Encoding.ASCII.GetBytes(serializedMessage),offset: 0,
array: Encoding.ASCII.GetBytes(serializedMessage),
offset: 0,
count: serializedMessage.Length),
messageType: WebSocketMessageType.Text,
endOfMessage: true,
cancellationToken: CancellationToken.None);
cancellationToken: cancellationToken);
}
}
}
6 changes: 4 additions & 2 deletions src/EasyWebSockets/WebSocketManagerExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,16 @@ public static IServiceCollection AddEasyWebSockets(this IServiceCollection servi
{
services.AddTransient<WebSocketConnectionManager>();
services.AddSingleton<IWebSocketPublisher, WebSocketHandler>();

return services;
}

public static IApplicationBuilder UseEasyWebSockets(this IApplicationBuilder app, string path = "/ws")
{
app.UseWebSockets();
var wsHandler = app.ApplicationServices.GetService(typeof(IWebSocketPublisher));
return app.Map(new PathString(path), (_app) => _app.UseMiddleware<WebSocketManagerMiddleware>(wsHandler));

object wsHandler = app.ApplicationServices.GetRequiredService<IWebSocketPublisher>();
return app.Map(new PathString(path), builder => builder.UseMiddleware<WebSocketManagerMiddleware>(wsHandler));
}
}
}
51 changes: 20 additions & 31 deletions src/EasyWebSockets/WebSocketManagerMiddleware.cs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ namespace EasyWebSockets
internal class WebSocketManagerMiddleware
{
private readonly RequestDelegate _next;
private WebSocketHandler _webSocketHandler { get; set; }
private readonly WebSocketHandler _webSocketHandler;

public WebSocketManagerMiddleware(RequestDelegate next, WebSocketHandler webSocketHandler)
{
Expand All @@ -24,55 +24,44 @@ public async Task Invoke(HttpContext context)
if (!context.WebSockets.IsWebSocketRequest)
return;

var socket = await context.WebSockets.AcceptWebSocketAsync();
await _webSocketHandler.OnConnected(socket);
WebSocket? socket = await context.WebSockets.AcceptWebSocketAsync();
await _webSocketHandler.OnConnected(socket, context.RequestAborted);
await Receive(socket, async (result, serializedInvocationDescriptor) =>
{
if (result.MessageType == WebSocketMessageType.Text)
switch (result.MessageType)
{
// await _webSocketHandler.ReceiveAsync(socket, result, serializedInvocationDescriptor);
return;
}

else if (result.MessageType == WebSocketMessageType.Close)
{
try
{
await _webSocketHandler.OnDisconnected(socket);
}
case WebSocketMessageType.Text:
// await _webSocketHandler.ReceiveAsync(socket, result, serializedInvocationDescriptor);
return;

catch (WebSocketException)
{
throw; //let's not swallow any exception for now
}

return;
case WebSocketMessageType.Close:
await _webSocketHandler.OnDisconnected(socket, context.RequestAborted);
return;
}
});
}, context.RequestAborted);
}

private async Task Receive(WebSocket socket, Action<WebSocketReceiveResult, string> handleMessage)
private static async Task Receive(WebSocket socket, Action<WebSocketReceiveResult, string> handleMessage, CancellationToken cancellationToken = default)
{
while (socket.State == WebSocketState.Open)
{
ArraySegment<Byte> buffer = new ArraySegment<byte>(new Byte[1024 * 4]);
string serializedInvocationDescriptor = null;
WebSocketReceiveResult result = null;
var buffer = new ArraySegment<byte>(new byte[1024 * 4]);
string serializedInvocationDescriptor;
WebSocketReceiveResult result;

using (var ms = new MemoryStream())
{
do
{
result = await socket.ReceiveAsync(buffer, CancellationToken.None);
ms.Write(buffer.Array, buffer.Offset, result.Count);
result = await socket.ReceiveAsync(buffer, cancellationToken);
await ms.WriteAsync(buffer.Array, buffer.Offset, result.Count, cancellationToken);
}
while (!result.EndOfMessage);

ms.Seek(0, SeekOrigin.Begin);

using (var reader = new StreamReader(ms, Encoding.UTF8))
{
serializedInvocationDescriptor = await reader.ReadToEndAsync();
}
using var reader = new StreamReader(ms, Encoding.UTF8);
serializedInvocationDescriptor = await reader.ReadToEndAsync();
}

handleMessage(result, serializedInvocationDescriptor);
Expand Down