Skip to content

Commit 65943b5

Browse files
committed
refactor: proxy protocol parsers
1 parent af70bc9 commit 65943b5

File tree

4 files changed

+191
-207
lines changed

4 files changed

+191
-207
lines changed

ProxyProtocolSocket/ProxyProtocolSocket.csproj

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,10 @@
44
<TargetFramework>net6.0</TargetFramework>
55
<ImplicitUsings>enable</ImplicitUsings>
66
<Nullable>enable</Nullable>
7-
<Version>2.0</Version>
7+
<Version>2.1</Version>
88
<Authors>LaoSparrow</Authors>
9-
<AssemblyVersion>2.0</AssemblyVersion>
10-
<FileVersion>2.0</FileVersion>
9+
<AssemblyVersion>2.1</AssemblyVersion>
10+
<FileVersion>2.1</FileVersion>
1111
</PropertyGroup>
1212

1313
<ItemGroup>

ProxyProtocolSocket/Utils/Net/ProxyProtocol.cs

Lines changed: 35 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -27,18 +27,20 @@ public class ProxyProtocol
2727
private static readonly byte[] V1_SIGNATURE = Encoding.ASCII.GetBytes("PROXY");
2828
private static readonly byte[] V2_OR_ABOVE_SIGNATURE = { 0x0D, 0x0A, 0x0D, 0x0A, 0x00, 0x0D, 0x0A, 0x51, 0x55, 0x49, 0x54, 0x0A };
2929
private const int V2_OR_ABOVE_ID_LENGTH = 16;
30-
private const int BUFFER_SIZE = 107;
30+
private const int MAX_BUFFER_SIZE = 232;
3131
#endregion
3232

3333
#region Members
3434
private NetworkStream _stream;
3535
private IPEndPoint _remoteEndpoint;
3636

37-
private byte[] _buffer = new byte[BUFFER_SIZE];
38-
private int _bufferPosition = 0;
37+
private byte[] _buffer = new byte[MAX_BUFFER_SIZE];
38+
// i.e. the end of _buffer
39+
// _buffer[_bufferSize - 1] is the last byte read from the stream
40+
private int _bufferSize = 0;
3941

4042
private bool _isParserCached = false;
41-
private IProxyProtocolParser? _cachedParser = null;
43+
private IProxyProtocolParser? _cachedParser = null;
4244
private ProxyProtocolVersion? _protocolVersion = null;
4345
#endregion
4446

@@ -62,7 +64,7 @@ public async Task Parse()
6264
var parser = await GetParser();
6365
if (parser == null)
6466
return;
65-
Logger.Log($"Calling parser {parser.GetType().Name}");
67+
Logger.Log($"[{_remoteEndpoint}] calling {parser.GetType().Name}.Parse()");
6668
await parser.Parse();
6769
}
6870

@@ -106,22 +108,14 @@ public async Task<ProxyProtocolCommand> GetCommand()
106108
if (_isParserCached)
107109
return _cachedParser;
108110

109-
Logger.Log("Getting parser");
111+
Logger.Log($"[{_remoteEndpoint}] selecting parser...");
110112
// Get parser corresponding to version
111-
switch (await GetVersion())
113+
_cachedParser = await GetVersion() switch
112114
{
113-
case ProxyProtocolVersion.V1:
114-
_cachedParser = new ProxyProtocolParserV1(_stream, _remoteEndpoint, _buffer, ref _bufferPosition);
115-
break;
116-
117-
case ProxyProtocolVersion.V2:
118-
_cachedParser = new ProxyProtocolParserV2(_stream, _remoteEndpoint, _buffer, ref _bufferPosition);
119-
break;
120-
121-
default:
122-
_cachedParser = null;
123-
break;
124-
}
115+
ProxyProtocolVersion.V1 => new ProxyProtocolParserV1(_stream, _remoteEndpoint, _buffer, _bufferSize),
116+
ProxyProtocolVersion.V2 => new ProxyProtocolParserV2(_stream, _remoteEndpoint, _buffer, _bufferSize),
117+
_ => null
118+
};
125119
_isParserCached = true;
126120
return _cachedParser;
127121
}
@@ -132,21 +126,21 @@ public async Task<ProxyProtocolVersion> GetVersion()
132126
if (_protocolVersion != null)
133127
return (ProxyProtocolVersion)_protocolVersion;
134128

135-
Logger.Log("Getting version info");
129+
Logger.Log($"[{_remoteEndpoint}] interpreting protocol version...");
136130

