from sqlalchemy.orm.interfaces import AttributeExtension from sqlalchemy.orm.properties import ColumnProperty from sqlalchemy.types import TypeEngine from sqlalchemy.sql import expression # Python datatypes class GisElement(object): """Represents a geometry value.""" @property def wkt(self): return func.AsText(literal(self, Geometry)) @property def wkb(self): return func.AsBinary(literal(self, Geometry)) def __str__(self): return self.desc def __repr__(self): return "<%s at 0x%x; %r>" % (self.__class__.__name__, id(self), self.desc) class PersistentGisElement(GisElement): """Represents a Geometry value as loaded from the database.""" def __init__(self, desc): self.desc = desc class TextualGisElement(GisElement, expression.Function): """Represents a Geometry value as expressed within application code; i.e. in wkt format. Extends expression.Function so that the value is interpreted as GeomFromText(value) in a SQL expression context. """ def __init__(self, desc, srid=-1): assert isinstance(desc, basestring) self.desc = desc expression.Function.__init__(self, "GeomFromText", desc, srid) # SQL datatypes. class Geometry(TypeEngine): """Base PostGIS Geometry column type. Converts bind/result values to/from a PersistentGisElement. """ name = 'GEOMETRY' def __init__(self, dimension=None, srid=-1): self.dimension = dimension self.srid = srid def bind_processor(self, dialect): def process(value): if value is not None: return value.desc else: return value return process def result_processor(self, dialect, coltype): def process(value): if value is not None: return PersistentGisElement(value) else: return value return process # other datatypes can be added as needed, which # currently only affect DDL statements. class Point(Geometry): name = 'POINT' class Curve(Geometry): name = 'CURVE' class LineString(Curve): name = 'LINESTRING' # ... etc. # DDL integration class GISDDL(object): """A DDL extension which integrates SQLAlchemy table create/drop methods with PostGis' AddGeometryColumn/DropGeometryColumn functions. Usage:: sometable = Table('sometable', metadata, ...) GISDDL(sometable) sometable.create() """ def __init__(self, table): for event in ('before-create', 'after-create', 'before-drop', 'after-drop'): table.ddl_listeners[event].append(self) self._stack = [] def __call__(self, event, table, bind): if event in ('before-create', 'before-drop'): regular_cols = [c for c in table.c if not isinstance(c.type, Geometry)] gis_cols = set(table.c).difference(regular_cols) self._stack.append(table.c) table._columns = expression.ColumnCollection(*regular_cols) if event == 'before-drop': for c in gis_cols: bind.execute(select([func.DropGeometryColumn('public', table.name, c.name)], autocommit=True)) elif event == 'after-create': table._columns = self._stack.pop() for c in table.c: if isinstance(c.type, Geometry): bind.execute(select([func.AddGeometryColumn(table.name, c.name, c.type.srid, c.type.name, c.type.dimension)], autocommit=True)) elif event == 'after-drop': table._columns = self._stack.pop() # ORM integration def _to_postgis(value): """Interpret a value as a GIS-compatible construct.""" if hasattr(value, '__clause_element__'): return value.__clause_element__() elif isinstance(value, (expression.ClauseElement, GisElement)): return value elif isinstance(value, basestring): return TextualGisElement(value) elif value is None: return None else: raise Exception("Invalid type") class GisAttribute(AttributeExtension): """Intercepts 'set' events on a mapped instance attribute and converts the incoming value to a GIS expression. """ def set(self, state, value, oldvalue, initiator): return _to_postgis(value) class GisComparator(ColumnProperty.ColumnComparator): """Intercepts standard Column operators on mapped class attributes and overrides their behavior. """ # override the __eq__() operator def __eq__(self, other): return self.__clause_element__().op('~=')(_to_postgis(other)) # add a custom operator def intersects(self, other): return self.__clause_element__().op('&&')(_to_postgis(other)) # any number of GIS operators can be overridden/added here # using the techniques above. def GISColumn(*args, **kw): """Define a declarative column property with GIS behavior. This just produces orm.column_property() with the appropriate extension and comparator_factory arguments. The given arguments are passed through to Column. The declarative module extracts the Column for inclusion in the mapped table. """ return column_property( Column(*args, **kw), extension=GisAttribute(), comparator_factory=GisComparator ) # illustrate usage if __name__ == '__main__': from sqlalchemy import (create_engine, MetaData, Column, Integer, String, func, literal, select) from sqlalchemy.orm import sessionmaker, column_property from sqlalchemy.ext.declarative import declarative_base engine = create_engine('postgresql://scott:tiger@localhost/gistest', echo=True) metadata = MetaData(engine) Base = declarative_base(metadata=metadata) class Road(Base): __tablename__ = 'roads' road_id = Column(Integer, primary_key=True) road_name = Column(String) road_geom = GISColumn(Geometry(2)) # enable the DDL extension, which allows CREATE/DROP operations # to work correctly. This is not needed if working with externally # defined tables. GISDDL(Road.__table__) metadata.drop_all() metadata.create_all() session = sessionmaker(bind=engine)() # Add objects. We can use strings... session.add_all([ Road(road_name='Jeff Rd', road_geom='LINESTRING(191232 243118,191108 243242)'), Road(road_name='Geordie Rd', road_geom='LINESTRING(189141 244158,189265 244817)'), Road(road_name='Paul St', road_geom='LINESTRING(192783 228138,192612 229814)'), Road(road_name='Graeme Ave', road_geom='LINESTRING(189412 252431,189631 259122)'), Road(road_name='Phil Tce', road_geom='LINESTRING(190131 224148,190871 228134)'), ]) # or use an explicit TextualGisElement (similar to saying func.GeomFromText()) r = Road(road_name='Dave Cres', road_geom=TextualGisElement('LINESTRING(198231 263418,198213 268322)', -1)) session.add(r) # pre flush, the TextualGisElement represents the string we sent. assert str(r.road_geom) == 'LINESTRING(198231 263418,198213 268322)' assert session.scalar(r.road_geom.wkt) == 'LINESTRING(198231 263418,198213 268322)' session.commit() # after flush and/or commit, all the TextualGisElements become PersistentGisElements. assert str(r.road_geom) == "01020000000200000000000000B832084100000000E813104100000000283208410000000088601041" r1 = session.query(Road).filter(Road.road_name=='Graeme Ave').one() # illustrate the overridden __eq__() operator. # strings come in as TextualGisElements r2 = session.query(Road).filter(Road.road_geom == 'LINESTRING(189412 252431,189631 259122)').one() # PersistentGisElements work directly r3 = session.query(Road).filter(Road.road_geom == r1.road_geom).one() assert r1 is r2 is r3 # illustrate the "intersects" operator print session.query(Road).filter(Road.road_geom.intersects(r1.road_geom)).all() # illustrate usage of the "wkt" accessor. this requires a DB # execution to call the AsText() function so we keep this explicit. assert session.scalar(r1.road_geom.wkt) == 'LINESTRING(189412 252431,189631 259122)' session.rollback() metadata.drop_all()