Presentation as HTML

egglog in Python#

In this talk I will try to cover how egglog in Python can:

  • Enable decentralized collaboration in the Python data science ecosystem.

  • Provide a faithful authoring environment for egglog.

../_images/fc870f4b8a1f6e5252569b2788e34e6956b817c6a44da05c784d7fef79295979.svg

Saul Shanabrook @ EGRAPHS Workshop - PLDI ‘23

Open Source Data Science Ecosystem in Python#

The term “ecosystem” is often used to describe the modern open-source scientific software. In biology, the term “ecosystem” is defined as a biological community of interacting organisms and their physical environment. Modern open-source scientific software development occurs in a similarly interconnected and interoperable fashion.

from Jupyter Meets the Earth: Ecosystem

Aims#

  • How can the tools we build foster greater resiliancy, collaboration, and interdependence in this ecosystem?

  • How can they help it stay flexible enough to adapt to the changing computational landscape to empower users and authors?

What role could egglog play?#

  • Bring the PL community closer to this space, providing theoretical frameworks for thinking about composition and language.

  • Constrained type system could support decentralized interopability and composition between data science libraries.

from __future__ import annotations
from egglog import *

Other Python EGraph Libraries#

TODO: Put this first, Say it’s for library authors

Semantics of python and egglog

  • Started with snake-egg

  • Didn’t want to re-invent the wheel, stay abreast of recent developments and research

  • Second piece that interests me

    • Unlike egg there are some builtin sorts, and can build user defined sorts on top of those

    • No host language conditions or data structures

    • Helps with optimization, more constrained

    • -> De-centers algorithms based on value, move to based on type. Everything becomes an interface.

    • Social dynamics, goal is ability to inovate and experiment, while still supporting existing use cases

      • New dataframe library comes out, supporting custom hardware. How dow we use it without rewriting code?

      • How do we have healthy ecosystem within these tools? Power

      • If it’s too hard, encourages centralized monopolistic actors to step in provide one stop shop solutions for users.

    • Active problem in the community, with things like trying to standardize on interop.

    • Before getting too abstract, let’s go to an example

A story about Arrays#

  • This is one path through a huge maze of use cases.

  • Does not represent one killer example, but is an area I am familar with based on my previous work

1. Someone makes an NDArray library…#

ndarray_mod = Module()
...
@ndarray_mod.class_
class Value(Expr):
    def __init__(self, v: i64Like) -> None:
        ...

    def __mul__(self, other: Value) -> Value:
        ...

    def __add__(self, other: Value) -> Value:
        ...


i, j = vars_("i j", i64)
ndarray_mod.register(
    rewrite(Value(i) * Value(j)).to(Value(i * j)),
    rewrite(Value(i) + Value(j)).to(Value(i + j)),
)


@ndarray_mod.class_
class Values(Expr):
    def __init__(self, v: Vec[Value]) -> None:
        ...

    def __getitem__(self, idx: Value) -> Value:
        ...

    def length(self) -> Value:
        ...

    def concat(self, other: Values) -> Values:
        ...


@ndarray_mod.register
def _values(vs: Vec[Value], other: Vec[Value]):
    yield rewrite(Values(vs)[Value(i)]).to(vs[i])
    yield rewrite(Values(vs).length()).to(Value(vs.length()))
    yield rewrite(Values(vs).concat(Values(other))).to(Values(vs.append(other)))
@ndarray_mod.class_
class NDArray(Expr):
    def __getitem__(self, idx: Values) -> Value:
        ...

    def shape(self) -> Values:
        ...


@ndarray_mod.function
def arange(n: Value) -> NDArray:
    ...
  • Basic

  • One function, range, get shape and index into array

  • Very different from existing paradigms in Python… Inheritance, multi-dispatch, dunder protocols.

    • Entirely open protocol.

    • Anyone else could define ways to create arrays

    • About mathematical definition really. This is from M

Restifo Mullin, Lenore Marie, “A mathematics of arrays” (1988). Electrical Engineering and Computer Science - Dissertations. 249.

@ndarray_mod.register
def _(n: Value, idx: Values, a: NDArray):
    yield rewrite(arange(n).shape()).to(Values(Vec(n)))
    yield rewrite(arange(n)[idx]).to(idx[Value(0)])
  • Rules to compute shape and index into arange.

egraph = EGraph([ndarray_mod])
ten = egraph.let("ten", arange(Value(10)))
ten_shape = ten.shape()
egraph.register(ten_shape)