137131
_protocolVersion = ProxyProtocolVersion.Unknown;
138132
// Check if is version 1
139-
await GetBytesToPosition(V1_SIGNATURE.Length);
133+
await GetBytesTillBufferSize(V1_SIGNATURE.Length);
140134
if (IsVersion1(_buffer))
141135
_protocolVersion = ProxyProtocolVersion.V1;
142136
else
143137
{
144138
// Check if is version 2 or above
145-
await GetBytesToPosition(V2_OR_ABOVE_ID_LENGTH);
139+
await GetBytesTillBufferSize(V2_OR_ABOVE_ID_LENGTH);
146140
if (IsVersion2OrAbove(_buffer))
147141
{
148142
// Check versions
149-
if (IsVersion2(_buffer, true))
143+
if (IsVersion2(_buffer, false))
150144
_protocolVersion = ProxyProtocolVersion.V2;
151145
}
152146
}
@@ -156,35 +150,43 @@ public async Task<ProxyProtocolVersion> GetVersion()
156150
#endregion
157151

158152
#region Private methods
159-
private async Task GetBytesToPosition(int position)
153+
private async Task GetBytesTillBufferSize(int size)
160154
{
161-
if (position <= _bufferPosition)
155+
if (size <= _bufferSize)
162156
return;
163-
await GetBytesFromStream(position - _bufferPosition);
157+
await GetBytesFromStream(size - _bufferSize);
164158
}
165159

166160
private async Task GetBytesFromStream(int length)
167161
{
168-
if ((_bufferPosition + length) > _buffer.Length)
162+
if ((_bufferSize + length) > _buffer.Length)
169163
throw new InternalBufferOverflowException();
170164

171165
while (length > 0)
172166
{
173167
if (!_stream.DataAvailable)
174168
throw new EndOfStreamException();
175169

176-
int count = await _stream.ReadAsync(_buffer, _bufferPosition, length);
170+
var count = await _stream.ReadAsync(_buffer.AsMemory(_bufferSize, length));
177171
length -= count;
178-
_bufferPosition += count;
172+
_bufferSize += count;
179173
}
180174
}
181175
#endregion
182176

183177
#region Public static methods
184-
public static bool IsVersion1(byte[] header) => header.Take(V1_SIGNATURE.Length).SequenceEqual(V1_SIGNATURE);
185-
public static bool IsVersion2OrAbove(byte[] header) => header.Take(V2_OR_ABOVE_SIGNATURE.Length).SequenceEqual(V2_OR_ABOVE_SIGNATURE);
186-
public static bool IsVersion2(byte[] header, bool checkedSignature = false) =>
187-
checkedSignature || IsVersion2OrAbove(header) &&
178+
179+
// ReSharper disable once MemberCanBePrivate.Global
180+
public static bool IsVersion1(ReadOnlySpan<byte> header) =>
181+
header[..V1_SIGNATURE.Length].SequenceEqual(V1_SIGNATURE);
182+
183+
// ReSharper disable once MemberCanBePrivate.Global
184+
public static bool IsVersion2OrAbove(ReadOnlySpan<byte> header) =>
185+
header[..V2_OR_ABOVE_SIGNATURE.Length].SequenceEqual(V2_OR_ABOVE_SIGNATURE);
186+
187+
// ReSharper disable once MemberCanBePrivate.Global
188+
public static bool IsVersion2(ReadOnlySpan<byte> header, bool checkSignature = true) =>
189+
(!checkSignature || IsVersion2OrAbove(header)) &&
188190
(header[12] & 0xF0) == 0x20;
189191
#endregion
190192
}

ProxyProtocolSocket/Utils/Net/ProxyProtocolParserV1.cs

Lines changed: 63 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -9,113 +9,113 @@ public class ProxyProtocolParserV1 : IProxyProtocolParser
99
#region Constants
1010
private const string DELIMITER = "\r\n";
1111
private const char SEPARATOR = ' ';
12+
private const int MAX_HEADER_SIZE = 107;
1213
#endregion
1314

1415
#region Members
1516
private NetworkStream _stream;
1617
private IPEndPoint _remoteEndpoint;
1718
private byte[] _buffer;
18-
private int _bufferPosition;
19+
private int _bufferSize;
1920

20-
private bool _isParsed;
21+
private bool _hasParsed;
2122
private AddressFamily _addressFamily = AddressFamily.Unknown;
2223
private ProxyProtocolCommand _protocolCommand = ProxyProtocolCommand.Unknown;
2324
private IPEndPoint? _sourceEndpoint;
2425
private IPEndPoint? _destEndpoint;
2526
#endregion
2627

