Use CompositeTypeΒΆ

Some functions return composite types. This example shows how to deal with this kind of functions.

  8 import pytest
  9 from pkg_resources import parse_version
 10 from sqlalchemy import Column
 11 from sqlalchemy import Float
 12 from sqlalchemy import Integer
 13 from sqlalchemy import MetaData
 14 from sqlalchemy import __version__ as SA_VERSION
 15 from sqlalchemy.ext.declarative import declarative_base
 16
 17 from geoalchemy2 import Raster
 18 from geoalchemy2 import WKTElement
 19 from geoalchemy2.functions import GenericFunction
 20 from geoalchemy2.types import CompositeType
 21
 22 # Tests imports
 23 from tests import select
 24 from tests import test_only_with_dialects
 25
 26
 27 class SummaryStatsCustomType(CompositeType):
 28     """Define the composite type returned by the function ST_SummaryStatsAgg."""
 29     typemap = {
 30         'count': Integer,
 31         'sum': Float,
 32         'mean': Float,
 33         'stddev': Float,
 34         'min': Float,
 35         'max': Float,
 36     }
 37
 38     cache_ok = True
 39
 40
 41 class ST_SummaryStatsAgg(GenericFunction):
 42     type = SummaryStatsCustomType
 43     # Set a specific identifier to not override the actual ST_SummaryStatsAgg function
 44     identifier = "ST_SummaryStatsAgg_custom"
 45
 46     inherit_cache = True
 47
 48
 49 metadata = MetaData()
 50 Base = declarative_base(metadata=metadata)
 51
 52
 53 class Ocean(Base):
 54     __tablename__ = 'ocean'
 55     id = Column(Integer, primary_key=True)
 56     rast = Column(Raster)
 57
 58     def __init__(self, rast):
 59         self.rast = rast
 60
 61
 62 @test_only_with_dialects("postgresql")
 63 class TestSTSummaryStatsAgg():
 64
 65     @pytest.mark.skipif(
 66         parse_version(SA_VERSION) < parse_version("1.4"),
 67         reason="requires SQLAlchely>1.4",
 68     )
 69     def test_st_summary_stats_agg(self, session, conn):
 70         metadata.drop_all(conn, checkfirst=True)
 71         metadata.create_all(conn)
 72
 73         # Create a new raster
 74         polygon = WKTElement('POLYGON((0 0,1 1,0 1,0 0))', srid=4326)
 75         o = Ocean(polygon.ST_AsRaster(5, 6))
 76         session.add(o)
 77         session.flush()
 78
 79         # Define the query to compute stats
 80         stats_agg = select([
 81             Ocean.rast.ST_SummaryStatsAgg_custom(1, True, 1).label("stats")
 82         ])
 83         stats_agg_alias = stats_agg.alias("stats_agg")
 84
 85         # Use these stats
 86         query = select([
 87             stats_agg_alias.c.stats.count.label("count"),
 88             stats_agg_alias.c.stats.sum.label("sum"),
 89             stats_agg_alias.c.stats.mean.label("mean"),
 90             stats_agg_alias.c.stats.stddev.label("stddev"),
 91             stats_agg_alias.c.stats.min.label("min"),
 92             stats_agg_alias.c.stats.max.label("max")
 93         ])
 94
 95         # Check the query
 96         assert str(query.compile(dialect=session.bind.dialect)) == (
 97             "SELECT "
 98             "(stats_agg.stats).count AS count, "
 99             "(stats_agg.stats).sum AS sum, "
100             "(stats_agg.stats).mean AS mean, "
101             "(stats_agg.stats).stddev AS stddev, "
102             "(stats_agg.stats).min AS min, "
103             "(stats_agg.stats).max AS max \n"
104             "FROM ("
105             "SELECT "
106             "ST_SummaryStatsAgg("
107             "ocean.rast, "
108             "%(ST_SummaryStatsAgg_1)s, %(ST_SummaryStatsAgg_2)s, %(ST_SummaryStatsAgg_3)s"
109             ") AS stats \n"
110             "FROM ocean) AS stats_agg"
111         )
112
113         # Execute the query
114         res = session.execute(query).fetchall()
115
116         # Check the result
117         assert res == [(15, 15.0, 1.0, 0.0, 1.0, 1.0)]

Gallery generated by Sphinx-Gallery