Skip to content

Commit 4011656

Browse files
committed
Refactor institute_list model
- Moved functions related to `InstituteList` into the class
1 parent 192ece2 commit 4011656

File tree

6 files changed

+66
-67
lines changed

6 files changed

+66
-67
lines changed

process_report/institute_list_models.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,16 @@
11
from typing import Annotated
22
import datetime
33
import functools
4+
import logging
45

56
import pydantic
67
import validators
78

89

10+
logger = logging.getLogger(__name__)
11+
logging.basicConfig(level=logging.INFO)
12+
13+
914
def parse_date(v: str) -> str:
1015
try:
1116
datetime.datetime.strptime(v, "%Y-%m")
@@ -70,3 +75,27 @@ def nonbillable_course_list(self) -> list[str]:
7075
if institute_info.courses_nonbillable:
7176
institute_list.append(institute_info.display_name)
7277
return institute_list
78+
79+
@functools.cached_property
80+
def domain_institute_mapping(self) -> dict[str, str]:
81+
"""Dict mapping web domains to institution display names"""
82+
institute_map = dict()
83+
for institute_info in self.root:
84+
for domain in institute_info.domains:
85+
institute_map[domain] = institute_info.display_name
86+
87+
return institute_map
88+
89+
def get_institution_from_pi(self, pi_email) -> str:
90+
institution_domain = pi_email.split("@")[-1]
91+
for i in range(institution_domain.count(".") + 1):
92+
if institution_name := self.domain_institute_mapping.get(
93+
institution_domain, ""
94+
):
95+
break
96+
institution_domain = institution_domain[institution_domain.find(".") + 1 :]
97+
98+
if institution_name == "":
99+
logger.warning(f"PI name {pi_email} does not match any institution!")
100+
101+
return institution_name

process_report/processors/add_institution_processor.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,15 +27,14 @@ def _add_institution(self):
2727
The list of mappings are defined in `institute_map.json`.
2828
"""
2929
institute_list = util.load_institute_list()
30-
institute_map = util.get_institute_mapping(institute_list)
3130
self.data = self.data.astype({invoice.INSTITUTION_FIELD: "str"})
3231
for i, row in self.data.iterrows():
3332
pi_name = row[invoice.PI_FIELD]
3433
if pandas.isna(pi_name):
3534
logger.info(f"Project {row[invoice.PROJECT_FIELD]} has no PI")
3635
else:
3736
self.data.at[i, invoice.INSTITUTION_FIELD] = (
38-
util.get_institution_from_pi(institute_map, pi_name)
37+
institute_list.get_institution_from_pi(pi_name)
3938
)
4039

4140
def _process(self):

process_report/processors/prepayment_processor.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -150,11 +150,10 @@ def _get_prepay_group_dict(self):
150150
def _add_prepay_info(self):
151151
"""Populate prepaid group name, institute, and MGHPCC managed field"""
152152
institute_list = util.load_institute_list()
153-
institute_map = util.get_institute_mapping(institute_list)
154153

155154
for group_name, group_dict in self.group_info_dict.items():
156-
group_institute = util.get_institution_from_pi(
157-
institute_map, group_dict[invoice.PREPAY_GROUP_CONTACT_FIELD]
155+
group_institute = institute_list.get_institution_from_pi(
156+
group_dict[invoice.PREPAY_GROUP_CONTACT_FIELD]
158157
)
159158

160159
# Prepay projects are identified by project name, not project - allocation name

process_report/tests/unit/processors/test_add_institution_processor.py

Lines changed: 0 additions & 36 deletions
This file was deleted.

process_report/tests/unit/test_institute_list_validate.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import pytest
33

44
from process_report.institute_list_validate import main
5+
from process_report.institute_list_models import InstituteList
56
from process_report.tests.base import BaseTestCaseWithTempDir
67

78

@@ -42,3 +43,36 @@ def test_invalid_institute_list(self):
4243
f.flush()
4344
with pytest.raises(SystemExit):
4445
main(["--github", str(test_file)])
46+
47+
def test_get_pi_institution(self):
48+
domain_map = {
49+
"harvard.edu": "Harvard University",
50+
"bu.edu": "Boston University",
51+
"bentley.edu": "Bentley",
52+
"mclean.harvard.edu": "McLean Hospital",
53+
"northeastern.edu": "Northeastern University",
54+
"childrens.harvard.edu": "Boston Children's Hospital",
55+
"meei.harvard.edu": "Massachusetts Eye & Ear",
56+
"dfci.harvard.edu": "Dana-Farber Cancer Institute",
57+
"bwh.harvard.edu": "Brigham and Women's Hospital",
58+
"bidmc.harvard.edu": "Beth Israel Deaconess Medical Center",
59+
}
60+
test_institute_list = InstituteList([])
61+
test_institute_list.domain_institute_mapping = domain_map
62+
63+
answers = {
64+
"[email protected]": "Boston University",
65+
"[email protected]": "McLean Hospital",
66+
"[email protected]": "Harvard University",
67+
"e@edu": "",
68+
"[email protected]": "Northeastern University",
69+
"[email protected]": "Harvard University",
70+
"[email protected]": "Boston Children's Hospital",
71+
"[email protected]": "Massachusetts Eye & Ear",
72+
73+
"[email protected]": "Brigham and Women's Hospital",
74+
"[email protected]": "Beth Israel Deaconess Medical Center",
75+
}
76+
77+
for pi_email, answer in answers.items():
78+
assert test_institute_list.get_institution_from_pi(pi_email) == answer

process_report/util.py

Lines changed: 0 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,12 @@
11
import os
22
import datetime
33
import yaml
4-
import logging
54
import functools
65

76
import boto3
87

98
from process_report.institute_list_models import InstituteList
109

11-
logger = logging.getLogger(__name__)
12-
logging.basicConfig(level=logging.INFO)
13-
1410

1511
DEFAULT_INSTITUTE_LIST = "process_report/institute_list.yaml"
1612

@@ -39,28 +35,6 @@ def load_institute_list() -> InstituteList:
3935
return InstituteList.model_validate(institute_list)
4036

4137

42-
def get_institute_mapping(institute_list: InstituteList):
43-
institute_map = dict()
44-
for institute_info in institute_list.root:
45-
for domain in institute_info.domains:
46-
institute_map[domain] = institute_info.display_name
47-
48-
return institute_map
49-
50-
51-
def get_institution_from_pi(institute_map, pi_uname):
52-
institution_domain = pi_uname.split("@")[-1]
53-
for i in range(institution_domain.count(".") + 1):
54-
if institution_name := institute_map.get(institution_domain, ""):
55-
break
56-
institution_domain = institution_domain[institution_domain.find(".") + 1 :]
57-
58-
if institution_name == "":
59-
logger.warning(f"PI name {pi_uname} does not match any institution!")
60-
61-
return institution_name
62-
63-
6438
def get_iso8601_time():
6539
return datetime.datetime.now().strftime("%Y%m%dT%H%M%SZ")
6640

0 commit comments

Comments
 (0)