egraph.run(20)
egraph.display()
egraph.extract(ten_shape)
outer_cluster_Value.__init___3377577844511369682_0 cluster_Value.__init___3377577844511369682_0 outer_cluster_Values.__init___0_0 cluster_Values.__init___0_0 outer_cluster_0 cluster_0 outer_cluster_1 cluster_1 outer_cluster_2 cluster_2 Values.__init___0_0:s->Value.__init___3377577844511369682 Value.__init___3377577844511369682:s->Value.__init___3377577844511369682_0 arange_0:s->Value.__init___3377577844511369682 Values.__init___0:s->Values.__init___0_0 NDArray.shape_5871781006564002453:s->arange_0 Value.__init___3377577844511369682_0 10 Values.__init___0_0 Vec[Value] Value.__init___3377577844511369682 Value.__init__ ten_0 ten arange_0 arange Values.__init___0 Values.__init__ NDArray.shape_5871781006564002453 NDArray.shape
Values(Vec.empty().push(Value(10)))
ten_indexed = ten[Values(Vec(Value(7)))]
egraph.register(ten_indexed)

egraph.run(20)

egraph.display()
egraph.extract(ten_indexed)
outer_cluster_Value.__init___3377577844511369682_0 cluster_Value.__init___3377577844511369682_0 outer_cluster_Value.__init___4208978898528913939_0 cluster_Value.__init___4208978898528913939_0 outer_cluster_Values.__init___11743562013128004906_0 cluster_Values.__init___11743562013128004906_0 outer_cluster_Value.__init___0_0 cluster_Value.__init___0_0 outer_cluster_Values.__init___0_0 cluster_Values.__init___0_0 outer_cluster_2 cluster_2 outer_cluster_5 cluster_5 outer_cluster_1 cluster_1 outer_cluster_0 cluster_0 outer_cluster_6 cluster_6 outer_cluster_7 cluster_7 Values.__init___0_0:s->Value.__init___3377577844511369682 Values.__init___11743562013128004906_0:s->Values.__getitem___520482313101349337 NDArray.shape_5871781006564002453:s->arange_0 Values.__init___0:s->Values.__init___0_0 Values.__init___11743562013128004906:s->Values.__init___11743562013128004906_0 arange_0:s->Value.__init___3377577844511369682 Value.__init___3377577844511369682:s->Value.__init___3377577844511369682_0 Value.__init___4208978898528913939:s->Value.__init___4208978898528913939_0 NDArray.__getitem___11868447927124751835:s->Values.__init___11743562013128004906 NDArray.__getitem___11868447927124751835:s->ten_0 Values.__getitem___520482313101349337:s->Values.__init___11743562013128004906 Values.__getitem___520482313101349337:s->Value.__init___0 Value.__init___0:s->Value.__init___0_0 Value.__init___3377577844511369682_0 10 Value.__init___0_0 0 Value.__init___4208978898528913939_0 7 Values.__init___0_0 Vec[Value] Values.__init___11743562013128004906_0 Vec[Value] NDArray.shape_5871781006564002453 NDArray.shape Values.__init___0 Values.__init__ Values.__init___11743562013128004906 Values.__init__ arange_0 arange ten_0 ten Value.__init___3377577844511369682 Value.__init__ Value.__init___4208978898528913939 Value.__init__ NDArray.__getitem___11868447927124751835 NDArray.__getitem__ Values.__getitem___520482313101349337 Values.__getitem__ Value.__init___0 Value.__init__
Value(7)
  • Any user can try it now

2. Someone else decides to implement a cross product library#

cross_mod = Module([ndarray_mod])


@cross_mod.function
def cross(l: NDArray, r: NDArray) -> NDArray:
    ...


@cross_mod.register
def _cross(l: NDArray, r: NDArray, idx: Values):
    yield rewrite(cross(l, r).shape()).to(l.shape().concat(r.shape()))
    yield rewrite(cross(l, r)[idx]).to(l[idx] * r[idx])
  • Someone decides to add some functionality

  • Multiplicative cross product

  • Shape is concatation, index is product of each matrix at that index

  • Mathematical definition

egraph = EGraph([cross_mod])
egraph.simplify(cross(arange(Value(10)), arange(Value(11))).shape(), 10)
Values(Vec.empty().push(Value(11)).push(Value(10)))

3. I write my wonderful data science application using it#

def my_special_app(x: Value) -> Value:
    return cross(arange(x), arange(x))[Values(Vec(x))]


egraph = EGraph([cross_mod])

egraph.simplify(my_special_app(Value(10)), 10)
Value(100)
  • Different person installs cross module

  • Implements application using their complicated algorithm

…. but its too slow…

