diff --git a/resspect/feature_extractors/light_curve.py b/resspect/feature_extractors/light_curve.py index 431f9f77..d947cf05 100644 --- a/resspect/feature_extractors/light_curve.py +++ b/resspect/feature_extractors/light_curve.py @@ -18,6 +18,7 @@ import logging from typing import Tuple import warnings +import os import numpy as np import pandas as pd @@ -30,7 +31,6 @@ from resspect.lightcurves_utils import load_plasticc_photometry_df from resspect.lightcurves_utils import get_snpcc_sntype - warnings.filterwarnings("ignore", category=RuntimeWarning) logging.basicConfig(level=logging.INFO) @@ -102,7 +102,15 @@ class LightCurve: """ - def __init__(self): + def __init__(self, lc=None): + if lc is None: + self._non_copy_constructor() + else: + if not isinstance(lc, LightCurve): + raise RuntimeError("argument is not a LightCurve object: "+type(lc)) + self._copy_constructor(lc) + + def _non_copy_constructor(self): self.queryable = None self.features = [] #self.features_names = ['p1', 'p2', 'p3', 'time_shift', 'max_flux'] @@ -121,6 +129,55 @@ def __init__(self): self.sncode = 0 self.sntype = ' ' + def _copy_constructor(self, lc): + self.queryable = lc.queryable + self.features = lc.features + self.dataset_name = lc.dataset_name + self.exp_time = lc.exp_time + self.filters = lc.filters + self.full_photometry = lc.full_photometry + self.id = lc.id + self.id_name = lc.id_name + self.last_mag = lc.last_mag + self.photometry = lc.photometry + self.redshift = lc.redshift + self.sample = ls.sample + self.sim_peakmag = lc.sim_peakmag + self.sim_pkmjd = lc.sim_pkmjd + self.sncode = lc.sncode + self.sntype = lc.sntype + + @staticmethod + def from_file(filename: str) -> list: + light_curves = [] + with open(filename, 'r') as f: + for ff in f.readlines(): + survey, path = ff.split() + if not os.path.exists(path): + raise FileNotFoundError('File Not found: '+path) + survey = survey.strip().upper() + lc = LightCurve() + if survey == "SNPCC": + lc.load_snpcc_lc(path) + elif survey == "PLASTICC": + lc.load_plasticc_lc(path) + else: + raise NameError("survey argument not recognized: "+survey) + light_curves.append(lc) + return light_curves + + @staticmethod + def compute_feature(feature_type: str, allLC: list) -> list: + feature_list = [] + if feature_type == 'bazin': + extractor = BazinFeatureExtractor + elif feature_extractor == 'bump': + extractor = BumpFeatureExtractor + for lc in allLC: + ex = extractor(lc) + ex.fit_all() + feature_list.appen(ex) + def _get_snpcc_photometry_raw_and_header( self, lc_data: np.ndarray, sntype_test_value: str = "-9") -> Tuple[np.ndarray, list]: