Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 25 additions & 18 deletions columnflow/tasks/cmsGhent/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,11 +54,12 @@ def run(self):
for var_name in variable_tuple
]
category_insts = [self.config_inst.get_category(c) for c in self.branch_data.categories]
category_insts_leafs = [c.get_leaf_categories() or [c] for c in category_insts]
process_inst = self.config_inst.get_process(self.branch_data.process)
sub_process_insts = [sub for sub, _, _ in process_inst.walk_processes(include_self=True)]

# histogram data for process
process_hist = 0
process_hists = {c.name: 0 for c in category_insts}

with self.publish_step(f"plotting {self.branch_data.variable} for {process_inst.name}"):
for dataset, inp in self.input().items():
Expand All @@ -72,48 +73,54 @@ def run(self):

# work on a copy
h = h_in.copy()

# axis selections
h = h[{
"process": [
hist.loc(p.id)
for p in sub_process_insts
if p.id in h.axes["process"]
],
"category": [
hist.loc(c.id)
for c in category_insts
if c.id in h.axes["category"]
],
"shift": [
hist.loc(s.id)
for s in plot_shifts
if s.id in h.axes["shift"]
],
}]

# axis reductions
h = h[{"process": sum}]
# axis selections
for c, lcs in zip(category_insts, category_insts_leafs):
hc = h[{
"category": [
hist.loc(c.id)
for c in lcs
if c.id in h.axes["category"]
],
}]

# add the histogram
process_hist = h + process_hist
# axis reductions
hc = hc[{"category": sum}]

# add the histsogram
process_hists[c.name] = hc + process_hists[c.name]

# there should be hists to plot
if not process_hist:
if not all(process_hists.values()):
raise Exception(
"no histograms found to plot; possible reasons:\n" +
" - requested variable requires columns that were missing during histogramming\n" +
" - selected --processes did not match any value on the process axis of the input histogram",
)

process_hists = OrderedDict(
(cat.name, h[{"category": hist.loc(cat.id)}])
for cat in category_insts
)
# update histograms using custom hooks
hists = self.invoke_hist_hooks(process_hists)

for cat in hists:
if "process" in hists[cat].axes.name:
hists[cat] = hists[cat][{"process": sum}]

# call the plot function
fig, _ = self.call_plot_func(
self.plot_function,
hists=process_hists,
hists=hists,
config_inst=self.config_inst,
category_inst=process_inst.copy_shallow(),
variable_insts=[var_inst.copy_shallow() for var_inst in variable_insts],
Expand Down