Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 8 additions & 6 deletions src/you_get/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -672,6 +672,7 @@ def url_save(
headers=None, timeout=None, **kwargs
):
tmp_headers = headers.copy() if headers is not None else {}
url = url[0] if type(url) is list and len(url) == 1 else url
# When a referer specified with param refer,
# the key must be 'Referer' for the hack here
if refer is not None:
Expand Down Expand Up @@ -807,9 +808,9 @@ def numreturn(a):
except socket.timeout:
pass
if not buffer:
if is_chunked and received_chunk == range_length:
if is_chunked and (received_chunk == range_length or range_length == float('inf')):
break
elif not is_chunked and received == file_size: # Download finished
elif not is_chunked and (received == file_size or range_length == float('inf')): # Download finished
break
# Unexpected termination. Retry request
tmp_headers['Range'] = 'bytes=' + str(received - chunk_start) + '-'
Expand All @@ -827,10 +828,11 @@ def numreturn(a):
received, os.path.getsize(temp_filepath), temp_filepath
)

if os.access(filepath, os.W_OK):
# on Windows rename could fail if destination filepath exists
os.remove(filepath)
os.rename(temp_filepath, filepath)
if temp_filepath != filepath:
if os.access(filepath, os.W_OK):
# on Windows rename could fail if destination filepath exists
os.remove(filepath)
os.rename(temp_filepath, filepath)


class SimpleProgressBar:
Expand Down
76 changes: 75 additions & 1 deletion tests/test_common.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,85 @@
#!/usr/bin/env python

import http.server
import socketserver
import tempfile
import threading
import unittest

from you_get.common import *


class TestCommon(unittest.TestCase):

def test_match1(self):
self.assertEqual(match1('http://youtu.be/1234567890A', r'youtu.be/([^/]+)'), '1234567890A')
self.assertEqual(match1('http://youtu.be/1234567890A', r'youtu.be/([^/]+)', r'youtu.(\w+)'), ['1234567890A', 'be'])


class TestDownloadUrlWithoutContentLength(unittest.TestCase):
def setUp(self):
self.server = ChunkedTestServer()
self.port = self.server.start()

def tearDown(self):
self.server.stop()

def test_server_response(self):
response = request.urlopen(f'http://localhost:{self.port}')
self.assertEqual(response.status, 200)
self.assertNotIn('Content-Length', response.headers)

expected_data = b'First chunk of data\nSecond chunk of data\nLast chunk of data'
self.assertEqual(response.read(), expected_data)

def test_url_save(self):
with tempfile.NamedTemporaryFile() as temp_file:
temp_path = temp_file.name

try:
url_save([f'http://localhost:{self.port}'], temp_path, None)

with open(temp_path, "r") as f:
expected_data = 'First chunk of data\nSecond chunk of data\nLast chunk of data'
self.assertEqual(f.read(), expected_data)
finally:
if os.path.exists(temp_path):
os.remove(temp_path)


class ChunkedHTTPRequestHandler(http.server.BaseHTTPRequestHandler):
def do_GET(self):
self.send_response(200)
self.send_header('Transfer-Encoding', 'chunked')
self.end_headers()

# Send data in chunks
chunks = [b"First chunk of data\n",
b"Second chunk of data\n",
b"Last chunk of data"]

for chunk in chunks:
self.wfile.write(f"{len(chunk):x}\r\n".encode())
self.wfile.write(chunk)
self.wfile.write(b"\r\n")

# Write the final chunk (zero-length chunk to indicate the end)
self.wfile.write(b"0\r\n\r\n")


class ChunkedTestServer:
def __init__(self, port=0):
self.port = port
self.server = socketserver.TCPServer(('localhost', port), ChunkedHTTPRequestHandler)
self.server_thread = None

def start(self):
self.server_thread = threading.Thread(target=self.server.serve_forever)
self.server_thread.daemon = True
self.server_thread.start()
self.port = self.server.server_address[1]
return self.port

def stop(self):
self.server.shutdown()
self.server.server_close()
self.server_thread.join()