Skip to content

Commit c11bfe9

Browse files
committed
Add instances for vector
1 parent c8b5bae commit c11bfe9

File tree

5 files changed

+115
-9
lines changed

5 files changed

+115
-9
lines changed

inline-python.cabal

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ Library
5959
, text >=2
6060
, bytestring
6161
, exceptions >=0.10
62+
, vector >=0.13
6263
hs-source-dirs: src
6364
include-dirs: include
6465
c-sources: cbits/python.c
@@ -85,11 +86,13 @@ library test
8586
QuasiQuotes
8687
build-depends: base
8788
, inline-python
88-
, tasty >=1.2
89-
, tasty-hunit >=0.10
90-
, tasty-quickcheck >=0.10
89+
, tasty >=1.2
90+
, tasty-hunit >=0.10
91+
, tasty-quickcheck >=0.10
92+
, quickcheck-instances >=0.3.32
9193
, exceptions
9294
, containers
95+
, vector
9396
hs-source-dirs: test
9497
Exposed-modules:
9598
TST.Run

src/Python/Inline/Literal.hs

Lines changed: 82 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
{-# LANGUAGE CPP #-}
12
{-# LANGUAGE ForeignFunctionInterface #-}
23
{-# LANGUAGE QuasiQuotes #-}
34
{-# LANGUAGE TemplateHaskell #-}
@@ -19,8 +20,17 @@ import Data.Bits
1920
import Data.Char
2021
import Data.Int
2122
import Data.Word
22-
import Data.Set qualified as Set
23-
import Data.Map.Strict qualified as Map
23+
import Data.Set qualified as Set
24+
import Data.Map.Strict qualified as Map
25+
import Data.Vector.Generic qualified as VG
26+
import Data.Vector.Generic.Mutable qualified as MVG
27+
import Data.Vector qualified as V
28+
#if MIN_VERSION_vector(0,13,2)
29+
import Data.Vector.Strict qualified as VV
30+
#endif
31+
import Data.Vector.Storable qualified as VS
32+
import Data.Vector.Primitive qualified as VP
33+
import Data.Vector.Unboxed qualified as VU
2434
import Foreign.Ptr
2535
import Foreign.C.Types
2636
import Foreign.Storable
@@ -436,6 +446,43 @@ instance (FromPy k, FromPy v, Ord k) => FromPy (Map.Map k v) where
436446
pure $! Map.insert k v m)
437447
Map.empty
438448

449+
-- | Converts to python's list
450+
instance ToPy a => ToPy (V.Vector a) where
451+
basicToPy = vectorToPy
452+
-- | Converts to python's list
453+
instance (ToPy a, VS.Storable a) => ToPy (VS.Vector a) where
454+
basicToPy = vectorToPy
455+
-- | Converts to python's list
456+
instance (ToPy a, VP.Prim a) => ToPy (VP.Vector a) where
457+
basicToPy = vectorToPy
458+
-- | Converts to python's list
459+
instance (ToPy a, VU.Unbox a) => ToPy (VU.Vector a) where
460+
basicToPy = vectorToPy
461+
#if MIN_VERSION_vector(0,13,2)
462+
-- | Converts to python's list
463+
instance (ToPy a) => ToPy (VV.Vector a) where
464+
basicToPy = vectorToPy
465+
#endif
466+
467+
-- | Accepts python's sequence (@len@ and indexing)
468+
instance FromPy a => FromPy (V.Vector a) where
469+
basicFromPy = vectorFromPy
470+
-- | Accepts python's sequence (@len@ and indexing)
471+
instance (FromPy a, VS.Storable a) => FromPy (VS.Vector a) where
472+
basicFromPy = vectorFromPy
473+
-- | Accepts python's sequence (@len@ and indexing)
474+
instance (FromPy a, VP.Prim a) => FromPy (VP.Vector a) where
475+
basicFromPy = vectorFromPy
476+
-- | Accepts python's sequence (@len@ and indexing)
477+
instance (FromPy a, VU.Unbox a) => FromPy (VU.Vector a) where
478+
basicFromPy = vectorFromPy
479+
#if MIN_VERSION_vector(0,13,2)
480+
-- | Accepts python's sequence (@len@ and indexing)
481+
instance FromPy a => FromPy (VV.Vector a) where
482+
basicFromPy = vectorFromPy
483+
#endif
484+
485+
439486
-- | Fold over iterable. Function takes ownership over iterator.
440487
foldPyIterable
441488
:: Ptr PyObject -- ^ Python iterator (not checked)
@@ -450,6 +497,39 @@ foldPyIterable p_iter step a0
450497
p -> loop =<< (step a p `finally` decref p)
451498

452499

500+
vectorFromPy :: (VG.Vector v a, FromPy a) => Ptr PyObject -> Py (v a)
501+
{-# INLINE vectorFromPy #-}
502+
vectorFromPy p_seq = do
503+
len <- Py [CU.exp| long long { PySequence_Size($(PyObject* p_seq)) } |]
504+
when (len < 0) $ do
505+
Py [C.exp| void { PyErr_Clear() } |]
506+
throwM BadPyType
507+
-- Read data into vector
508+
buf <- MVG.generateM (fromIntegral len) $ \i -> do
509+
let i_c = fromIntegral i
510+
Py [CU.exp| PyObject* { PySequence_GetItem($(PyObject* p_seq), $(long long i_c)) } |] >>= \case
511+
NULL -> mustThrowPyError
512+
p -> basicFromPy p `finally` decref p
513+
VG.unsafeFreeze buf
514+
515+
vectorToPy :: (VG.Vector v a, ToPy a) => v a -> Py (Ptr PyObject)
516+
vectorToPy vec = runProgram $ do
517+
p_list <- takeOwnership =<< checkNull (Py [CU.exp| PyObject* { PyList_New($(long long n_c)) } |])
518+
progPy $ do
519+
let loop i
520+
| i >= n = p_list <$ incref p_list
521+
| otherwise = basicToPy (VG.unsafeIndex vec i) >>= \case
522+
NULL -> pure nullPtr
523+
p_a -> do
524+
let i_c = fromIntegral i :: CLLong
525+
-- NOTE: PyList_SET_ITEM steals reference
526+
Py [CU.exp| void { PyList_SET_ITEM($(PyObject* p_list), $(long long i_c), $(PyObject* p_a)) } |]
527+
loop (i+1)
528+
loop 0
529+
where
530+
n = VG.length vec
531+
n_c = fromIntegral n :: CLLong
532+
453533
----------------------------------------------------------------
454534
-- Functions marshalling
455535
----------------------------------------------------------------

src/Python/Internal/Types.hs

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
{-# LANGUAGE OverloadedStrings #-}
22
{-# LANGUAGE QuasiQuotes #-}
33
{-# LANGUAGE TemplateHaskell #-}
4+
{-# LANGUAGE TypeFamilies #-}
45
-- |
56
-- Definition of data types used by inline-python. They are moved to
67
-- separate module since some are required for @inline-c@'s context
@@ -26,10 +27,11 @@ module Python.Internal.Types
2627

2728
import Control.Monad.IO.Class
2829
import Control.Monad.Catch
30+
import Control.Monad.Primitive (PrimMonad(..),RealWorld)
2931
import Control.Exception
3032
import Data.Coerce
3133
import Data.Int
32-
import Data.Map.Strict qualified as Map
34+
import Data.Map.Strict qualified as Map
3335
import Foreign.Ptr
3436
import Foreign.C.Types
3537
import GHC.ForeignPtr
@@ -106,6 +108,11 @@ pyIO = Py
106108
instance MonadIO Py where
107109
liftIO = Py . interruptible
108110

111+
instance PrimMonad Py where
112+
type PrimState Py = RealWorld
113+
primitive = Py . primitive
114+
{-# INLINE primitive #-}
115+
109116

110117
----------------------------------------------------------------
111118
-- inline-C

test/TST/FromPy.hs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@ import Test.Tasty
77
import Test.Tasty.HUnit
88
import Python.Inline
99
import Python.Inline.QQ
10-
import Python.Inline.Types
1110

1211
tests :: TestTree
1312
tests = testGroup "FromPy"

test/TST/Roundtrip.hs

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
{-# LANGUAGE AllowAmbiguousTypes #-}
2+
{-# LANGUAGE CPP #-}
23
-- |
34
module TST.Roundtrip (tests) where
45

@@ -11,9 +12,19 @@ import Foreign.C.Types
1112

1213
import Test.Tasty
1314
import Test.Tasty.QuickCheck
15+
import Test.QuickCheck.Instances.Vector ()
1416
import Python.Inline
1517
import Python.Inline.QQ
1618

19+
import Data.Vector qualified as V
20+
#if MIN_VERSION_vector(0,13,2)
21+
import Data.Vector.Strict qualified as VV
22+
#endif
23+
import Data.Vector.Storable qualified as VS
24+
import Data.Vector.Primitive qualified as VP
25+
import Data.Vector.Unboxed qualified as VU
26+
27+
1728
tests :: TestTree
1829
tests = testGroup "Roundtrip"
1930
[ testGroup "Roundtrip"
@@ -56,6 +67,13 @@ tests = testGroup "Roundtrip"
5667
, testRoundtrip @(Set Int)
5768
, testRoundtrip @(Map Int Int)
5869
-- , testRoundtrip @String -- Trips on zero byte as it should
70+
, testRoundtrip @(V.Vector Int)
71+
, testRoundtrip @(VS.Vector Int)
72+
, testRoundtrip @(VP.Vector Int)
73+
, testRoundtrip @(VU.Vector Int)
74+
#if MIN_VERSION_vector(0,13,2)
75+
-- , testRoundtrip @(VV.Vector Int)
76+
#endif
5977
]
6078
, testGroup "OutOfRange"
6179
[ testOutOfRange @Int8 @Int16
@@ -66,7 +84,7 @@ tests = testGroup "Roundtrip"
6684
, testOutOfRange @Word32 @Word64
6785
]
6886
]
69-
87+
7088
testRoundtrip
7189
:: forall a. (FromPy a, ToPy a, Eq a, Arbitrary a, Show a, Typeable a) => TestTree
7290
testRoundtrip = testProperty (show (typeOf (undefined :: a))) (propRoundtrip @a)
@@ -98,4 +116,3 @@ propOutOfRange wide = ioProperty $ do
98116
a_hs = case fromIntegral wide :: a of
99117
a' | fromIntegral a' == wide -> Just a'
100118
| otherwise -> Nothing
101-

0 commit comments

Comments
 (0)