1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18 from pyspark import SparkContext
19 from pyspark.mllib._common import \
20 _get_unmangled_rdd, _get_unmangled_double_vector_rdd, \
21 _serialize_double_matrix, _deserialize_double_matrix, \
22 _serialize_double_vector, _deserialize_double_vector, \
23 _get_initial_weights, _serialize_rating, _regression_train_wrapper, \
24 _serialize_tuple, RatingDeserializer
25 from pyspark.rdd import RDD
29
30 """A matrix factorisation model trained by regularized alternating
31 least-squares.
32
33 >>> r1 = (1, 1, 1.0)
34 >>> r2 = (1, 2, 2.0)
35 >>> r3 = (2, 1, 2.0)
36 >>> ratings = sc.parallelize([r1, r2, r3])
37 >>> model = ALS.trainImplicit(ratings, 1)
38 >>> model.predict(2,2) is not None
39 True
40 >>> testset = sc.parallelize([(1, 2), (1, 1)])
41 >>> model.predictAll(testset).count() == 2
42 True
43 """
44
46 self._context = sc
47 self._java_model = java_model
48
50 self._context._gateway.detach(self._java_model)
51
53 return self._java_model.predict(user, product)
54
56 usersProductsJRDD = _get_unmangled_rdd(usersProducts, _serialize_tuple)
57 return RDD(self._java_model.predict(usersProductsJRDD._jrdd),
58 self._context, RatingDeserializer())
59
60
61 -class ALS(object):
62
63 """Alternating Least Squares matrix factorization.
64
65 SPARK-3990: In Spark 1.1.x, we use Kryo serialization by default in
66 PySpark. ALS does not work under this default setting. You can switch
67 back to the default Java serialization by setting:
68
69 spark.serializer=org.apache.spark.serializer.JavaSerializer
70
71 Please go to http://spark.apache.org/docs/latest/configuration.html
72 for instructions on how to configure Spark.
73 """
74
75 @classmethod
76 - def train(cls, ratings, rank, iterations=5, lambda_=0.01, blocks=-1):
77 sc = ratings.context
78 ratingBytes = _get_unmangled_rdd(ratings, _serialize_rating)
79 mod = sc._jvm.PythonMLLibAPI().trainALSModel(
80 ratingBytes._jrdd, rank, iterations, lambda_, blocks)
81 return MatrixFactorizationModel(sc, mod)
82
83 @classmethod
84 - def trainImplicit(cls, ratings, rank, iterations=5, lambda_=0.01, blocks=-1, alpha=0.01):
85 sc = ratings.context
86 ratingBytes = _get_unmangled_rdd(ratings, _serialize_rating)
87 mod = sc._jvm.PythonMLLibAPI().trainImplicitALSModel(
88 ratingBytes._jrdd, rank, iterations, lambda_, blocks, alpha)
89 return MatrixFactorizationModel(sc, mod)
90
93 import doctest
94 globs = globals().copy()
95 globs['sc'] = SparkContext('local[4]', 'PythonTest', batchSize=2)
96 (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS)
97 globs['sc'].stop()
98 if failure_count:
99 exit(-1)
100
101
102 if __name__ == "__main__":
103 _test()
104