for i in range(100):
    egraph.simplify(my_special_app(Value(i)), 10)
  • Too slow in inner loop

  • Is there a way we could optimize it

4. Someone else writes a library for delayed execution#

py_mod = Module([ndarray_mod])


@py_mod.function
def py_value(s: StringLike) -> Value:
    ...


...
Ellipsis
  • While this is happening, someone else, based on the original module, wrote a different execution semantics

  • Builds up expression string instead of trying to evaluate eagerly

@py_mod.register
def _py_value(l: String, r: String):
    yield rewrite(py_value(l) + py_value(r)).to(py_value(join(l, " + ", r)))
    yield rewrite(py_value(l) * py_value(r)).to(py_value(join(l, " * ", r)))


@py_mod.function
def py_values(s: StringLike) -> Values:
    ...


@py_mod.register
def _py_values(l: String, r: String):
    yield rewrite(py_values(l)[py_value(r)]).to(py_value(join(l, "[", r, "]")))
    yield rewrite(py_values(l).length()).to(py_value(join("len(", l, ")")))
    yield rewrite(py_values(l).concat(py_values(r))).to(py_values(join(l, " + ", r)))


@py_mod.function
def py_ndarray(s: StringLike) -> NDArray:
    ...


@py_mod.register
def _py_ndarray(l: String, r: String):
    yield rewrite(py_ndarray(l)[py_values(r)]).to(py_value(join(l, "[", r, "]")))
    yield rewrite(py_ndarray(l).shape()).to(py_values(join(l, ".shape")))
    yield rewrite(arange(py_value(l))).to(py_ndarray(join("np.arange(", l, ")")))

5. I can use it jit compile my application!#

egraph = EGraph([cross_mod, py_mod])
egraph.simplify(my_special_app(py_value("x")), 10)
py_value("x * x")
  • I pull in third party library

  • Add it to my e-graph

  • Now I can compile lazily

  • py_mod never needed to know about cross product, works with it

… and add support for jit compilation for the other library I am using, without changing either library:

@egraph.register
def _(l: String, r: String):
    yield rewrite(cross(py_ndarray(l), py_ndarray(r))).to(py_ndarray(join("np.multiply.outer(", l, ", ", r, ")")))
egraph.run(20)
egraph.graphviz().render(outfile="big_graph.svg", format="svg")
'big_graph.svg'

Takeaways…#

…from this totally realistic example.

  • Declerative nature of egglog could facilitate decantralized library collaboration and experimentation.

    • Focus on types over values for library authors encourages interoperability.

  • Pushing power down, empowering users and library authors

  • Could allow greater collaboration between PL community and data science library community in Python

How does it work?#

  • Will show how some examples translate

Sorts, expressions, and functions#

%%egglog graph
(datatype Math
  (Num i64)
  (Var String)
  (Add Math Math)
  (Mul Math Math))

(define expr1 (Mul (Num 2) (Add (Var "x") (Num 3))))
(define expr2 (Add (Num 6) (Mul (Num 2) (Var "x"))))
../_images/f435898ed7ef330964f7797352b9170f1f92f106f747dad910027c9142ff55c5.svg
  • User defined sorts

  • Expressions

    • expr1 and expr2 in their own e-classes, we haven’t ran any rules

  • %%egglog magic, Writing egglog in Notebook, graphs, output inline.

egraph = EGraph()


@egraph.class_
class Num(Expr):
    @classmethod
    def var(cls, name: StringLike) -> Num:
        ...

    def __init__(self, value: i64Like) -> None:
        ...

    def __add__(self, other: Num) -> Num:
        ...

    def __mul__(self, other: Num) -> Num:
        ...