27-
public ProxyProtocolParserV1(NetworkStream stream, IPEndPoint remoteEndpoint, byte[] buffer, ref int bufferPosition)
28+
public ProxyProtocolParserV1(NetworkStream stream, IPEndPoint remoteEndpoint, byte[] buffer, int bufferSize)
2829
{
2930
#region Args checking
30-
if (stream == null) throw new ArgumentNullException("argument 'stream' cannot be null");
31-
if (stream.CanRead != true) throw new ArgumentException("argument 'stream' is unreadable");
32-
if (remoteEndpoint == null) throw new ArgumentNullException("argument 'remoteEndpoint' cannot be null");
33-
if (buffer == null) throw new ArgumentNullException("argument 'buffer' cannot be null");
34-
if (bufferPosition > buffer.Length) throw new ArgumentException("argument 'bufferPosition' is larger than 'buffer.Length'");
31+
if (stream == null) throw new ArgumentNullException(nameof(stream));
32+
if (stream.CanRead != true) throw new ArgumentException($"argument 'stream' is unreadable");
33+
if (remoteEndpoint == null) throw new ArgumentNullException(nameof(remoteEndpoint));
34+
if (buffer == null) throw new ArgumentNullException(nameof(buffer));
35+
if (bufferSize > buffer.Length) throw new ArgumentException($"argument '{nameof(bufferSize)}' is larger than '{nameof(buffer)}.Length'");
3536
#endregion
3637

3738
#region Filling members
3839
_stream = stream;
3940
_remoteEndpoint = remoteEndpoint;
4041
_buffer = buffer;
41-
_bufferPosition = bufferPosition;
42+
_bufferSize = bufferSize;
4243
#endregion
4344
}
4445

4546
#region Public methods
4647
public async Task Parse()
4748
{
48-
if (_isParsed)
49+
if (_hasParsed)
4950
return;
50-
_isParsed = true;
51-
Logger.Log("Parsing header");
51+
_hasParsed = true;
52+
Logger.Log($"[{_remoteEndpoint}] parsing header...");
5253

5354
#region Getting full header and do first check
55+
5456
await GetFullHeader();
55-
if (_bufferPosition < 2 || _buffer[_bufferPosition - 2] != '\r')
56-
throw new Exception("Header must end with CRLF");
57-
58-
string[] tokens = Encoding.ASCII.GetString(_buffer.Take(_bufferPosition - 2).ToArray()).Split(SEPARATOR);
57+
if (ProxyProtocolSocketPlugin.Config.Settings.LogLevel == LogLevel.Debug)
58+
Logger.Log($"[{_remoteEndpoint}] header content: {Convert.ToHexString(_buffer[.._bufferSize])}");
59+
var tokens = Encoding.ASCII.GetString(_buffer[..(_bufferSize - 2)]).Split(SEPARATOR);
5960
if (tokens.Length < 2)
6061
throw new Exception("Unable to read AddressFamily and protocol");
62+
6163
#endregion
6264

6365
#region Parse address family
64-
AddressFamily addressFamily;
65-
switch (tokens[1])
66+
67+
Logger.Log($"[{_remoteEndpoint}] parsing address family...");
68+
var addressFamily = tokens[1] switch
6669
{
67-
case "TCP4":
68-
addressFamily = AddressFamily.InterNetwork;
69-
break;
70-
71-
case "TCP6":
72-
addressFamily = AddressFamily.InterNetworkV6;
73-
break;
74-
75-
case "UNKNOWN":
76-
addressFamily = AddressFamily.Unspecified;
77-
break;
70+
"TCP4" => AddressFamily.InterNetwork,
71+
"TCP6" => AddressFamily.InterNetworkV6,
72+
"UNKNOWN" => AddressFamily.Unspecified,
73+
_ => throw new Exception("Invalid address family")
74+
};
7875

79-
default:
80-
throw new Exception("Invalid address family");
81-
}
8276
#endregion
8377

8478
#region Do second check
79+
8580
if (addressFamily == AddressFamily.Unspecified)
8681
{
8782
_protocolCommand = ProxyProtocolCommand.Local;
8883
_sourceEndpoint = _remoteEndpoint;
89-
_isParsed = true;
84+
_hasParsed = true;
9085
return;
9186
}
92-
else if (tokens.Length < 6)
93-
throw new Exception("Unable to read ipaddresses and ports");
87+
88+
if (tokens.Length < 6)
89+
throw new Exception("Impossible to read ip addresses and ports as the number of tokens is less than 6");
90+
9491
#endregion
9592

