Skip to content

Commit 081e241

Browse files
Upgrade redisvl and improve resource management (#26)
* update redisvl * update search index init * async best practices * add missing files * clean up * remove import * add lifespan manager for tests
1 parent 3b8ce52 commit 081e241

File tree

15 files changed

+374
-173
lines changed

15 files changed

+374
-173
lines changed

.gitignore

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,4 +8,5 @@ product_metadata.json
88
product_vectors.json
99
data/
1010
!backend/data
11-
.env
11+
.env
12+
.python-version

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ Much inspiration taken from [tiangelo/full-stack-fastapi-template](https://githu
5757
product.py # primary API logic lives here
5858
/db
5959
load.py # seeds Redis DB
60-
redis_helpers.py # redis util
60+
utils.py # redis util
6161
/schema
6262
# pydantic models for serialization/validation from API
6363
/tests

backend/poetry.lock

Lines changed: 255 additions & 28 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

backend/productsearch/api/routes/product.py

Lines changed: 12 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -2,38 +2,21 @@
22

33
import numpy as np
44
from fastapi import APIRouter, Depends
5-
from redis.commands.search.document import Document
6-
from redis.commands.search.query import Query
75
from redisvl.index import AsyncSearchIndex
8-
from redisvl.query import FilterQuery, VectorQuery
9-
from redisvl.query.filter import FilterExpression, Tag
6+
from redisvl.query import CountQuery, FilterQuery, VectorQuery
7+
from redisvl.query.filter import Tag
108

119
from productsearch import config
1210
from productsearch.api.schema.product import (
1311
ProductSearchResponse,
1412
ProductVectorSearchResponse,
1513
SimilarityRequest,
1614
)
17-
from productsearch.db import redis_helpers
18-
15+
from productsearch.db import utils
1916

2017
router = APIRouter()
2118

2219

23-
def create_count_query(filter_expression: FilterExpression) -> Query:
24-
"""
25-
Create a "count" query where simply want to know how many records
26-
match a particular filter expression
27-
28-
Args:
29-
filter_expression (FilterExpression): The filter expression for the query.
30-
31-
Returns:
32-
Query: The Redis query object.
33-
"""
34-
return Query(str(filter_expression)).no_content().dialect(2)
35-
36-
3720
@router.get(
3821
"/",
3922
response_model=ProductSearchResponse,
@@ -45,7 +28,7 @@ async def get_products(
4528
skip: int = 0,
4629
gender: str = "",
4730
category: str = "",
48-
index: AsyncSearchIndex = Depends(redis_helpers.get_async_index),
31+
index: AsyncSearchIndex = Depends(utils.get_async_index),
4932
) -> ProductSearchResponse:
5033
"""Fetch and return products based on gender and category fields
5134
@@ -76,7 +59,7 @@ async def get_products(
7659
)
7760
async def find_products_by_image(
7861
similarity_request: SimilarityRequest,
79-
index: AsyncSearchIndex = Depends(redis_helpers.get_async_index),
62+
index: AsyncSearchIndex = Depends(utils.get_async_index),
8063
) -> ProductVectorSearchResponse:
8164
"""Fetch and return products based on image similarity
8265
@@ -116,14 +99,14 @@ async def find_products_by_image(
11699
return_fields=config.RETURN_FIELDS,
117100
filter_expression=filter_expression,
118101
)
119-
count_query = create_count_query(filter_expression)
102+
count_query = CountQuery(filter_expression)
120103

121104
# Execute search
122105
count, result_papers = await asyncio.gather(
123-
index.search(count_query), index.query(paper_similarity_query)
106+
index.query(count_query), index.query(paper_similarity_query)
124107
)
125108
# Get Paper records of those results
126-
return ProductVectorSearchResponse(total=count.total, products=result_papers)
109+
return ProductVectorSearchResponse(total=count, products=result_papers)
127110

128111

129112
@router.post(
@@ -134,7 +117,7 @@ async def find_products_by_image(
134117
)
135118
async def find_products_by_text(
136119
similarity_request: SimilarityRequest,
137-
index: AsyncSearchIndex = Depends(redis_helpers.get_async_index),
120+
index: AsyncSearchIndex = Depends(utils.get_async_index),
138121
) -> ProductVectorSearchResponse:
139122
"""Fetch and return products based on image similarity
140123
@@ -174,11 +157,11 @@ async def find_products_by_text(
174157
return_fields=config.RETURN_FIELDS,
175158
filter_expression=filter_expression,
176159
)
177-
count_query = create_count_query(filter_expression)
160+
count_query = CountQuery(filter_expression)
178161

179162
# Execute search
180163
count, result_papers = await asyncio.gather(
181-
index.search(count_query), index.query(paper_similarity_query)
164+
index.query(count_query), index.query(paper_similarity_query)
182165
)
183166
# Get Paper records of those results
184-
return ProductVectorSearchResponse(total=count.total, products=result_papers)
167+
return ProductVectorSearchResponse(total=count, products=result_papers)

backend/productsearch/db/load.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,15 @@
11
#!/usr/bin/env python3
22
import asyncio
33
import json
4-
import os
54
from typing import List
65

76
import numpy as np
87
import requests
9-
from productsearch import config
108
from redisvl.index import AsyncSearchIndex
119

10+
from productsearch import config
11+
from productsearch.db.utils import get_schema
12+
1213

1314
def read_from_s3():
1415
res = requests.get(config.S3_DATA_URL)
@@ -58,10 +59,8 @@ def preprocess(product: dict) -> dict:
5859

5960

6061
async def load_data():
61-
index = AsyncSearchIndex.from_yaml(
62-
os.path.join("./productsearch/db/schema", "products.yml")
63-
)
64-
index.connect(config.REDIS_URL)
62+
schema = get_schema()
63+
index = AsyncSearchIndex(schema, redis_url=config.REDIS_URL)
6564

6665
# Check if index exists
6766
if await index.exists() and len((await index.search("*")).docs) > 0:

backend/productsearch/db/redis_helpers.py

Lines changed: 0 additions & 36 deletions
This file was deleted.

backend/productsearch/db/utils.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
import logging
2+
import os
3+
4+
from redisvl.index import AsyncSearchIndex
5+
from redisvl.schema import IndexSchema
6+
7+
from productsearch import config
8+
9+
logger = logging.getLogger(__name__)
10+
11+
# global search index
12+
_global_index = None
13+
14+
15+
def get_schema() -> IndexSchema:
16+
dir_path = os.path.dirname(os.path.realpath(__file__)) + "/schema"
17+
file_path = os.path.join(dir_path, "products.yml")
18+
return IndexSchema.from_yaml(file_path)
19+
20+
21+
async def get_async_index():
22+
global _global_index
23+
if not _global_index:
24+
_global_index = AsyncSearchIndex(get_schema(), redis_url=config.REDIS_URL)
25+
return _global_index

backend/productsearch/main.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from contextlib import asynccontextmanager
12
from pathlib import Path
23

34
import uvicorn
@@ -7,10 +8,22 @@
78

89
from productsearch import config
910
from productsearch.api.main import api_router
11+
from productsearch.db.utils import get_async_index
1012
from productsearch.spa import SinglePageApplication
1113

14+
15+
@asynccontextmanager
16+
async def lifespan(app: FastAPI):
17+
index = await get_async_index()
18+
async with index:
19+
yield
20+
21+
1222
app = FastAPI(
13-
title=config.PROJECT_NAME, docs_url=config.API_DOCS, openapi_url=config.OPENAPI_DOCS
23+
title=config.PROJECT_NAME,
24+
docs_url=config.API_DOCS,
25+
openapi_url=config.OPENAPI_DOCS,
26+
lifespan=lifespan,
1427
)
1528

1629
app.add_middleware(

backend/productsearch/tests/api/routes/test_product.py

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2,30 +2,29 @@
22
from httpx import AsyncClient
33

44
from productsearch.api.schema.product import SimilarityRequest
5-
from productsearch.main import app
65

76

8-
@pytest.fixture
9-
def gender(products):
10-
return products[0]["gender"]
7+
@pytest.fixture(scope="module")
8+
def gender(test_data):
9+
return test_data[0]["gender"]
1110

1211

13-
@pytest.fixture
14-
def category(products):
15-
return products[0]["category"]
12+
@pytest.fixture(scope="module")
13+
def category(test_data):
14+
return test_data[0]["category"]
1615

1716

18-
@pytest.fixture
17+
@pytest.fixture(scope="module")
1918
def bad_req_json():
2019
return {"not": "valid"}
2120

2221

23-
@pytest.fixture
24-
def product_req(gender, category, products):
22+
@pytest.fixture(scope="module")
23+
def product_req(gender, category, test_data):
2524
return SimilarityRequest(
2625
gender=gender,
2726
category=category,
28-
product_id=products[0]["product_id"],
27+
product_id=test_data[0]["product_id"],
2928
)
3029

3130

Lines changed: 44 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,36 +1,59 @@
1-
from asyncio import get_event_loop
2-
from typing import Generator
1+
import json
2+
import os
33

4+
import httpx
5+
import numpy as np
46
import pytest
57
import pytest_asyncio
8+
from asgi_lifespan import LifespanManager
69
from httpx import AsyncClient
7-
from redis.asyncio import Redis
10+
from redisvl.index import SearchIndex
811

912
from productsearch import config
13+
from productsearch.db.utils import get_schema
1014
from productsearch.main import app
11-
from productsearch.tests.utils.seed import seed_test_db
1215

1316

14-
@pytest.fixture(scope="module")
15-
def products():
16-
products = seed_test_db()
17-
return products
17+
@pytest.fixture(scope="session")
18+
def index():
19+
index = SearchIndex(schema=get_schema(), redis_url=config.REDIS_URL)
20+
index.create()
21+
yield index
22+
index.disconnect()
1823

1924

20-
@pytest.fixture
21-
async def client():
22-
client = await Redis.from_url(config.REDIS_URL)
23-
yield client
24-
try:
25-
await client.aclose()
26-
except RuntimeError as e:
27-
if "Event loop is closed" not in str(e):
28-
raise
25+
@pytest.fixture(scope="session", autouse=True)
26+
def test_data(index):
27+
cwd = os.getcwd()
28+
with open(f"{cwd}/productsearch/tests/test_vectors.json", "r") as f:
29+
products = json.load(f)
2930

31+
parsed_products = []
3032

31-
@pytest_asyncio.fixture(scope="session")
32-
async def async_client():
33+
# convert to bytes
34+
for product in products:
35+
parsed = {}
36+
parsed["text_vector"] = np.array(
37+
product["text_vector"], dtype=np.float32
38+
).tobytes()
39+
parsed["img_vector"] = np.array(
40+
product["img_vector"], dtype=np.float32
41+
).tobytes()
42+
parsed["category"] = product["product_metadata"]["master_category"]
43+
parsed["img_url"] = product["product_metadata"]["img_url"]
44+
parsed["name"] = product["product_metadata"]["name"]
45+
parsed["gender"] = product["product_metadata"]["gender"]
46+
parsed["product_id"] = product["product_id"]
47+
parsed_products.append(parsed)
48+
49+
_ = index.load(data=parsed_products, id_field="product_id")
50+
return parsed_products
3351

34-
async with AsyncClient(app=app, base_url="http://test/api/v1/") as client:
3552

36-
yield client
53+
@pytest_asyncio.fixture(scope="session")
54+
async def async_client():
55+
async with LifespanManager(app=app) as lifespan:
56+
async with AsyncClient(
57+
transport=httpx.ASGITransport(app=app), base_url="http://test/api/v1/" # type: ignore
58+
) as client:
59+
yield client

0 commit comments

Comments
 (0)