expr1 = egraph.let("expr1", Num(2) * (Num.var("x") + Num(3)))
expr2 = egraph.let("expr2", Num(6) + Num(2) * Num.var("x"))
egraph
outer_cluster_Num.__init___17615343019692007359_0 cluster_Num.__init___17615343019692007359_0 outer_cluster_Num.__init___16783941965674463102_0 cluster_Num.__init___16783941965674463102_0 outer_cluster_Num.__init___11743562013128004906_0 cluster_Num.__init___11743562013128004906_0 outer_cluster_Num.var_1976739436905633066_0 cluster_Num.var_1976739436905633066_0 outer_cluster_3 cluster_3 outer_cluster_6 cluster_6 outer_cluster_7 cluster_7 outer_cluster_0 cluster_0 outer_cluster_1 cluster_1 outer_cluster_4 cluster_4 outer_cluster_2 cluster_2 outer_cluster_5 cluster_5 Num.__add___7659469028595837896:s->Num.var_1976739436905633066 Num.__add___7659469028595837896:s->Num.__init___17615343019692007359 Num.__mul___5871781006564002453:s->Num.__init___11743562013128004906 Num.__mul___5871781006564002453:s->Num.var_1976739436905633066 Num.__add___13095445380246898500:s->Num.__mul___5871781006564002453 Num.__add___13095445380246898500:s->Num.__init___16783941965674463102 Num.__init___11743562013128004906:s->Num.__init___11743562013128004906_0 Num.var_1976739436905633066:s->Num.var_1976739436905633066_0 Num.__mul___17615343019692007359:s->Num.__add___7659469028595837896 Num.__mul___17615343019692007359:s->Num.__init___11743562013128004906 Num.__init___17615343019692007359:s->Num.__init___17615343019692007359_0 Num.__init___16783941965674463102:s->Num.__init___16783941965674463102_0 Num.__init___17615343019692007359_0 3 Num.__init___16783941965674463102_0 6 Num.__init___11743562013128004906_0 2 Num.var_1976739436905633066_0 "x" Num.__add___7659469028595837896 Num.__add__ Num.__mul___5871781006564002453 Num.__mul__ expr2_0 expr2 Num.__add___13095445380246898500 Num.__add__ Num.__init___11743562013128004906 Num.__init__ Num.var_1976739436905633066 Num.var Num.__mul___17615343019692007359 Num.__mul__ expr1_0 expr1 Num.__init___17615343019692007359 Num.__init__ Num.__init___16783941965674463102 Num.__init__
  • Re-use existing Python class and functions

    • Humans and computers to understand the typing semantics

    • Humans read __init__ and __add__.

    • Static type checkers. Num("String") it won’t work.

    • Static type checking drives much of the API design of the library

  • Operator overloading support infix operators

  • Names generated based on classes

    • Same operator on different types compile to different function with different signature

Rewrite rules and checks#

%%egglog graph continue
(rewrite (Add a b)
         (Add b a))
(rewrite (Mul a (Add b c))
         (Add (Mul a b) (Mul a c)))
(rewrite (Add (Num a) (Num b))
         (Num (+ a b)))
(rewrite (Mul (Num a) (Num b))
         (Num (* a b)))

(run 10)
(check (= expr1 expr2))
../_images/ef8ceffa7d535d49481317b3f7b4e47fe6bf6357e6fbc1215da0322eed80acd1.svg
  • See equivalent, in same e-class now

@egraph.register
def _(a: Num, b: Num, c: Num, i: i64, j: i64):
    yield rewrite(a + b).to(b + a)
    yield rewrite(a * (b + c)).to((a * b) + (a * c))
    yield rewrite(Num(i) + Num(j)).to(Num(i + j))
    yield rewrite(Num(i) * Num(j)).to(Num(i * j))


egraph.run(10)
egraph.check(eq(expr1).to(expr2))
egraph
outer_cluster_2 cluster_2 outer_cluster_0 cluster_0 outer_cluster_1 cluster_1 outer_cluster_4 cluster_4 outer_cluster_6 cluster_6 outer_cluster_10 cluster_10 outer_cluster_3 cluster_3 outer_cluster_Num.__init___17615343019692007359_0 cluster_Num.__init___17615343019692007359_0 outer_cluster_Num.__init___16783941965674463102_0 cluster_Num.__init___16783941965674463102_0 outer_cluster_Num.__init___11743562013128004906_0 cluster_Num.__init___11743562013128004906_0 outer_cluster_Num.var_1976739436905633066_0 cluster_Num.var_1976739436905633066_0 Num.__init___17615343019692007359:s->Num.__init___17615343019692007359_0 Num.__init___11743562013128004906:s->Num.__init___11743562013128004906_0 Num.var_1976739436905633066:s->Num.var_1976739436905633066_0 Num.__mul___17615343019692007359:s->Num.__init___11743562013128004906 Num.__mul___17615343019692007359:s->Num.__add___7784354942592584825 Num.__add___9842753449732275747:s->Num.__mul___5871781006564002453 Num.__add___9842753449732275747:s->Num.__init___16783941965674463102 Num.__add___11849178328430774015:s->Num.__mul___5871781006564002453 Num.__add___11849178328430774015:s->Num.__mul___11743562013128004906 Num.__mul___5871781006564002453:s->Num.__init___11743562013128004906 Num.__mul___5871781006564002453:s->Num.var_1976739436905633066 Num.__mul___11743562013128004906:s->Num.__init___17615343019692007359 Num.__mul___11743562013128004906:s->Num.__init___11743562013128004906 Num.__init___16783941965674463102:s->Num.__init___16783941965674463102_0 Num.__add___7659469028595837896:s->Num.__init___17615343019692007359 Num.__add___7659469028595837896:s->Num.var_1976739436905633066 Num.__add___7784354942592584825:s->Num.__init___17615343019692007359 Num.__add___7784354942592584825:s->Num.var_1976739436905633066 Num.__init___17615343019692007359_0 3 Num.__init___16783941965674463102_0 6 Num.__init___11743562013128004906_0 2 Num.var_1976739436905633066_0 "x" Num.__init___17615343019692007359 Num.__init__ Num.__init___11743562013128004906 Num.__init__ Num.var_1976739436905633066 Num.var Num.__mul___17615343019692007359 Num.__mul__ expr2_0 expr2 Num.__add___9842753449732275747 Num.__add__ Num.__add___11849178328430774015 Num.__add__ expr1_0 expr1 Num.__mul___5871781006564002453 Num.__mul__ Num.__mul___11743562013128004906 Num.__mul__ Num.__init___16783941965674463102 Num.__init__ Num.__add___7659469028595837896 Num.__add__ Num.__add___7784354942592584825 Num.__add__
  • Similar in Python, rewrite rules, run, check

  • Notice that all vars need types, unlike inferred in egglog

    • Both for static type checkers to verify

    • And for runtime to know what methods

