From 8eb21889635535da3b525ff1f32b059256025226 Mon Sep 17 00:00:00 2001
From: Yannick Ulrich <yannick.ulrich@psi.ch>
Date: Sun, 9 Feb 2020 12:06:15 +0100
Subject: [PATCH] 14: Generic combineplot function in re #17

---
 pymule/__init__.py   |  3 ++-
 pymule/errortools.py | 23 +++++++++++++++++++----
 2 files changed, 21 insertions(+), 5 deletions(-)

diff --git a/pymule/__init__.py b/pymule/__init__.py
index 01195ed..add9279 100644
--- a/pymule/__init__.py
+++ b/pymule/__init__.py
@@ -3,7 +3,8 @@ import matplotlib.pyplot as plt
 
 from vegas import importvegas
 from errortools import mergenumbers, plusnumbers, dividenumbers, timesnumbers,\
-                       mergeplots, addplots, combineNplots, scaleplot,        \
+                       mergeplots, addplots, divideplots, scaleplot,          \
+                       combineplots, combineNplots,                           \
                        integratehistogram, mergebins
 from loader import importreg, pattern, setup, sigma,                          \
                    mergeset, mergeseeds, mergefks,                            \
diff --git a/pymule/errortools.py b/pymule/errortools.py
index acfb162..36f6434 100644
--- a/pymule/errortools.py
+++ b/pymule/errortools.py
@@ -82,7 +82,7 @@ def mergeplots(ps, returnchi=False):
         return out
 
 
-def addplots(a, b, sa=1., sb=1.):
+def combineplots(a, b, yfunc, efunc):
     # There must be a better way of doing this
     maskA = [False]*len(a)
     maskB = [False]*len(b)
@@ -93,12 +93,27 @@ def addplots(a, b, sa=1., sb=1.):
                 maskB[ib] = True
 
     x = a[maskA, 0]
-    y = sa*a[maskA, 1] + sb*b[maskB, 1]
-    e = np.sqrt(sa**2 * a[maskA, 2]**2 + sb**2 * b[maskB, 2]**2)
-
+    y = yfunc(a[maskA, 1], b[maskB, 1])
+    e = efunc(a[maskA, 1], a[maskA, 2], b[maskB, 1], b[maskB, 2])
     return np.column_stack((x, y, e))
 
 
+def addplots(a, b, sa=1., sb=1.):
+    return combineplots(
+        a, b,
+        lambda y1, y2: y1 + y2,
+        lambda y1, e1, y2, e2: np.sqrt(e1**2 + e2**2)
+    )
+
+
+def divideplots(a, b):
+    return combineplots(
+        a, b,
+        lambda y1, y2: y1 / y2,
+        lambda y1, e1, y2, e2: np.sqrt(e2**2 * y1**2 / y2**4 + e1**2 / y2**2)
+    )
+
+
 def combineNplots(func, plots):
     accum = plots[0]
     for plot in plots[1:]:
-- 
GitLab