Skip to content

Commit cdd80ec

Browse files
committed
Created a setup_model_class_mode function
1 parent ba5553a commit cdd80ec

File tree

2 files changed

+116
-65
lines changed

2 files changed

+116
-65
lines changed

Sports2D/Sports2D.py

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,44 @@
151151
'deepsort_params': """{'max_age':30, 'n_init':3, 'nms_max_overlap':0.8, 'max_cosine_distance':0.3, 'nn_budget':200, 'max_iou_distance':0.8, 'embedder_gpu': True, 'embedder':'torchreid'}""",
152152
'keypoint_likelihood_threshold': 0.3,
153153
'average_likelihood_threshold': 0.5,
154-
'keypoint_number_threshold': 0.3
154+
'keypoint_number_threshold': 0.3,
155+
'CUSTOM': { 'name': 'Hip',
156+
'id': 19,
157+
'children': [{'name': 'RHip',
158+
'id': 12,
159+
'children': [{'name': 'RKnee',
160+
'id': 14,
161+
'children': [{'name': 'RAnkle',
162+
'id': 16,
163+
'children': [{'name': 'RBigToe',
164+
'id': 21,
165+
'children': [{'name': 'RSmallToe', 'id': 23}]},
166+
{'name': 'RHeel', 'id': 25}]}]}]},
167+
{'name': 'LHip',
168+
'id': 11,
169+
'children': [{'name': 'LKnee',
170+
'id': 13,
171+
'children': [{'name': 'LAnkle',
172+
'id': 15,
173+
'children': [{'name': 'LBigToe',
174+
'id': 20,
175+
'children': [{'name': 'LSmallToe', 'id': 22}]},
176+
{'name': 'LHeel', 'id': 24}]}]}]},
177+
{'name': 'Neck',
178+
'id': 18,
179+
'children': [{'name': 'Head',
180+
'id': 17,
181+
'children': [{'name': 'Nose', 'id': 0}]},
182+
{'name': 'RShoulder',
183+
'id': 6,
184+
'children': [{'name': 'RElbow',
185+
'id': 8,
186+
'children': [{'name': 'RWrist', 'id': 10}]}]},
187+
{'name': 'LShoulder',
188+
'id': 5,
189+
'children': [{'name': 'LElbow',
190+
'id': 7,
191+
'children': [{'name': 'LWrist', 'id': 9}]}]}]}]}
155192
},
156193
'px_to_meters_conversion': {
157194
'to_meters': True,

Sports2D/process.py

Lines changed: 78 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -59,14 +59,13 @@
5959
import copy
6060
import shutil
6161
import os
62+
import re
6263
from importlib.metadata import version
63-
from functools import partial
6464
from datetime import datetime
6565
import itertools as it
6666
from tqdm import tqdm
6767
from collections import defaultdict
6868
from anytree import RenderTree
69-
from anytree.importer import DictImporter
7069

7170
import numpy as np
7271
import pandas as pd
@@ -193,6 +192,79 @@ def setup_video(video_file_path, save_vid, vid_output_path):
193192
return cap, out_vid, cam_width, cam_height, fps
194193

195194

195+
def setup_model_class_mode(pose_model, mode, config_dict={}):
196+
'''
197+
198+
'''
199+
200+
if pose_model.upper() in ('HALPE_26', 'BODY_WITH_FEET'):
201+
model_name = 'HALPE_26'
202+
ModelClass = BodyWithFeet # 26 keypoints(halpe26)
203+
logging.info(f"Using HALPE_26 model (body and feet) for pose estimation in {mode} mode.")
204+
elif pose_model.upper() in ('COCO_133', 'WHOLE_BODY', 'WHOLE_BODY_WRIST'):
205+
model_name = 'COCO_133'
206+
ModelClass = Wholebody
207+
logging.info(f"Using COCO_133 model (body, feet, hands, and face) for pose estimation in {mode} mode.")
208+
elif pose_model.upper() in ('COCO_17', 'BODY'):
209+
model_name = 'COCO_17'
210+
ModelClass = Body
211+
logging.info(f"Using COCO_17 model (body) for pose estimation in {mode} mode.")
212+
elif pose_model.upper() =='HAND':
213+
model_name = 'HAND_21'
214+
ModelClass = Hand
215+
logging.info(f"Using HAND_21 model for pose estimation in {mode} mode.")
216+
elif pose_model.upper() =='FACE':
217+
model_name = 'FACE_106'
218+
logging.info(f"Using FACE_106 model for pose estimation in {mode} mode.")
219+
elif pose_model.upper() =='ANIMAL':
220+
model_name = 'ANIMAL2D_17'
221+
logging.info(f"Using ANIMAL2D_17 model for pose estimation in {mode} mode.")
222+
else:
223+
model_name = pose_model.upper()
224+
logging.info(f"Using model {model_name} for pose estimation in {mode} mode.")
225+
try:
226+
pose_model = eval(model_name)
227+
except:
228+
try: # from Config.toml
229+
from anytree.importer import DictImporter
230+
model_name = pose_model.upper()
231+
pose_model = DictImporter().import_(config_dict.get('pose').get(pose_model))
232+
if pose_model.id == 'None':
233+
pose_model.id = None
234+
logging.info(f"Using model {model_name} for pose estimation.")
235+
except:
236+
raise NameError(f'{pose_model} not found in skeletons.py nor in Config.toml')
237+
238+
# Manually select the models if mode is a dictionary rather than 'lightweight', 'balanced', or 'performance'
239+
if not mode in ['lightweight', 'balanced', 'performance'] or 'ModelClass' not in locals():
240+
try:
241+
from functools import partial
242+
try:
243+
mode = ast.literal_eval(mode)
244+
except: # if within single quotes instead of double quotes when run with sports2d --mode """{dictionary}"""
245+
mode = mode.strip("'").replace('\n', '').replace(" ", "").replace(",", '", "').replace(":", '":"').replace("{", '{"').replace("}", '"}').replace('":"/',':/').replace('":"\\',':\\')
246+
mode = re.sub(r'"\[([^"]+)",\s?"([^"]+)\]"', r'[\1,\2]', mode) # changes "[640", "640]" to [640,640]
247+
mode = json.loads(mode)
248+
det_class = mode.get('det_class')
249+
det = mode.get('det_model')
250+
det_input_size = mode.get('det_input_size')
251+
pose_class = mode.get('pose_class')
252+
pose = mode.get('pose_model')
253+
pose_input_size = mode.get('pose_input_size')
254+
255+
ModelClass = partial(Custom,
256+
det_class=det_class, det=det, det_input_size=det_input_size,
257+
pose_class=pose_class, pose=pose, pose_input_size=pose_input_size)
258+
logging.info(f"Using model {model_name} with the following custom parameters: {mode}.")
259+
260+
except (json.JSONDecodeError, TypeError):
261+
logging.warning("Invalid mode. Must be 'lightweight', 'balanced', 'performance', or '''{dictionary}''' of parameters within triple quotes. Make sure input_sizes are within square brackets.")
262+
logging.warning('Using the default "balanced" mode.')
263+
mode = 'balanced'
264+
265+
return pose_model, ModelClass, mode
266+
267+
196268
def setup_backend_device(backend='auto', device='auto'):
197269
'''
198270
Set up the backend and device for the pose tracker based on the availability of hardware acceleration.
@@ -1427,74 +1499,15 @@ def process_fun(config_dict, video_file, time_range, frame_rate, result_dir):
14271499
cv2.namedWindow(f'{video_file} Sports2D', cv2.WINDOW_NORMAL + cv2.WINDOW_KEEPRATIO)
14281500
cv2.setWindowProperty(f'{video_file} Sports2D', cv2.WND_PROP_ASPECT_RATIO, cv2.WINDOW_FULLSCREEN)
14291501

1502+
14301503
# Select the appropriate model based on the model_type
14311504
logging.info('\nEstimating pose...')
1432-
if pose_model.upper() in ('HALPE_26', 'BODY_WITH_FEET'):
1433-
model_name = 'HALPE_26'
1434-
ModelClass = BodyWithFeet # 26 keypoints(halpe26)
1435-
logging.info(f"Using HALPE_26 model (body and feet) for pose estimation.")
1436-
elif pose_model.upper() in ('COCO_133', 'WHOLE_BODY', 'WHOLE_BODY_WRIST'):
1437-
model_name = 'COCO_133'
1438-
ModelClass = Wholebody
1439-
logging.info(f"Using COCO_133 model (body, feet, hands, and face) for pose estimation.")
1440-
elif pose_model.upper() in ('COCO_17', 'BODY'):
1441-
model_name = 'COCO_17'
1442-
ModelClass = Body
1443-
logging.info(f"Using COCO_17 model (body) for pose estimation.")
1444-
elif pose_model.upper() =='HAND':
1445-
model_name = 'HAND_21'
1446-
ModelClass = Hand
1447-
logging.info(f"Using HAND_21 model for pose estimation.")
1448-
elif pose_model.upper() =='FACE':
1449-
model_name = 'FACE_106'
1450-
logging.info(f"Using FACE_106 model for pose estimation.")
1451-
elif pose_model.upper() =='ANIMAL':
1452-
model_name = 'ANIMAL2D_17'
1453-
logging.info(f"Using ANIMAL2D_17 model for pose estimation.")
1454-
else:
1455-
model_name = pose_model.upper()
1456-
logging.info(f"Using model {model_name} for pose estimation.")
14571505
pose_model_name = pose_model
1458-
try:
1459-
pose_model = eval(model_name)
1460-
except:
1461-
try: # from Config.toml
1462-
pose_model = DictImporter().import_(config_dict.get('pose').get(pose_model))
1463-
if pose_model.id == 'None':
1464-
pose_model.id = None
1465-
except:
1466-
raise NameError(f'{pose_model} not found in skeletons.py nor in Config.toml')
1467-
1506+
pose_model, ModelClass, mode = setup_model_class_mode(pose_model, mode, config_dict)
1507+
14681508
# Select device and backend
14691509
backend, device = setup_backend_device(backend=backend, device=device)
14701510

1471-
# Manually select the models if mode is a dictionary rather than 'lightweight', 'balanced', or 'performance'
1472-
if not mode in ['lightweight', 'balanced', 'performance']:
1473-
try:
1474-
try:
1475-
mode = ast.literal_eval(mode)
1476-
except: # if within single quotes instead of double quotes when run with sports2d --mode """{dictionary}"""
1477-
mode = mode.strip("'").replace('\n', '').replace(" ", "").replace(",", '", "').replace(":", '":"').replace("{", '{"').replace("}", '"}').replace('":"/',':/').replace('":"\\',':\\')
1478-
mode = re.sub(r'"\[([^"]+)",\s?"([^"]+)\]"', r'[\1,\2]', mode) # changes "[640", "640]" to [640,640]
1479-
mode = json.loads(mode)
1480-
det_class = mode.get('det_class')
1481-
det = mode.get('det_model')
1482-
det_input_size = mode.get('det_input_size')
1483-
pose_class = mode.get('pose_class')
1484-
pose = mode.get('pose_model')
1485-
pose_input_size = mode.get('pose_input_size')
1486-
1487-
ModelClass = partial(Custom,
1488-
det_class=det_class, det=det, det_input_size=det_input_size,
1489-
pose_class=pose_class, pose=pose, pose_input_size=pose_input_size,
1490-
backend=backend, device=device)
1491-
1492-
except (json.JSONDecodeError, TypeError):
1493-
logging.warning("Invalid mode. Must be 'lightweight', 'balanced', 'performance', or '''{dictionary}''' of parameters within triple quotes. Make sure input_sizes are within square brackets.")
1494-
logging.warning('Using the default "balanced" mode.')
1495-
mode = 'balanced'
1496-
1497-
14981511
# Skip pose estimation or set it up:
14991512
if load_trc_px:
15001513
if not '_px' in str(load_trc_px):
@@ -1722,6 +1735,7 @@ def process_fun(config_dict, video_file, time_range, frame_rate, result_dir):
17221735
all_frames_scores_homog = all_frames_scores_homog[...,new_keypoints_ids]
17231736

17241737
frame_range = [0,frame_count] if video_file == 'webcam' else frame_range
1738+
print(frame_range)
17251739
all_frames_time = pd.Series(np.linspace(frame_range[0]/fps, frame_range[1]/fps, frame_count-frame_range[0]), name='time')
17261740
if load_trc_px:
17271741
selected_persons = [0]

0 commit comments

Comments
 (0)