Extracting lowest cost expression#

%%egglog continue output
(extract expr1)
Extracted with cost 8: (Add (Mul (Num 2) (Var "x")) (Num 6))
  • Extract lowest cost expr

egraph.extract(expr1)
(Num(2) * Num.var("x")) + Num(6)
  • get back expr object

  • Str representation is Python syntax

Multipart Rules#

%%egglog graph
(function fib (i64) i64)

(set (fib 0) 0)
(set (fib 1) 1)
(rule ((= f0 (fib x))
       (= f1 (fib (+ x 1))))
      ((set (fib (+ x 2)) (+ f0 f1))))

(run 7)
(check (= (fib 7) 13))
../_images/0fb04bb6857fef47d63eabee0a9eb22971a8e9f65a2cd8f7fc4b8c66bf7eb696.svg
  • Rule that depend on facts and execute actions

fib_egraph = EGraph()


@fib_egraph.function
def fib(x: i64Like) -> i64:
    ...


@fib_egraph.register
def _(f0: i64, f1: i64, x: i64):
    yield set_(fib(0)).to(i64(1))
    yield set_(fib(1)).to(i64(1))
    yield rule(
        eq(f0).to(fib(x)),
        eq(f1).to(fib(x + 1)),
    ).then(set_(fib(x + 2)).to(f0 + f1))


fib_egraph.run(7)
fib_egraph.check(eq(fib(7)).to(i64(21)))
  • set_ and and eq both type safe. Required builder syntax

Include & Modules#

%%writefile path.egg
(relation path (i64 i64))
(relation edge (i64 i64))

(rule ((edge x y))
      ((path x y)))

(rule ((path x y) (edge y z))
      ((path x z)))
Overwriting path.egg
%%egglog
(include "path.egg")
(edge 1 2)
(edge 2 3)
(edge 3 4)
(run 3)
(check (path 1 3))
  • Include another file for re-useability

mod = Module()
path = mod.relation("path", i64, i64)
edge = mod.relation("edge", i64, i64)


@mod.register
def _(x: i64, y: i64, z: i64):
    yield rule(edge(x, y)).then(path(x, y))
    yield rule(path(x, y), edge(y, z)).then(path(x, z))
  • Modules same in Python

  • Supports defining rules, etc, but doesn’t actually run them, just builds up commands

egraph = EGraph([mod])
egraph.register(edge(1, 2), edge(2, 3), edge(3, 4))
egraph.run(3)
egraph.check(path(1, 3))
  • Then when we depend on them, it will run those commands first.

  • Allows distribution of code and others to re-use it, using existing Python import mechanisms.

Possible next steps?#

  • Try getting toehold in existing library (like Ibis) to see if constrained egglog approach can still be useful.

    • Add support for Python objects as builtin sort.

  • Upstream egglog improvements which could help with reuse

    • First class functions (would help with implementing things like reductions, mapping)

    • User defined generic sorts (i.e. an array type agnostic to inner values)

Thank you!#

pip install egglog
from egglog import *

egraph = EGraph()
...

Come say hello at github.com/egraphs-good/egglog-python!