1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19 import sys
20 import types
21 import itertools
22 import warnings
23 import decimal
24 import datetime
25 import keyword
26 import warnings
27 from array import array
28 from operator import itemgetter
29
30 from pyspark.rdd import RDD, PipelinedRDD
31 from pyspark.serializers import BatchedSerializer, PickleSerializer, CloudPickleSerializer
32
33 from itertools import chain, ifilter, imap
34
35 from py4j.protocol import Py4JError
36 from py4j.java_collections import ListConverter, MapConverter
37
38
39 __all__ = [
40 "StringType", "BinaryType", "BooleanType", "TimestampType", "DecimalType",
41 "DoubleType", "FloatType", "ByteType", "IntegerType", "LongType",
42 "ShortType", "ArrayType", "MapType", "StructField", "StructType",
43 "SQLContext", "HiveContext", "LocalHiveContext", "TestHiveContext",
44 "SchemaRDD", "Row"]
48
49 """Spark SQL DataType"""
50
52 return self.__class__.__name__
53
55 return hash(str(self))
56
58 return (isinstance(other, self.__class__) and
59 self.__dict__ == other.__dict__)
60
62 return not self.__eq__(other)
63
66
67 """Metaclass for PrimitiveType"""
68
69 _instances = {}
70
75
78
79 """Spark SQL PrimitiveType"""
80
81 __metaclass__ = PrimitiveTypeSingleton
82
84
85 return self is other
86
89
90 """Spark SQL StringType
91
92 The data type representing string values.
93 """
94
97
98 """Spark SQL BinaryType
99
100 The data type representing bytearray values.
101 """
102
105
106 """Spark SQL BooleanType
107
108 The data type representing bool values.
109 """
110
113
114 """Spark SQL TimestampType
115
116 The data type representing datetime.datetime values.
117 """
118
121
122 """Spark SQL DecimalType
123
124 The data type representing decimal.Decimal values.
125 """
126
129
130 """Spark SQL DoubleType
131
132 The data type representing float values.
133 """
134
137
138 """Spark SQL FloatType
139
140 The data type representing single precision floating-point values.
141 """
142
145
146 """Spark SQL ByteType
147
148 The data type representing int values with 1 singed byte.
149 """
150
153
154 """Spark SQL IntegerType
155
156 The data type representing int values.
157 """
158
161
162 """Spark SQL LongType
163
164 The data type representing long values. If the any value is
165 beyond the range of [-9223372036854775808, 9223372036854775807],
166 please use DecimalType.
167 """
168
171
172 """Spark SQL ShortType
173
174 The data type representing int values with 2 signed bytes.
175 """
176
179
180 """Spark SQL ArrayType
181
182 The data type representing list values. An ArrayType object
183 comprises two fields, elementType (a DataType) and containsNull (a bool).
184 The field of elementType is used to specify the type of array elements.
185 The field of containsNull is used to specify if the array has None values.
186
187 """
188
189 - def __init__(self, elementType, containsNull=True):
190 """Creates an ArrayType
191
192 :param elementType: the data type of elements.
193 :param containsNull: indicates whether the list contains None values.
194
195 >>> ArrayType(StringType) == ArrayType(StringType, True)
196 True
197 >>> ArrayType(StringType, False) == ArrayType(StringType)
198 False
199 """
200 self.elementType = elementType
201 self.containsNull = containsNull
202
204 return "ArrayType(%s,%s)" % (self.elementType,
205 str(self.containsNull).lower())
206
209
210 """Spark SQL MapType
211
212 The data type representing dict values. A MapType object comprises
213 three fields, keyType (a DataType), valueType (a DataType) and
214 valueContainsNull (a bool).
215
216 The field of keyType is used to specify the type of keys in the map.
217 The field of valueType is used to specify the type of values in the map.
218 The field of valueContainsNull is used to specify if values of this
219 map has None values.
220
221 For values of a MapType column, keys are not allowed to have None values.
222
223 """
224
225 - def __init__(self, keyType, valueType, valueContainsNull=True):
226 """Creates a MapType
227 :param keyType: the data type of keys.
228 :param valueType: the data type of values.
229 :param valueContainsNull: indicates whether values contains
230 null values.
231
232 >>> (MapType(StringType, IntegerType)
233 ... == MapType(StringType, IntegerType, True))
234 True
235 >>> (MapType(StringType, IntegerType, False)
236 ... == MapType(StringType, FloatType))
237 False
238 """
239 self.keyType = keyType
240 self.valueType = valueType
241 self.valueContainsNull = valueContainsNull
242
244 return "MapType(%s,%s,%s)" % (self.keyType, self.valueType,
245 str(self.valueContainsNull).lower())
246
249
250 """Spark SQL StructField
251
252 Represents a field in a StructType.
253 A StructField object comprises three fields, name (a string),
254 dataType (a DataType) and nullable (a bool). The field of name
255 is the name of a StructField. The field of dataType specifies
256 the data type of a StructField.
257
258 The field of nullable specifies if values of a StructField can
259 contain None values.
260
261 """
262
263 - def __init__(self, name, dataType, nullable):
264 """Creates a StructField
265 :param name: the name of this field.
266 :param dataType: the data type of this field.
267 :param nullable: indicates whether values of this field
268 can be null.
269
270 >>> (StructField("f1", StringType, True)
271 ... == StructField("f1", StringType, True))
272 True
273 >>> (StructField("f1", StringType, True)
274 ... == StructField("f2", StringType, True))
275 False
276 """
277 self.name = name
278 self.dataType = dataType
279 self.nullable = nullable
280
282 return "StructField(%s,%s,%s)" % (self.name, self.dataType,
283 str(self.nullable).lower())
284
287
288 """Spark SQL StructType
289
290 The data type representing rows.
291 A StructType object comprises a list of L{StructField}s.
292
293 """
294
296 """Creates a StructType
297
298 >>> struct1 = StructType([StructField("f1", StringType, True)])
299 >>> struct2 = StructType([StructField("f1", StringType, True)])
300 >>> struct1 == struct2
301 True
302 >>> struct1 = StructType([StructField("f1", StringType, True)])
303 >>> struct2 = StructType([StructField("f1", StringType, True),
304 ... [StructField("f2", IntegerType, False)]])
305 >>> struct1 == struct2
306 False
307 """
308 self.fields = fields
309
311 return ("StructType(List(%s))" %
312 ",".join(str(field) for field in self.fields))
313
316 """Parses a list of comma separated data types."""
317 index = 0
318 datatype_list = []
319 start = 0
320 depth = 0
321 while index < len(datatype_list_string):
322 if depth == 0 and datatype_list_string[index] == ",":
323 datatype_string = datatype_list_string[start:index].strip()
324 datatype_list.append(_parse_datatype_string(datatype_string))
325 start = index + 1
326 elif datatype_list_string[index] == "(":
327 depth += 1
328 elif datatype_list_string[index] == ")":
329 depth -= 1
330
331 index += 1
332
333
334 datatype_string = datatype_list_string[start:index].strip()
335 datatype_list.append(_parse_datatype_string(datatype_string))
336 return datatype_list
337
338
339 _all_primitive_types = dict((k, v) for k, v in globals().iteritems()
340 if type(v) is PrimitiveTypeSingleton and v.__base__ == PrimitiveType)
344 """Parses the given data type string.
345
346 >>> def check_datatype(datatype):
347 ... scala_datatype = sqlCtx._ssql_ctx.parseDataType(str(datatype))
348 ... python_datatype = _parse_datatype_string(
349 ... scala_datatype.toString())
350 ... return datatype == python_datatype
351 >>> all(check_datatype(cls()) for cls in _all_primitive_types.values())
352 True
353 >>> # Simple ArrayType.
354 >>> simple_arraytype = ArrayType(StringType(), True)
355 >>> check_datatype(simple_arraytype)
356 True
357 >>> # Simple MapType.
358 >>> simple_maptype = MapType(StringType(), LongType())
359 >>> check_datatype(simple_maptype)
360 True
361 >>> # Simple StructType.
362 >>> simple_structtype = StructType([
363 ... StructField("a", DecimalType(), False),
364 ... StructField("b", BooleanType(), True),
365 ... StructField("c", LongType(), True),
366 ... StructField("d", BinaryType(), False)])
367 >>> check_datatype(simple_structtype)
368 True
369 >>> # Complex StructType.
370 >>> complex_structtype = StructType([
371 ... StructField("simpleArray", simple_arraytype, True),
372 ... StructField("simpleMap", simple_maptype, True),
373 ... StructField("simpleStruct", simple_structtype, True),
374 ... StructField("boolean", BooleanType(), False)])
375 >>> check_datatype(complex_structtype)
376 True
377 >>> # Complex ArrayType.
378 >>> complex_arraytype = ArrayType(complex_structtype, True)
379 >>> check_datatype(complex_arraytype)
380 True
381 >>> # Complex MapType.
382 >>> complex_maptype = MapType(complex_structtype,
383 ... complex_arraytype, False)
384 >>> check_datatype(complex_maptype)
385 True
386 """
387 index = datatype_string.find("(")
388 if index == -1:
389
390 index = len(datatype_string)
391 type_or_field = datatype_string[:index]
392 rest_part = datatype_string[index + 1:len(datatype_string) - 1].strip()
393
394 if type_or_field in _all_primitive_types:
395 return _all_primitive_types[type_or_field]()
396
397 elif type_or_field == "ArrayType":
398 last_comma_index = rest_part.rfind(",")
399 containsNull = True
400 if rest_part[last_comma_index + 1:].strip().lower() == "false":
401 containsNull = False
402 elementType = _parse_datatype_string(
403 rest_part[:last_comma_index].strip())
404 return ArrayType(elementType, containsNull)
405
406 elif type_or_field == "MapType":
407 last_comma_index = rest_part.rfind(",")
408 valueContainsNull = True
409 if rest_part[last_comma_index + 1:].strip().lower() == "false":
410 valueContainsNull = False
411 keyType, valueType = _parse_datatype_list(
412 rest_part[:last_comma_index].strip())
413 return MapType(keyType, valueType, valueContainsNull)
414
415 elif type_or_field == "StructField":
416 first_comma_index = rest_part.find(",")
417 name = rest_part[:first_comma_index].strip()
418 last_comma_index = rest_part.rfind(",")
419 nullable = True
420 if rest_part[last_comma_index + 1:].strip().lower() == "false":
421 nullable = False
422 dataType = _parse_datatype_string(
423 rest_part[first_comma_index + 1:last_comma_index].strip())
424 return StructField(name, dataType, nullable)
425
426 elif type_or_field == "StructType":
427
428
429 field_list_string = rest_part[rest_part.find("(") + 1:-1]
430 fields = _parse_datatype_list(field_list_string)
431 return StructType(fields)
432
433
434
435 _type_mappings = {
436 bool: BooleanType,
437 int: IntegerType,
438 long: LongType,
439 float: DoubleType,
440 str: StringType,
441 unicode: StringType,
442 decimal.Decimal: DecimalType,
443 datetime.datetime: TimestampType,
444 datetime.date: TimestampType,
445 datetime.time: TimestampType,
446 }
450 """Infer the DataType from obj"""
451 if obj is None:
452 raise ValueError("Can not infer type for None")
453
454 dataType = _type_mappings.get(type(obj))
455 if dataType is not None:
456 return dataType()
457
458 if isinstance(obj, dict):
459 if not obj:
460 raise ValueError("Can not infer type for empty dict")
461 key, value = obj.iteritems().next()
462 return MapType(_infer_type(key), _infer_type(value), True)
463 elif isinstance(obj, (list, array)):
464 if not obj:
465 raise ValueError("Can not infer type for empty list/array")
466 return ArrayType(_infer_type(obj[0]), True)
467 else:
468 try:
469 return _infer_schema(obj)
470 except ValueError:
471 raise ValueError("not supported type: %s" % type(obj))
472
475 """Infer the schema from dict/namedtuple/object"""
476 if isinstance(row, dict):
477 items = sorted(row.items())
478
479 elif isinstance(row, tuple):
480 if hasattr(row, "_fields"):
481 items = zip(row._fields, tuple(row))
482 elif hasattr(row, "__FIELDS__"):
483 items = zip(row.__FIELDS__, tuple(row))
484 elif all(isinstance(x, tuple) and len(x) == 2 for x in row):
485 items = row
486 else:
487 raise ValueError("Can't infer schema from tuple")
488
489 elif hasattr(row, "__dict__"):
490 items = sorted(row.__dict__.items())
491
492 else:
493 raise ValueError("Can not infer schema for type: %s" % type(row))
494
495 fields = [StructField(k, _infer_type(v), True) for k, v in items]
496 return StructType(fields)
497
500 """Create an converter to drop the names of fields in obj """
501 if isinstance(dataType, ArrayType):
502 conv = _create_converter(obj[0], dataType.elementType)
503 return lambda row: map(conv, row)
504
505 elif isinstance(dataType, MapType):
506 value = obj.values()[0]
507 conv = _create_converter(value, dataType.valueType)
508 return lambda row: dict((k, conv(v)) for k, v in row.iteritems())
509
510 elif not isinstance(dataType, StructType):
511 return lambda x: x
512
513
514 names = [f.name for f in dataType.fields]
515
516 if isinstance(obj, dict):
517 conv = lambda o: tuple(o.get(n) for n in names)
518
519 elif isinstance(obj, tuple):
520 if hasattr(obj, "_fields"):
521 conv = tuple
522 elif hasattr(obj, "__FIELDS__"):
523 conv = tuple
524 elif all(isinstance(x, tuple) and len(x) == 2 for x in obj):
525 conv = lambda o: tuple(v for k, v in o)
526 else:
527 raise ValueError("unexpected tuple")
528
529 elif hasattr(obj, "__dict__"):
530 conv = lambda o: [o.__dict__.get(n, None) for n in names]
531
532 if all(isinstance(f.dataType, PrimitiveType) for f in dataType.fields):
533 return conv
534
535 row = conv(obj)
536 convs = [_create_converter(v, f.dataType)
537 for v, f in zip(row, dataType.fields)]
538
539 def nested_conv(row):
540 return tuple(f(v) for f, v in zip(convs, conv(row)))
541
542 return nested_conv
543
546 """ all the names of fields, becoming tuples"""
547 iterator = iter(rows)
548 row = iterator.next()
549 converter = _create_converter(row, schema)
550 yield converter(row)
551 for i in iterator:
552 yield converter(i)
553
554
555 _BRACKETS = {'(': ')', '[': ']', '{': '}'}
559 """
560 split the schema abstract into fields
561
562 >>> _split_schema_abstract("a b c")
563 ['a', 'b', 'c']
564 >>> _split_schema_abstract("a(a b)")
565 ['a(a b)']
566 >>> _split_schema_abstract("a b[] c{a b}")
567 ['a', 'b[]', 'c{a b}']
568 >>> _split_schema_abstract(" ")
569 []
570 """
571
572 r = []
573 w = ''
574 brackets = []
575 for c in s:
576 if c == ' ' and not brackets:
577 if w:
578 r.append(w)
579 w = ''
580 else:
581 w += c
582 if c in _BRACKETS:
583 brackets.append(c)
584 elif c in _BRACKETS.values():
585 if not brackets or c != _BRACKETS[brackets.pop()]:
586 raise ValueError("unexpected " + c)
587
588 if brackets:
589 raise ValueError("brackets not closed: %s" % brackets)
590 if w:
591 r.append(w)
592 return r
593
596 """
597 Parse a field in schema abstract
598
599 >>> _parse_field_abstract("a")
600 StructField(a,None,true)
601 >>> _parse_field_abstract("b(c d)")
602 StructField(b,StructType(...c,None,true),StructField(d...
603 >>> _parse_field_abstract("a[]")
604 StructField(a,ArrayType(None,true),true)
605 >>> _parse_field_abstract("a{[]}")
606 StructField(a,MapType(None,ArrayType(None,true),true),true)
607 """
608 if set(_BRACKETS.keys()) & set(s):
609 idx = min((s.index(c) for c in _BRACKETS if c in s))
610 name = s[:idx]
611 return StructField(name, _parse_schema_abstract(s[idx:]), True)
612 else:
613 return StructField(s, None, True)
614
617 """
618 parse abstract into schema
619
620 >>> _parse_schema_abstract("a b c")
621 StructType...a...b...c...
622 >>> _parse_schema_abstract("a[b c] b{}")
623 StructType...a,ArrayType...b...c...b,MapType...
624 >>> _parse_schema_abstract("c{} d{a b}")
625 StructType...c,MapType...d,MapType...a...b...
626 >>> _parse_schema_abstract("a b(t)").fields[1]
627 StructField(b,StructType(List(StructField(t,None,true))),true)
628 """
629 s = s.strip()
630 if not s:
631 return
632
633 elif s.startswith('('):
634 return _parse_schema_abstract(s[1:-1])
635
636 elif s.startswith('['):
637 return ArrayType(_parse_schema_abstract(s[1:-1]), True)
638
639 elif s.startswith('{'):
640 return MapType(None, _parse_schema_abstract(s[1:-1]))
641
642 parts = _split_schema_abstract(s)
643 fields = [_parse_field_abstract(p) for p in parts]
644 return StructType(fields)
645
648 """
649 Fill the dataType with types infered from obj
650
651 >>> schema = _parse_schema_abstract("a b c")
652 >>> row = (1, 1.0, "str")
653 >>> _infer_schema_type(row, schema)
654 StructType...IntegerType...DoubleType...StringType...
655 >>> row = [[1], {"key": (1, 2.0)}]
656 >>> schema = _parse_schema_abstract("a[] b{c d}")
657 >>> _infer_schema_type(row, schema)
658 StructType...a,ArrayType...b,MapType(StringType,...c,IntegerType...
659 """
660 if dataType is None:
661 return _infer_type(obj)
662
663 if not obj:
664 raise ValueError("Can not infer type from empty value")
665
666 if isinstance(dataType, ArrayType):
667 eType = _infer_schema_type(obj[0], dataType.elementType)
668 return ArrayType(eType, True)
669
670 elif isinstance(dataType, MapType):
671 k, v = obj.iteritems().next()
672 return MapType(_infer_type(k),
673 _infer_schema_type(v, dataType.valueType))
674
675 elif isinstance(dataType, StructType):
676 fs = dataType.fields
677 assert len(fs) == len(obj), \
678 "Obj(%s) have different length with fields(%s)" % (obj, fs)
679 fields = [StructField(f.name, _infer_schema_type(o, f.dataType), True)
680 for o, f in zip(obj, fs)]
681 return StructType(fields)
682
683 else:
684 raise ValueError("Unexpected dataType: %s" % dataType)
685
686
687 _acceptable_types = {
688 BooleanType: (bool,),
689 ByteType: (int, long),
690 ShortType: (int, long),
691 IntegerType: (int, long),
692 LongType: (long,),
693 FloatType: (float,),
694 DoubleType: (float,),
695 DecimalType: (decimal.Decimal,),
696 StringType: (str, unicode),
697 TimestampType: (datetime.datetime,),
698 ArrayType: (list, tuple, array),
699 MapType: (dict,),
700 StructType: (tuple, list),
701 }
705 """
706 Verify the type of obj against dataType, raise an exception if
707 they do not match.
708
709 >>> _verify_type(None, StructType([]))
710 >>> _verify_type("", StringType())
711 >>> _verify_type(0, IntegerType())
712 >>> _verify_type(range(3), ArrayType(ShortType()))
713 >>> _verify_type(set(), ArrayType(StringType())) # doctest: +IGNORE_EXCEPTION_DETAIL
714 Traceback (most recent call last):
715 ...
716 TypeError:...
717 >>> _verify_type({}, MapType(StringType(), IntegerType()))
718 >>> _verify_type((), StructType([]))
719 >>> _verify_type([], StructType([]))
720 >>> _verify_type([1], StructType([])) # doctest: +IGNORE_EXCEPTION_DETAIL
721 Traceback (most recent call last):
722 ...
723 ValueError:...
724 """
725
726 if obj is None:
727 return
728
729 _type = type(dataType)
730 if _type not in _acceptable_types:
731 return
732
733 if type(obj) not in _acceptable_types[_type]:
734 raise TypeError("%s can not accept abject in type %s"
735 % (dataType, type(obj)))
736
737 if isinstance(dataType, ArrayType):
738 for i in obj:
739 _verify_type(i, dataType.elementType)
740
741 elif isinstance(dataType, MapType):
742 for k, v in obj.iteritems():
743 _verify_type(k, dataType.keyType)
744 _verify_type(v, dataType.valueType)
745
746 elif isinstance(dataType, StructType):
747 if len(obj) != len(dataType.fields):
748 raise ValueError("Length of object (%d) does not match with"
749 "length of fields (%d)" % (len(obj), len(dataType.fields)))
750 for v, f in zip(obj, dataType.fields):
751 _verify_type(v, f.dataType)
752
753
754 _cached_cls = {}
758 """ Restore object during unpickling. """
759
760
761
762 k = id(dataType)
763 cls = _cached_cls.get(k)
764 if cls is None:
765
766 cls = _cached_cls.get(dataType)
767 if cls is None:
768 cls = _create_cls(dataType)
769 _cached_cls[dataType] = cls
770 _cached_cls[k] = cls
771 return cls(obj)
772
775 """ Create an customized object with class `cls`. """
776 return cls(v) if v is not None else v
777
780 """ Create a getter for item `i` with schema """
781 cls = _create_cls(dt)
782
783 def getter(self):
784 return _create_object(cls, self[i])
785
786 return getter
787
790 """Return whether `dt` is or has StructType in it"""
791 if isinstance(dt, StructType):
792 return True
793 elif isinstance(dt, ArrayType):
794 return _has_struct(dt.elementType)
795 elif isinstance(dt, MapType):
796 return _has_struct(dt.valueType)
797 return False
798
801 """Create properties according to fields"""
802 ps = {}
803 for i, f in enumerate(fields):
804 name = f.name
805 if (name.startswith("__") and name.endswith("__")
806 or keyword.iskeyword(name)):
807 warnings.warn("field name %s can not be accessed in Python,"
808 "use position to access it instead" % name)
809 if _has_struct(f.dataType):
810
811 getter = _create_getter(f.dataType, i)
812 else:
813 getter = itemgetter(i)
814 ps[name] = property(getter)
815 return ps
816
819 """
820 Create an class by dataType
821
822 The created class is similar to namedtuple, but can have nested schema.
823
824 >>> schema = _parse_schema_abstract("a b c")
825 >>> row = (1, 1.0, "str")
826 >>> schema = _infer_schema_type(row, schema)
827 >>> obj = _create_cls(schema)(row)
828 >>> import pickle
829 >>> pickle.loads(pickle.dumps(obj))
830 Row(a=1, b=1.0, c='str')
831
832 >>> row = [[1], {"key": (1, 2.0)}]
833 >>> schema = _parse_schema_abstract("a[] b{c d}")
834 >>> schema = _infer_schema_type(row, schema)
835 >>> obj = _create_cls(schema)(row)
836 >>> pickle.loads(pickle.dumps(obj))
837 Row(a=[1], b={'key': Row(c=1, d=2.0)})
838 """
839
840 if isinstance(dataType, ArrayType):
841 cls = _create_cls(dataType.elementType)
842
843 class List(list):
844
845 def __getitem__(self, i):
846
847 return _create_object(cls, list.__getitem__(self, i))
848
849 def __repr__(self):
850
851 return "[%s]" % (", ".join(repr(self[i])
852 for i in range(len(self))))
853
854 def __reduce__(self):
855 return list.__reduce__(self)
856
857 return List
858
859 elif isinstance(dataType, MapType):
860 vcls = _create_cls(dataType.valueType)
861
862 class Dict(dict):
863
864 def __getitem__(self, k):
865
866 return _create_object(vcls, dict.__getitem__(self, k))
867
868 def __repr__(self):
869
870 return "{%s}" % (", ".join("%r: %r" % (k, self[k])
871 for k in self))
872
873 def __reduce__(self):
874 return dict.__reduce__(self)
875
876 return Dict
877
878 elif not isinstance(dataType, StructType):
879 raise Exception("unexpected data type: %s" % dataType)
880
881 class Row(tuple):
882
883 """ Row in SchemaRDD """
884 __DATATYPE__ = dataType
885 __FIELDS__ = tuple(f.name for f in dataType.fields)
886 __slots__ = ()
887
888
889 locals().update(_create_properties(dataType.fields))
890
891 def __repr__(self):
892
893 return ("Row(%s)" % ", ".join("%s=%r" % (n, getattr(self, n))
894 for n in self.__FIELDS__))
895
896 def __reduce__(self):
897 return (_restore_object, (self.__DATATYPE__, tuple(self)))
898
899 return Row
900
901
902 -class SQLContext:
903
904 """Main entry point for SparkSQL functionality.
905
906 A SQLContext can be used create L{SchemaRDD}s, register L{SchemaRDD}s as
907 tables, execute SQL over tables, cache tables, and read parquet files.
908 """
909
910 - def __init__(self, sparkContext, sqlContext=None):
911 """Create a new SQLContext.
912
913 @param sparkContext: The SparkContext to wrap.
914 @param sqlContext: An optional JVM Scala SQLContext. If set, we do not instatiate a new
915 SQLContext in the JVM, instead we make all calls to this object.
916
917 >>> srdd = sqlCtx.inferSchema(rdd)
918 >>> sqlCtx.inferSchema(srdd) # doctest: +IGNORE_EXCEPTION_DETAIL
919 Traceback (most recent call last):
920 ...
921 TypeError:...
922
923 >>> bad_rdd = sc.parallelize([1,2,3])
924 >>> sqlCtx.inferSchema(bad_rdd) # doctest: +IGNORE_EXCEPTION_DETAIL
925 Traceback (most recent call last):
926 ...
927 ValueError:...
928
929 >>> from datetime import datetime
930 >>> allTypes = sc.parallelize([Row(i=1, s="string", d=1.0, l=1L,
931 ... b=True, list=[1, 2, 3], dict={"s": 0}, row=Row(a=1),
932 ... time=datetime(2014, 8, 1, 14, 1, 5))])
933 >>> srdd = sqlCtx.inferSchema(allTypes)
934 >>> srdd.registerTempTable("allTypes")
935 >>> sqlCtx.sql('select i+1, d+1, not b, list[1], dict["s"], time, row.a '
936 ... 'from allTypes where b and i > 0').collect()
937 [Row(c0=2, c1=2.0, c2=False, c3=2, c4=0...8, 1, 14, 1, 5), a=1)]
938 >>> srdd.map(lambda x: (x.i, x.s, x.d, x.l, x.b, x.time,
939 ... x.row.a, x.list)).collect()
940 [(1, u'string', 1.0, 1, True, ...(2014, 8, 1, 14, 1, 5), 1, [1, 2, 3])]
941 """
942 self._sc = sparkContext
943 self._jsc = self._sc._jsc
944 self._jvm = self._sc._jvm
945 self._pythonToJava = self._jvm.PythonRDD.pythonToJavaArray
946
947 if sqlContext:
948 self._scala_SQLContext = sqlContext
949
950 @property
951 - def _ssql_ctx(self):
952 """Accessor for the JVM SparkSQL context.
953
954 Subclasses can override this property to provide their own
955 JVM Contexts.
956 """
957 if not hasattr(self, '_scala_SQLContext'):
958 self._scala_SQLContext = self._jvm.SQLContext(self._jsc.sc())
959 return self._scala_SQLContext
960
961 - def registerFunction(self, name, f, returnType=StringType()):
962 """Registers a lambda function as a UDF so it can be used in SQL statements.
963
964 In addition to a name and the function itself, the return type can be optionally specified.
965 When the return type is not given it default to a string and conversion will automatically
966 be done. For any other return type, the produced object must match the specified type.
967
968 >>> sqlCtx.registerFunction("stringLengthString", lambda x: len(x))
969 >>> sqlCtx.sql("SELECT stringLengthString('test')").collect()
970 [Row(c0=u'4')]
971 >>> sqlCtx.registerFunction("stringLengthInt", lambda x: len(x), IntegerType())
972 >>> sqlCtx.sql("SELECT stringLengthInt('test')").collect()
973 [Row(c0=4)]
974 >>> sqlCtx.registerFunction("twoArgs", lambda x, y: len(x) + y, IntegerType())
975 >>> sqlCtx.sql("SELECT twoArgs('test', 1)").collect()
976 [Row(c0=5)]
977 """
978 func = lambda _, it: imap(lambda x: f(*x), it)
979 command = (func,
980 BatchedSerializer(PickleSerializer(), 1024),
981 BatchedSerializer(PickleSerializer(), 1024))
982 env = MapConverter().convert(self._sc.environment,
983 self._sc._gateway._gateway_client)
984 includes = ListConverter().convert(self._sc._python_includes,
985 self._sc._gateway._gateway_client)
986 self._ssql_ctx.registerPython(name,
987 bytearray(CloudPickleSerializer().dumps(command)),
988 env,
989 includes,
990 self._sc.pythonExec,
991 self._sc._javaAccumulator,
992 str(returnType))
993
994 - def inferSchema(self, rdd):
995 """Infer and apply a schema to an RDD of L{Row}s.
996
997 We peek at the first row of the RDD to determine the fields' names
998 and types. Nested collections are supported, which include array,
999 dict, list, Row, tuple, namedtuple, or object.
1000
1001 All the rows in `rdd` should have the same type with the first one,
1002 or it will cause runtime exceptions.
1003
1004 Each row could be L{pyspark.sql.Row} object or namedtuple or objects,
1005 using dict is deprecated.
1006
1007 >>> rdd = sc.parallelize(
1008 ... [Row(field1=1, field2="row1"),
1009 ... Row(field1=2, field2="row2"),
1010 ... Row(field1=3, field2="row3")])
1011 >>> srdd = sqlCtx.inferSchema(rdd)
1012 >>> srdd.collect()[0]
1013 Row(field1=1, field2=u'row1')
1014
1015 >>> NestedRow = Row("f1", "f2")
1016 >>> nestedRdd1 = sc.parallelize([
1017 ... NestedRow(array('i', [1, 2]), {"row1": 1.0}),
1018 ... NestedRow(array('i', [2, 3]), {"row2": 2.0})])
1019 >>> srdd = sqlCtx.inferSchema(nestedRdd1)
1020 >>> srdd.collect()
1021 [Row(f1=[1, 2], f2={u'row1': 1.0}), ..., f2={u'row2': 2.0})]
1022
1023 >>> nestedRdd2 = sc.parallelize([
1024 ... NestedRow([[1, 2], [2, 3]], [1, 2]),
1025 ... NestedRow([[2, 3], [3, 4]], [2, 3])])
1026 >>> srdd = sqlCtx.inferSchema(nestedRdd2)
1027 >>> srdd.collect()
1028 [Row(f1=[[1, 2], [2, 3]], f2=[1, 2]), ..., f2=[2, 3])]
1029 """
1030
1031 if isinstance(rdd, SchemaRDD):
1032 raise TypeError("Cannot apply schema to SchemaRDD")
1033
1034 first = rdd.first()
1035 if not first:
1036 raise ValueError("The first row in RDD is empty, "
1037 "can not infer schema")
1038 if type(first) is dict:
1039 warnings.warn("Using RDD of dict to inferSchema is deprecated,"
1040 "please use pyspark.Row instead")
1041
1042 schema = _infer_schema(first)
1043 rdd = rdd.mapPartitions(lambda rows: _drop_schema(rows, schema))
1044 return self.applySchema(rdd, schema)
1045
1046 - def applySchema(self, rdd, schema):
1047 """
1048 Applies the given schema to the given RDD of L{tuple} or L{list}s.
1049
1050 These tuples or lists can contain complex nested structures like
1051 lists, maps or nested rows.
1052
1053 The schema should be a StructType.
1054
1055 It is important that the schema matches the types of the objects
1056 in each row or exceptions could be thrown at runtime.
1057
1058 >>> rdd2 = sc.parallelize([(1, "row1"), (2, "row2"), (3, "row3")])
1059 >>> schema = StructType([StructField("field1", IntegerType(), False),
1060 ... StructField("field2", StringType(), False)])
1061 >>> srdd = sqlCtx.applySchema(rdd2, schema)
1062 >>> sqlCtx.registerRDDAsTable(srdd, "table1")
1063 >>> srdd2 = sqlCtx.sql("SELECT * from table1")
1064 >>> srdd2.collect()
1065 [Row(field1=1, field2=u'row1'),..., Row(field1=3, field2=u'row3')]
1066
1067 >>> from datetime import datetime
1068 >>> rdd = sc.parallelize([(127, -128L, -32768, 32767, 2147483647L, 1.0,
1069 ... datetime(2010, 1, 1, 1, 1, 1),
1070 ... {"a": 1}, (2,), [1, 2, 3], None)])
1071 >>> schema = StructType([
1072 ... StructField("byte1", ByteType(), False),
1073 ... StructField("byte2", ByteType(), False),
1074 ... StructField("short1", ShortType(), False),
1075 ... StructField("short2", ShortType(), False),
1076 ... StructField("int", IntegerType(), False),
1077 ... StructField("float", FloatType(), False),
1078 ... StructField("time", TimestampType(), False),
1079 ... StructField("map",
1080 ... MapType(StringType(), IntegerType(), False), False),
1081 ... StructField("struct",
1082 ... StructType([StructField("b", ShortType(), False)]), False),
1083 ... StructField("list", ArrayType(ByteType(), False), False),
1084 ... StructField("null", DoubleType(), True)])
1085 >>> srdd = sqlCtx.applySchema(rdd, schema)
1086 >>> results = srdd.map(
1087 ... lambda x: (x.byte1, x.byte2, x.short1, x.short2, x.int, x.float, x.time,
1088 ... x.map["a"], x.struct.b, x.list, x.null))
1089 >>> results.collect()[0]
1090 (127, -128, -32768, 32767, 2147483647, 1.0, ...(2010, 1, 1, 1, 1, 1), 1, 2, [1, 2, 3], None)
1091
1092 >>> srdd.registerTempTable("table2")
1093 >>> sqlCtx.sql(
1094 ... "SELECT byte1 - 1 AS byte1, byte2 + 1 AS byte2, " +
1095 ... "short1 + 1 AS short1, short2 - 1 AS short2, int - 1 AS int, " +
1096 ... "float + 1.5 as float FROM table2").collect()
1097 [Row(byte1=126, byte2=-127, short1=-32767, short2=32766, int=2147483646, float=2.5)]
1098
1099 >>> rdd = sc.parallelize([(127, -32768, 1.0,
1100 ... datetime(2010, 1, 1, 1, 1, 1),
1101 ... {"a": 1}, (2,), [1, 2, 3])])
1102 >>> abstract = "byte short float time map{} struct(b) list[]"
1103 >>> schema = _parse_schema_abstract(abstract)
1104 >>> typedSchema = _infer_schema_type(rdd.first(), schema)
1105 >>> srdd = sqlCtx.applySchema(rdd, typedSchema)
1106 >>> srdd.collect()
1107 [Row(byte=127, short=-32768, float=1.0, time=..., list=[1, 2, 3])]
1108 """
1109
1110 if isinstance(rdd, SchemaRDD):
1111 raise TypeError("Cannot apply schema to SchemaRDD")
1112
1113 if not isinstance(schema, StructType):
1114 raise TypeError("schema should be StructType")
1115
1116
1117 rows = rdd.take(10)
1118 for row in rows:
1119 _verify_type(row, schema)
1120
1121 batched = isinstance(rdd._jrdd_deserializer, BatchedSerializer)
1122 jrdd = self._pythonToJava(rdd._jrdd, batched)
1123 srdd = self._ssql_ctx.applySchemaToPythonRDD(jrdd.rdd(), str(schema))
1124 return SchemaRDD(srdd.toJavaSchemaRDD(), self)
1125
1126 - def registerRDDAsTable(self, rdd, tableName):
1127 """Registers the given RDD as a temporary table in the catalog.
1128
1129 Temporary tables exist only during the lifetime of this instance of
1130 SQLContext.
1131
1132 >>> srdd = sqlCtx.inferSchema(rdd)
1133 >>> sqlCtx.registerRDDAsTable(srdd, "table1")
1134 """
1135 if (rdd.__class__ is SchemaRDD):
1136 srdd = rdd._jschema_rdd.baseSchemaRDD()
1137 self._ssql_ctx.registerRDDAsTable(srdd, tableName)
1138 else:
1139 raise ValueError("Can only register SchemaRDD as table")
1140
1141 - def parquetFile(self, path):
1142 """Loads a Parquet file, returning the result as a L{SchemaRDD}.
1143
1144 >>> import tempfile, shutil
1145 >>> parquetFile = tempfile.mkdtemp()
1146 >>> shutil.rmtree(parquetFile)
1147 >>> srdd = sqlCtx.inferSchema(rdd)
1148 >>> srdd.saveAsParquetFile(parquetFile)
1149 >>> srdd2 = sqlCtx.parquetFile(parquetFile)
1150 >>> sorted(srdd.collect()) == sorted(srdd2.collect())
1151 True
1152 """
1153 jschema_rdd = self._ssql_ctx.parquetFile(path).toJavaSchemaRDD()
1154 return SchemaRDD(jschema_rdd, self)
1155
1156 - def jsonFile(self, path, schema=None):
1157 """
1158 Loads a text file storing one JSON object per line as a
1159 L{SchemaRDD}.
1160
1161 If the schema is provided, applies the given schema to this
1162 JSON dataset.
1163
1164 Otherwise, it goes through the entire dataset once to determine
1165 the schema.
1166
1167 >>> import tempfile, shutil
1168 >>> jsonFile = tempfile.mkdtemp()
1169 >>> shutil.rmtree(jsonFile)
1170 >>> ofn = open(jsonFile, 'w')
1171 >>> for json in jsonStrings:
1172 ... print>>ofn, json
1173 >>> ofn.close()
1174 >>> srdd1 = sqlCtx.jsonFile(jsonFile)
1175 >>> sqlCtx.registerRDDAsTable(srdd1, "table1")
1176 >>> srdd2 = sqlCtx.sql(
1177 ... "SELECT field1 AS f1, field2 as f2, field3 as f3, "
1178 ... "field6 as f4 from table1")
1179 >>> for r in srdd2.collect():
1180 ... print r
1181 Row(f1=1, f2=u'row1', f3=Row(field4=11, field5=None), f4=None)
1182 Row(f1=2, f2=None, f3=Row(field4=22,..., f4=[Row(field7=u'row2')])
1183 Row(f1=None, f2=u'row3', f3=Row(field4=33, field5=[]), f4=None)
1184 >>> srdd3 = sqlCtx.jsonFile(jsonFile, srdd1.schema())
1185 >>> sqlCtx.registerRDDAsTable(srdd3, "table2")
1186 >>> srdd4 = sqlCtx.sql(
1187 ... "SELECT field1 AS f1, field2 as f2, field3 as f3, "
1188 ... "field6 as f4 from table2")
1189 >>> for r in srdd4.collect():
1190 ... print r
1191 Row(f1=1, f2=u'row1', f3=Row(field4=11, field5=None), f4=None)
1192 Row(f1=2, f2=None, f3=Row(field4=22,..., f4=[Row(field7=u'row2')])
1193 Row(f1=None, f2=u'row3', f3=Row(field4=33, field5=[]), f4=None)
1194 >>> schema = StructType([
1195 ... StructField("field2", StringType(), True),
1196 ... StructField("field3",
1197 ... StructType([
1198 ... StructField("field5",
1199 ... ArrayType(IntegerType(), False), True)]), False)])
1200 >>> srdd5 = sqlCtx.jsonFile(jsonFile, schema)
1201 >>> sqlCtx.registerRDDAsTable(srdd5, "table3")
1202 >>> srdd6 = sqlCtx.sql(
1203 ... "SELECT field2 AS f1, field3.field5 as f2, "
1204 ... "field3.field5[0] as f3 from table3")
1205 >>> srdd6.collect()
1206 [Row(f1=u'row1', f2=None, f3=None)...Row(f1=u'row3', f2=[], f3=None)]
1207 """
1208 if schema is None:
1209 srdd = self._ssql_ctx.jsonFile(path)
1210 else:
1211 scala_datatype = self._ssql_ctx.parseDataType(str(schema))
1212 srdd = self._ssql_ctx.jsonFile(path, scala_datatype)
1213 return SchemaRDD(srdd.toJavaSchemaRDD(), self)
1214
1215 - def jsonRDD(self, rdd, schema=None):
1216 """Loads an RDD storing one JSON object per string as a L{SchemaRDD}.
1217
1218 If the schema is provided, applies the given schema to this
1219 JSON dataset.
1220
1221 Otherwise, it goes through the entire dataset once to determine
1222 the schema.
1223
1224 >>> srdd1 = sqlCtx.jsonRDD(json)
1225 >>> sqlCtx.registerRDDAsTable(srdd1, "table1")
1226 >>> srdd2 = sqlCtx.sql(
1227 ... "SELECT field1 AS f1, field2 as f2, field3 as f3, "
1228 ... "field6 as f4 from table1")
1229 >>> for r in srdd2.collect():
1230 ... print r
1231 Row(f1=1, f2=u'row1', f3=Row(field4=11, field5=None), f4=None)
1232 Row(f1=2, f2=None, f3=Row(field4=22..., f4=[Row(field7=u'row2')])
1233 Row(f1=None, f2=u'row3', f3=Row(field4=33, field5=[]), f4=None)
1234 >>> srdd3 = sqlCtx.jsonRDD(json, srdd1.schema())
1235 >>> sqlCtx.registerRDDAsTable(srdd3, "table2")
1236 >>> srdd4 = sqlCtx.sql(
1237 ... "SELECT field1 AS f1, field2 as f2, field3 as f3, "
1238 ... "field6 as f4 from table2")
1239 >>> for r in srdd4.collect():
1240 ... print r
1241 Row(f1=1, f2=u'row1', f3=Row(field4=11, field5=None), f4=None)
1242 Row(f1=2, f2=None, f3=Row(field4=22..., f4=[Row(field7=u'row2')])
1243 Row(f1=None, f2=u'row3', f3=Row(field4=33, field5=[]), f4=None)
1244 >>> schema = StructType([
1245 ... StructField("field2", StringType(), True),
1246 ... StructField("field3",
1247 ... StructType([
1248 ... StructField("field5",
1249 ... ArrayType(IntegerType(), False), True)]), False)])
1250 >>> srdd5 = sqlCtx.jsonRDD(json, schema)
1251 >>> sqlCtx.registerRDDAsTable(srdd5, "table3")
1252 >>> srdd6 = sqlCtx.sql(
1253 ... "SELECT field2 AS f1, field3.field5 as f2, "
1254 ... "field3.field5[0] as f3 from table3")
1255 >>> srdd6.collect()
1256 [Row(f1=u'row1', f2=None,...Row(f1=u'row3', f2=[], f3=None)]
1257
1258 >>> sqlCtx.jsonRDD(sc.parallelize(['{}',
1259 ... '{"key0": {"key1": "value1"}}'])).collect()
1260 [Row(key0=None), Row(key0=Row(key1=u'value1'))]
1261 >>> sqlCtx.jsonRDD(sc.parallelize(['{"key0": null}',
1262 ... '{"key0": {"key1": "value1"}}'])).collect()
1263 [Row(key0=None), Row(key0=Row(key1=u'value1'))]
1264 """
1265
1266 def func(iterator):
1267 for x in iterator:
1268 if not isinstance(x, basestring):
1269 x = unicode(x)
1270 if isinstance(x, unicode):
1271 x = x.encode("utf-8")
1272 yield x
1273 keyed = rdd.mapPartitions(func)
1274 keyed._bypass_serializer = True
1275 jrdd = keyed._jrdd.map(self._jvm.BytesToString())
1276 if schema is None:
1277 srdd = self._ssql_ctx.jsonRDD(jrdd.rdd())
1278 else:
1279 scala_datatype = self._ssql_ctx.parseDataType(str(schema))
1280 srdd = self._ssql_ctx.jsonRDD(jrdd.rdd(), scala_datatype)
1281 return SchemaRDD(srdd.toJavaSchemaRDD(), self)
1282
1283 - def sql(self, sqlQuery):
1284 """Return a L{SchemaRDD} representing the result of the given query.
1285
1286 >>> srdd = sqlCtx.inferSchema(rdd)
1287 >>> sqlCtx.registerRDDAsTable(srdd, "table1")
1288 >>> srdd2 = sqlCtx.sql("SELECT field1 AS f1, field2 as f2 from table1")
1289 >>> srdd2.collect()
1290 [Row(f1=1, f2=u'row1'), Row(f1=2, f2=u'row2'), Row(f1=3, f2=u'row3')]
1291 """
1292 return SchemaRDD(self._ssql_ctx.sql(sqlQuery).toJavaSchemaRDD(), self)
1293
1294 - def table(self, tableName):
1295 """Returns the specified table as a L{SchemaRDD}.
1296
1297 >>> srdd = sqlCtx.inferSchema(rdd)
1298 >>> sqlCtx.registerRDDAsTable(srdd, "table1")
1299 >>> srdd2 = sqlCtx.table("table1")
1300 >>> sorted(srdd.collect()) == sorted(srdd2.collect())
1301 True
1302 """
1303 return SchemaRDD(self._ssql_ctx.table(tableName).toJavaSchemaRDD(), self)
1304
1305 - def cacheTable(self, tableName):
1306 """Caches the specified table in-memory."""
1307 self._ssql_ctx.cacheTable(tableName)
1308
1309 - def uncacheTable(self, tableName):
1310 """Removes the specified table from the in-memory cache."""
1311 self._ssql_ctx.uncacheTable(tableName)
1312
1313
1314 -class HiveContext(SQLContext):
1315
1316 """A variant of Spark SQL that integrates with data stored in Hive.
1317
1318 Configuration for Hive is read from hive-site.xml on the classpath.
1319 It supports running both SQL and HiveQL commands.
1320 """
1321
1322 - def __init__(self, sparkContext, hiveContext=None):
1323 """Create a new HiveContext.
1324
1325 @param sparkContext: The SparkContext to wrap.
1326 @param hiveContext: An optional JVM Scala HiveContext. If set, we do not instatiate a new
1327 HiveContext in the JVM, instead we make all calls to this object.
1328 """
1329 SQLContext.__init__(self, sparkContext)
1330
1331 if hiveContext:
1332 self._scala_HiveContext = hiveContext
1333
1334 @property
1335 - def _ssql_ctx(self):
1336 try:
1337 if not hasattr(self, '_scala_HiveContext'):
1338 self._scala_HiveContext = self._get_hive_ctx()
1339 return self._scala_HiveContext
1340 except Py4JError as e:
1341 raise Exception("You must build Spark with Hive. "
1342 "Export 'SPARK_HIVE=true' and run "
1343 "sbt/sbt assembly", e)
1344
1345 - def _get_hive_ctx(self):
1346 return self._jvm.HiveContext(self._jsc.sc())
1347
1348 - def hiveql(self, hqlQuery):
1349 """
1350 DEPRECATED: Use sql()
1351 """
1352 warnings.warn("hiveql() is deprecated as the sql function now parses using HiveQL by" +
1353 "default. The SQL dialect for parsing can be set using 'spark.sql.dialect'",
1354 DeprecationWarning)
1355 return SchemaRDD(self._ssql_ctx.hiveql(hqlQuery).toJavaSchemaRDD(), self)
1356
1357 - def hql(self, hqlQuery):
1358 """
1359 DEPRECATED: Use sql()
1360 """
1361 warnings.warn("hql() is deprecated as the sql function now parses using HiveQL by" +
1362 "default. The SQL dialect for parsing can be set using 'spark.sql.dialect'",
1363 DeprecationWarning)
1364 return self.hiveql(hqlQuery)
1365
1366
1367 -class LocalHiveContext(HiveContext):
1368
1369 """Starts up an instance of hive where metadata is stored locally.
1370
1371 An in-process metadata data is created with data stored in ./metadata.
1372 Warehouse data is stored in in ./warehouse.
1373
1374 >>> import os
1375 >>> hiveCtx = LocalHiveContext(sc)
1376 >>> try:
1377 ... supress = hiveCtx.sql("DROP TABLE src")
1378 ... except Exception:
1379 ... pass
1380 >>> kv1 = os.path.join(os.environ["SPARK_HOME"],
1381 ... 'examples/src/main/resources/kv1.txt')
1382 >>> supress = hiveCtx.sql(
1383 ... "CREATE TABLE IF NOT EXISTS src (key INT, value STRING)")
1384 >>> supress = hiveCtx.sql("LOAD DATA LOCAL INPATH '%s' INTO TABLE src"
1385 ... % kv1)
1386 >>> results = hiveCtx.sql("FROM src SELECT value"
1387 ... ).map(lambda r: int(r.value.split('_')[1]))
1388 >>> num = results.count()
1389 >>> reduce_sum = results.reduce(lambda x, y: x + y)
1390 >>> num
1391 500
1392 >>> reduce_sum
1393 130091
1394 """
1395
1396 - def __init__(self, sparkContext, sqlContext=None):
1397 HiveContext.__init__(self, sparkContext, sqlContext)
1398 warnings.warn("LocalHiveContext is deprecated. "
1399 "Use HiveContext instead.", DeprecationWarning)
1400
1401 - def _get_hive_ctx(self):
1402 return self._jvm.LocalHiveContext(self._jsc.sc())
1403
1404
1405 -class TestHiveContext(HiveContext):
1406
1407 - def _get_hive_ctx(self):
1408 return self._jvm.TestHiveContext(self._jsc.sc())
1409
1412 row = Row(*values)
1413 row.__FIELDS__ = fields
1414 return row
1415
1416
1417 -class Row(tuple):
1418
1419 """
1420 A row in L{SchemaRDD}. The fields in it can be accessed like attributes.
1421
1422 Row can be used to create a row object by using named arguments,
1423 the fields will be sorted by names.
1424
1425 >>> row = Row(name="Alice", age=11)
1426 >>> row
1427 Row(age=11, name='Alice')
1428 >>> row.name, row.age
1429 ('Alice', 11)
1430
1431 Row also can be used to create another Row like class, then it
1432 could be used to create Row objects, such as
1433
1434 >>> Person = Row("name", "age")
1435 >>> Person
1436 <Row(name, age)>
1437 >>> Person("Alice", 11)
1438 Row(name='Alice', age=11)
1439 """
1440
1441 - def __new__(self, *args, **kwargs):
1442 if args and kwargs:
1443 raise ValueError("Can not use both args "
1444 "and kwargs to create Row")
1445 if args:
1446
1447 return tuple.__new__(self, args)
1448
1449 elif kwargs:
1450
1451 names = sorted(kwargs.keys())
1452 values = tuple(kwargs[n] for n in names)
1453 row = tuple.__new__(self, values)
1454 row.__FIELDS__ = names
1455 return row
1456
1457 else:
1458 raise ValueError("No args or kwargs")
1459
1460
1462 """create new Row object"""
1463 return _create_row(self, args)
1464
1466 if item.startswith("__"):
1467 raise AttributeError(item)
1468 try:
1469
1470
1471 idx = self.__FIELDS__.index(item)
1472 return self[idx]
1473 except IndexError:
1474 raise AttributeError(item)
1475
1477 if hasattr(self, "__FIELDS__"):
1478 return (_create_row, (self.__FIELDS__, tuple(self)))
1479 else:
1480 return tuple.__reduce__(self)
1481
1483 if hasattr(self, "__FIELDS__"):
1484 return "Row(%s)" % ", ".join("%s=%r" % (k, v)
1485 for k, v in zip(self.__FIELDS__, self))
1486 else:
1487 return "<Row(%s)>" % ", ".join(self)
1488
1491
1492 """An RDD of L{Row} objects that has an associated schema.
1493
1494 The underlying JVM object is a SchemaRDD, not a PythonRDD, so we can
1495 utilize the relational query api exposed by SparkSQL.
1496
1497 For normal L{pyspark.rdd.RDD} operations (map, count, etc.) the
1498 L{SchemaRDD} is not operated on directly, as it's underlying
1499 implementation is an RDD composed of Java objects. Instead it is
1500 converted to a PythonRDD in the JVM, on which Python operations can
1501 be done.
1502
1503 This class receives raw tuples from Java but assigns a class to it in
1504 all its data-collection methods (mapPartitionsWithIndex, collect, take,
1505 etc) so that PySpark sees them as Row objects with named fields.
1506 """
1507
1508 - def __init__(self, jschema_rdd, sql_ctx):
1509 self.sql_ctx = sql_ctx
1510 self._sc = sql_ctx._sc
1511 clsName = jschema_rdd.getClass().getName()
1512 assert clsName.endswith("JavaSchemaRDD"), "jschema_rdd must be JavaSchemaRDD"
1513 self._jschema_rdd = jschema_rdd
1514
1515 self.is_cached = False
1516 self.is_checkpointed = False
1517 self.ctx = self.sql_ctx._sc
1518
1519 self._jrdd_deserializer = BatchedSerializer(PickleSerializer())
1520
1521 @property
1523 """Lazy evaluation of PythonRDD object.
1524
1525 Only done when a user calls methods defined by the
1526 L{pyspark.rdd.RDD} super class (map, filter, etc.).
1527 """
1528 if not hasattr(self, '_lazy_jrdd'):
1529 self._lazy_jrdd = self._jschema_rdd.baseSchemaRDD().javaToPython()
1530 return self._lazy_jrdd
1531
1532 @property
1534 return self._jrdd.id()
1535
1537 """Save the contents as a Parquet file, preserving the schema.
1538
1539 Files that are written out using this method can be read back in as
1540 a SchemaRDD using the L{SQLContext.parquetFile} method.
1541
1542 >>> import tempfile, shutil
1543 >>> parquetFile = tempfile.mkdtemp()
1544 >>> shutil.rmtree(parquetFile)
1545 >>> srdd = sqlCtx.inferSchema(rdd)
1546 >>> srdd.saveAsParquetFile(parquetFile)
1547 >>> srdd2 = sqlCtx.parquetFile(parquetFile)
1548 >>> sorted(srdd2.collect()) == sorted(srdd.collect())
1549 True
1550 """
1551 self._jschema_rdd.saveAsParquetFile(path)
1552
1554 """Registers this RDD as a temporary table using the given name.
1555
1556 The lifetime of this temporary table is tied to the L{SQLContext}
1557 that was used to create this SchemaRDD.
1558
1559 >>> srdd = sqlCtx.inferSchema(rdd)
1560 >>> srdd.registerTempTable("test")
1561 >>> srdd2 = sqlCtx.sql("select * from test")
1562 >>> sorted(srdd.collect()) == sorted(srdd2.collect())
1563 True
1564 """
1565 self._jschema_rdd.registerTempTable(name)
1566
1568 warnings.warn("Use registerTempTable instead of registerAsTable.", DeprecationWarning)
1569 self.registerTempTable(name)
1570
1571 - def insertInto(self, tableName, overwrite=False):
1572 """Inserts the contents of this SchemaRDD into the specified table.
1573
1574 Optionally overwriting any existing data.
1575 """
1576 self._jschema_rdd.insertInto(tableName, overwrite)
1577
1579 """Creates a new table with the contents of this SchemaRDD."""
1580 self._jschema_rdd.saveAsTable(tableName)
1581
1583 """Returns the schema of this SchemaRDD (represented by
1584 a L{StructType})."""
1585 return _parse_datatype_string(self._jschema_rdd.baseSchemaRDD().schema().toString())
1586
1588 """Returns the output schema in the tree format."""
1589 return self._jschema_rdd.schemaString()
1590
1592 """Prints out the schema in the tree format."""
1593 print self.schemaString()
1594
1596 """Return the number of elements in this RDD.
1597
1598 Unlike the base RDD implementation of count, this implementation
1599 leverages the query optimizer to compute the count on the SchemaRDD,
1600 which supports features such as filter pushdown.
1601
1602 >>> srdd = sqlCtx.inferSchema(rdd)
1603 >>> srdd.count()
1604 3L
1605 >>> srdd.count() == srdd.map(lambda x: x).count()
1606 True
1607 """
1608 return self._jschema_rdd.count()
1609
1611 """
1612 Return a list that contains all of the rows in this RDD.
1613
1614 Each object in the list is on Row, the fields can be accessed as
1615 attributes.
1616 """
1617 rows = RDD.collect(self)
1618 cls = _create_cls(self.schema())
1619 return map(cls, rows)
1620
1621
1622
1624 """
1625 Return a new RDD by applying a function to each partition of this RDD,
1626 while tracking the index of the original partition.
1627
1628 >>> rdd = sc.parallelize([1, 2, 3, 4], 4)
1629 >>> def f(splitIndex, iterator): yield splitIndex
1630 >>> rdd.mapPartitionsWithIndex(f).sum()
1631 6
1632 """
1633 rdd = RDD(self._jrdd, self._sc, self._jrdd_deserializer)
1634
1635 schema = self.schema()
1636
1637 def applySchema(_, it):
1638 cls = _create_cls(schema)
1639 return itertools.imap(cls, it)
1640
1641 objrdd = rdd.mapPartitionsWithIndex(applySchema, preservesPartitioning)
1642 return objrdd.mapPartitionsWithIndex(f, preservesPartitioning)
1643
1644
1645
1646
1648 self.is_cached = True
1649 self._jschema_rdd.cache()
1650 return self
1651
1653 self.is_cached = True
1654 javaStorageLevel = self.ctx._getJavaStorageLevel(storageLevel)
1655 self._jschema_rdd.persist(javaStorageLevel)
1656 return self
1657
1659 self.is_cached = False
1660 self._jschema_rdd.unpersist(blocking)
1661 return self
1662
1664 self.is_checkpointed = True
1665 self._jschema_rdd.checkpoint()
1666
1669
1671 checkpointFile = self._jschema_rdd.getCheckpointFile()
1672 if checkpointFile.isPresent():
1673 return checkpointFile.get()
1674
1675 - def coalesce(self, numPartitions, shuffle=False):
1678
1682
1684 if (other.__class__ is SchemaRDD):
1685 rdd = self._jschema_rdd.intersection(other._jschema_rdd)
1686 return SchemaRDD(rdd, self.sql_ctx)
1687 else:
1688 raise ValueError("Can only intersect with another SchemaRDD")
1689
1693
1694 - def subtract(self, other, numPartitions=None):
1695 if (other.__class__ is SchemaRDD):
1696 if numPartitions is None:
1697 rdd = self._jschema_rdd.subtract(other._jschema_rdd)
1698 else:
1699 rdd = self._jschema_rdd.subtract(other._jschema_rdd,
1700 numPartitions)
1701 return SchemaRDD(rdd, self.sql_ctx)
1702 else:
1703 raise ValueError("Can only subtract another SchemaRDD")
1704
1707 import doctest
1708 from array import array
1709 from pyspark.context import SparkContext
1710
1711 import pyspark.sql
1712 from pyspark.sql import Row, SQLContext
1713 globs = pyspark.sql.__dict__.copy()
1714
1715
1716 sc = SparkContext('local[4]', 'PythonTest', batchSize=2)
1717 globs['sc'] = sc
1718 globs['sqlCtx'] = SQLContext(sc)
1719 globs['rdd'] = sc.parallelize(
1720 [Row(field1=1, field2="row1"),
1721 Row(field1=2, field2="row2"),
1722 Row(field1=3, field2="row3")]
1723 )
1724 jsonStrings = [
1725 '{"field1": 1, "field2": "row1", "field3":{"field4":11}}',
1726 '{"field1" : 2, "field3":{"field4":22, "field5": [10, 11]},'
1727 '"field6":[{"field7": "row2"}]}',
1728 '{"field1" : null, "field2": "row3", '
1729 '"field3":{"field4":33, "field5": []}}'
1730 ]
1731 globs['jsonStrings'] = jsonStrings
1732 globs['json'] = sc.parallelize(jsonStrings)
1733 (failure_count, test_count) = doctest.testmod(
1734 pyspark.sql, globs=globs, optionflags=doctest.ELLIPSIS)
1735 globs['sc'].stop()
1736 if failure_count:
1737 exit(-1)
1738
1739
1740 if __name__ == "__main__":
1741 _test()
1742