9693
#region Parse source and dest end point
97-
IPEndPoint sourceEP;
98-
IPEndPoint destEP;
94+
95+
Logger.Log($"[{_remoteEndpoint}] parsing endpoints...");
96+
IPEndPoint sourceEp;
97+
IPEndPoint destEp;
9998
try
10099
{
101100
// TODO: IP format validation
102-
IPAddress sourceAddr = IPAddress.Parse(tokens[2]);
103-
IPAddress destAddr = IPAddress.Parse(tokens[3]);
104-
int sourcePort = Convert.ToInt32(tokens[4]);
105-
int destPort = Convert.ToInt32(tokens[5]);
106-
sourceEP = new IPEndPoint(sourceAddr, sourcePort);
107-
destEP = new IPEndPoint(destAddr, destPort);
101+
var sourceAddr = IPAddress.Parse(tokens[2]);
102+
var destAddr = IPAddress.Parse(tokens[3]);
103+
var sourcePort = Convert.ToInt32(tokens[4]);
104+
var destPort = Convert.ToInt32(tokens[5]);
105+
sourceEp = new IPEndPoint(sourceAddr, sourcePort);
106+
destEp = new IPEndPoint(destAddr, destPort);
108107
}
109108
catch (Exception ex)
110109
{
111110
throw new Exception("Unable to parse ip addresses and ports", ex);
112111
}
112+
113113
#endregion
114114

115115
_addressFamily = addressFamily;
116116
_protocolCommand = ProxyProtocolCommand.Proxy;
117-
_sourceEndpoint = sourceEP;
118-
_destEndpoint = destEP;
117+
_sourceEndpoint = sourceEp;
118+
_destEndpoint = destEp;
119119
}
120120

121121
public async Task<IPEndPoint?> GetSourceEndpoint()
@@ -146,58 +146,47 @@ public async Task<ProxyProtocolCommand> GetCommand()
146146
#region Private methods
147147
private async Task GetFullHeader()
148148
{
149-
Logger.Log($"Getting full header");
150-
for (int i = 1; ; i++)
149+
Logger.Log($"[{_remoteEndpoint}] getting full header");
150+
for (var i = 7; i < MAX_HEADER_SIZE; i++) // Search after "PROXY" signature
151151
{
152-
if (await GetOneByteOfPosition(i) == '\n')
153-
break;
154-
if (i >= _buffer.Length)
155-
throw new Exception("Reaching the end of buffer without reaching the delimiter of version 1");
152+
if (await GetOneByteAtPosition(i) != DELIMITER[1])
153+
continue;
154+
if (await GetOneByteAtPosition(i - 1) != DELIMITER[0])
155+
throw new Exception("Header must end with CRLF");
156+
return;
156157
}
158+
throw new Exception("Failed to find any delimiter within the maximum header size of version 1");
157159
}
158160

159-
private async Task GetBytesToPosition(int position)
161+
private async Task GetBytesTillBufferSize(int size)
160162
{
161-
if (position <= _bufferPosition)
163+
if (size <= _bufferSize)
162164
return;
163-
await GetBytesFromStream(position - _bufferPosition);
165+
await GetBytesFromStream(size - _bufferSize);
164166
}
165167

166168
private async Task GetBytesFromStream(int length)
167169
{
168-
if ((_bufferPosition + length) > _buffer.Length)
170+
if (_bufferSize + length > _buffer.Length)
169171
throw new InternalBufferOverflowException();
170172

171173
while (length > 0)
172174
{
173175
if (!_stream.DataAvailable)
174176
throw new EndOfStreamException();
175177

176-
int count = await _stream.ReadAsync(_buffer, _bufferPosition, length);
178+
var count = await _stream.ReadAsync(_buffer.AsMemory(_bufferSize, length));
177179
length -= count;
178-
_bufferPosition += count;
180+
_bufferSize += count;
179181
}
180182
}
181183

182-
private async Task<byte> GetOneByteOfPosition(int position)
184+
private async Task<byte> GetOneByteAtPosition(int position)
183185
{
184-
await GetBytesToPosition(position);
185-
return _buffer[position - 1];
186-
}
187-
188-
private byte GetOneByteFromStream()
189-
{
190-
if ((_bufferPosition + 1) > _buffer.Length)
191-
throw new InternalBufferOverflowException();
192-
193-
int readState = _stream.ReadByte();
194-
if (readState < 0)
195-
throw new EndOfStreamException();
196-
197-
_buffer[_bufferPosition] = (byte)readState;
198-
_bufferPosition++;
199-
return (byte)readState;
186+
await GetBytesTillBufferSize(position + 1);
187+
return _buffer[position];
200188
}
189+
201190
#endregion
202191
}
203192
}

0 commit comments

Comments
 (0)