|
59 | 59 | import copy |
60 | 60 | import shutil |
61 | 61 | import os |
| 62 | +import re |
62 | 63 | from importlib.metadata import version |
63 | | -from functools import partial |
64 | 64 | from datetime import datetime |
65 | 65 | import itertools as it |
66 | 66 | from tqdm import tqdm |
67 | 67 | from collections import defaultdict |
68 | 68 | from anytree import RenderTree |
69 | | -from anytree.importer import DictImporter |
70 | 69 |
|
71 | 70 | import numpy as np |
72 | 71 | import pandas as pd |
@@ -193,6 +192,79 @@ def setup_video(video_file_path, save_vid, vid_output_path): |
193 | 192 | return cap, out_vid, cam_width, cam_height, fps |
194 | 193 |
|
195 | 194 |
|
| 195 | +def setup_model_class_mode(pose_model, mode, config_dict={}): |
| 196 | + ''' |
| 197 | + |
| 198 | + ''' |
| 199 | + |
| 200 | + if pose_model.upper() in ('HALPE_26', 'BODY_WITH_FEET'): |
| 201 | + model_name = 'HALPE_26' |
| 202 | + ModelClass = BodyWithFeet # 26 keypoints(halpe26) |
| 203 | + logging.info(f"Using HALPE_26 model (body and feet) for pose estimation in {mode} mode.") |
| 204 | + elif pose_model.upper() in ('COCO_133', 'WHOLE_BODY', 'WHOLE_BODY_WRIST'): |
| 205 | + model_name = 'COCO_133' |
| 206 | + ModelClass = Wholebody |
| 207 | + logging.info(f"Using COCO_133 model (body, feet, hands, and face) for pose estimation in {mode} mode.") |
| 208 | + elif pose_model.upper() in ('COCO_17', 'BODY'): |
| 209 | + model_name = 'COCO_17' |
| 210 | + ModelClass = Body |
| 211 | + logging.info(f"Using COCO_17 model (body) for pose estimation in {mode} mode.") |
| 212 | + elif pose_model.upper() =='HAND': |
| 213 | + model_name = 'HAND_21' |
| 214 | + ModelClass = Hand |
| 215 | + logging.info(f"Using HAND_21 model for pose estimation in {mode} mode.") |
| 216 | + elif pose_model.upper() =='FACE': |
| 217 | + model_name = 'FACE_106' |
| 218 | + logging.info(f"Using FACE_106 model for pose estimation in {mode} mode.") |
| 219 | + elif pose_model.upper() =='ANIMAL': |
| 220 | + model_name = 'ANIMAL2D_17' |
| 221 | + logging.info(f"Using ANIMAL2D_17 model for pose estimation in {mode} mode.") |
| 222 | + else: |
| 223 | + model_name = pose_model.upper() |
| 224 | + logging.info(f"Using model {model_name} for pose estimation in {mode} mode.") |
| 225 | + try: |
| 226 | + pose_model = eval(model_name) |
| 227 | + except: |
| 228 | + try: # from Config.toml |
| 229 | + from anytree.importer import DictImporter |
| 230 | + model_name = pose_model.upper() |
| 231 | + pose_model = DictImporter().import_(config_dict.get('pose').get(pose_model)) |
| 232 | + if pose_model.id == 'None': |
| 233 | + pose_model.id = None |
| 234 | + logging.info(f"Using model {model_name} for pose estimation.") |
| 235 | + except: |
| 236 | + raise NameError(f'{pose_model} not found in skeletons.py nor in Config.toml') |
| 237 | + |
| 238 | + # Manually select the models if mode is a dictionary rather than 'lightweight', 'balanced', or 'performance' |
| 239 | + if not mode in ['lightweight', 'balanced', 'performance'] or 'ModelClass' not in locals(): |
| 240 | + try: |
| 241 | + from functools import partial |
| 242 | + try: |
| 243 | + mode = ast.literal_eval(mode) |
| 244 | + except: # if within single quotes instead of double quotes when run with sports2d --mode """{dictionary}""" |
| 245 | + mode = mode.strip("'").replace('\n', '').replace(" ", "").replace(",", '", "').replace(":", '":"').replace("{", '{"').replace("}", '"}').replace('":"/',':/').replace('":"\\',':\\') |
| 246 | + mode = re.sub(r'"\[([^"]+)",\s?"([^"]+)\]"', r'[\1,\2]', mode) # changes "[640", "640]" to [640,640] |
| 247 | + mode = json.loads(mode) |
| 248 | + det_class = mode.get('det_class') |
| 249 | + det = mode.get('det_model') |
| 250 | + det_input_size = mode.get('det_input_size') |
| 251 | + pose_class = mode.get('pose_class') |
| 252 | + pose = mode.get('pose_model') |
| 253 | + pose_input_size = mode.get('pose_input_size') |
| 254 | + |
| 255 | + ModelClass = partial(Custom, |
| 256 | + det_class=det_class, det=det, det_input_size=det_input_size, |
| 257 | + pose_class=pose_class, pose=pose, pose_input_size=pose_input_size) |
| 258 | + logging.info(f"Using model {model_name} with the following custom parameters: {mode}.") |
| 259 | + |
| 260 | + except (json.JSONDecodeError, TypeError): |
| 261 | + logging.warning("Invalid mode. Must be 'lightweight', 'balanced', 'performance', or '''{dictionary}''' of parameters within triple quotes. Make sure input_sizes are within square brackets.") |
| 262 | + logging.warning('Using the default "balanced" mode.') |
| 263 | + mode = 'balanced' |
| 264 | + |
| 265 | + return pose_model, ModelClass, mode |
| 266 | + |
| 267 | + |
196 | 268 | def setup_backend_device(backend='auto', device='auto'): |
197 | 269 | ''' |
198 | 270 | Set up the backend and device for the pose tracker based on the availability of hardware acceleration. |
@@ -1427,74 +1499,15 @@ def process_fun(config_dict, video_file, time_range, frame_rate, result_dir): |
1427 | 1499 | cv2.namedWindow(f'{video_file} Sports2D', cv2.WINDOW_NORMAL + cv2.WINDOW_KEEPRATIO) |
1428 | 1500 | cv2.setWindowProperty(f'{video_file} Sports2D', cv2.WND_PROP_ASPECT_RATIO, cv2.WINDOW_FULLSCREEN) |
1429 | 1501 |
|
| 1502 | + |
1430 | 1503 | # Select the appropriate model based on the model_type |
1431 | 1504 | logging.info('\nEstimating pose...') |
1432 | | - if pose_model.upper() in ('HALPE_26', 'BODY_WITH_FEET'): |
1433 | | - model_name = 'HALPE_26' |
1434 | | - ModelClass = BodyWithFeet # 26 keypoints(halpe26) |
1435 | | - logging.info(f"Using HALPE_26 model (body and feet) for pose estimation.") |
1436 | | - elif pose_model.upper() in ('COCO_133', 'WHOLE_BODY', 'WHOLE_BODY_WRIST'): |
1437 | | - model_name = 'COCO_133' |
1438 | | - ModelClass = Wholebody |
1439 | | - logging.info(f"Using COCO_133 model (body, feet, hands, and face) for pose estimation.") |
1440 | | - elif pose_model.upper() in ('COCO_17', 'BODY'): |
1441 | | - model_name = 'COCO_17' |
1442 | | - ModelClass = Body |
1443 | | - logging.info(f"Using COCO_17 model (body) for pose estimation.") |
1444 | | - elif pose_model.upper() =='HAND': |
1445 | | - model_name = 'HAND_21' |
1446 | | - ModelClass = Hand |
1447 | | - logging.info(f"Using HAND_21 model for pose estimation.") |
1448 | | - elif pose_model.upper() =='FACE': |
1449 | | - model_name = 'FACE_106' |
1450 | | - logging.info(f"Using FACE_106 model for pose estimation.") |
1451 | | - elif pose_model.upper() =='ANIMAL': |
1452 | | - model_name = 'ANIMAL2D_17' |
1453 | | - logging.info(f"Using ANIMAL2D_17 model for pose estimation.") |
1454 | | - else: |
1455 | | - model_name = pose_model.upper() |
1456 | | - logging.info(f"Using model {model_name} for pose estimation.") |
1457 | 1505 | pose_model_name = pose_model |
1458 | | - try: |
1459 | | - pose_model = eval(model_name) |
1460 | | - except: |
1461 | | - try: # from Config.toml |
1462 | | - pose_model = DictImporter().import_(config_dict.get('pose').get(pose_model)) |
1463 | | - if pose_model.id == 'None': |
1464 | | - pose_model.id = None |
1465 | | - except: |
1466 | | - raise NameError(f'{pose_model} not found in skeletons.py nor in Config.toml') |
1467 | | - |
| 1506 | + pose_model, ModelClass, mode = setup_model_class_mode(pose_model, mode, config_dict) |
| 1507 | + |
1468 | 1508 | # Select device and backend |
1469 | 1509 | backend, device = setup_backend_device(backend=backend, device=device) |
1470 | 1510 |
|
1471 | | - # Manually select the models if mode is a dictionary rather than 'lightweight', 'balanced', or 'performance' |
1472 | | - if not mode in ['lightweight', 'balanced', 'performance']: |
1473 | | - try: |
1474 | | - try: |
1475 | | - mode = ast.literal_eval(mode) |
1476 | | - except: # if within single quotes instead of double quotes when run with sports2d --mode """{dictionary}""" |
1477 | | - mode = mode.strip("'").replace('\n', '').replace(" ", "").replace(",", '", "').replace(":", '":"').replace("{", '{"').replace("}", '"}').replace('":"/',':/').replace('":"\\',':\\') |
1478 | | - mode = re.sub(r'"\[([^"]+)",\s?"([^"]+)\]"', r'[\1,\2]', mode) # changes "[640", "640]" to [640,640] |
1479 | | - mode = json.loads(mode) |
1480 | | - det_class = mode.get('det_class') |
1481 | | - det = mode.get('det_model') |
1482 | | - det_input_size = mode.get('det_input_size') |
1483 | | - pose_class = mode.get('pose_class') |
1484 | | - pose = mode.get('pose_model') |
1485 | | - pose_input_size = mode.get('pose_input_size') |
1486 | | - |
1487 | | - ModelClass = partial(Custom, |
1488 | | - det_class=det_class, det=det, det_input_size=det_input_size, |
1489 | | - pose_class=pose_class, pose=pose, pose_input_size=pose_input_size, |
1490 | | - backend=backend, device=device) |
1491 | | - |
1492 | | - except (json.JSONDecodeError, TypeError): |
1493 | | - logging.warning("Invalid mode. Must be 'lightweight', 'balanced', 'performance', or '''{dictionary}''' of parameters within triple quotes. Make sure input_sizes are within square brackets.") |
1494 | | - logging.warning('Using the default "balanced" mode.') |
1495 | | - mode = 'balanced' |
1496 | | - |
1497 | | - |
1498 | 1511 | # Skip pose estimation or set it up: |
1499 | 1512 | if load_trc_px: |
1500 | 1513 | if not '_px' in str(load_trc_px): |
@@ -1722,6 +1735,7 @@ def process_fun(config_dict, video_file, time_range, frame_rate, result_dir): |
1722 | 1735 | all_frames_scores_homog = all_frames_scores_homog[...,new_keypoints_ids] |
1723 | 1736 |
|
1724 | 1737 | frame_range = [0,frame_count] if video_file == 'webcam' else frame_range |
| 1738 | + print(frame_range) |
1725 | 1739 | all_frames_time = pd.Series(np.linspace(frame_range[0]/fps, frame_range[1]/fps, frame_count-frame_range[0]), name='time') |
1726 | 1740 | if load_trc_px: |
1727 | 1741 | selected_persons = [0] |
|
0 commit comments