|
| 1 | +package com.sparkTutorial.sparkSql; |
| 2 | + |
| 3 | +import org.apache.log4j.Level; |
| 4 | +import org.apache.log4j.Logger; |
| 5 | +import org.apache.spark.SparkConf; |
| 6 | +import org.apache.spark.api.java.JavaRDD; |
| 7 | +import org.apache.spark.api.java.JavaSparkContext; |
| 8 | +import org.apache.spark.sql.Dataset; |
| 9 | +import org.apache.spark.sql.Encoders; |
| 10 | +import org.apache.spark.sql.SparkSession; |
| 11 | + |
| 12 | +import static org.apache.spark.sql.functions.avg; |
| 13 | +import static org.apache.spark.sql.functions.max; |
| 14 | + |
| 15 | + |
| 16 | +public class TypedDataset { |
| 17 | + private static final String AGE_MIDPOINT = "ageMidpoint"; |
| 18 | + private static final String SALARY_MIDPOINT = "salaryMidPoint"; |
| 19 | + private static final String SALARY_MIDPOINT_BUCKET = "salaryMidpointBucket"; |
| 20 | + private static final float NULL_VALUE = -1.0f; |
| 21 | + private static final String COMMA_DELIMITER = ",(?=([^\"]*\"[^\"]*\")*[^\"]*$)"; |
| 22 | + |
| 23 | + public static void main(String[] args) throws Exception { |
| 24 | + |
| 25 | + Logger.getLogger("org").setLevel(Level.ERROR); |
| 26 | + SparkConf conf = new SparkConf().setAppName("StackOverFlowSurvey").setMaster("local[1]"); |
| 27 | + |
| 28 | + JavaSparkContext sc = new JavaSparkContext(conf); |
| 29 | + |
| 30 | + SparkSession session = SparkSession.builder().appName("StackOverFlowSurvey").master("local[1]").getOrCreate(); |
| 31 | + |
| 32 | + JavaRDD<String> lines = sc.textFile("in/2016-stack-overflow-survey-responses.csv"); |
| 33 | + |
| 34 | + JavaRDD<Response> responseRDD = lines |
| 35 | + .filter(line -> !line.split(COMMA_DELIMITER, -1)[2].equals("country")) |
| 36 | + .map(line -> { |
| 37 | + String[] splits = line.split(COMMA_DELIMITER, -1); |
| 38 | + return new Response(splits[2], convertStringToFloat(splits[6]), splits[9], convertStringToFloat(splits[14])); |
| 39 | + }); |
| 40 | + Dataset<Response> responseDataset = session.createDataset(responseRDD.rdd(), Encoders.bean(Response.class)); |
| 41 | + |
| 42 | + System.out.println("=== Print out schema ==="); |
| 43 | + responseDataset.printSchema(); |
| 44 | + |
| 45 | + System.out.println("=== Print 20 records of responses table ==="); |
| 46 | + responseDataset.show(20); |
| 47 | + |
| 48 | + System.out.println("=== Print records where the response is from Afghanistan ==="); |
| 49 | + responseDataset.filter(response -> response.getCountry().equals("Afghanistan")).show(); |
| 50 | + |
| 51 | + System.out.println("=== Print the count of occupations ==="); |
| 52 | + responseDataset.groupBy(responseDataset.col("occupation")).count().show(); |
| 53 | + |
| 54 | + |
| 55 | + System.out.println("=== Print records with average mid age less than 20 ==="); |
| 56 | + responseDataset.filter(response -> response.getAgeMidPoint() != NULL_VALUE && response.getAgeMidPoint() < 20).show(); |
| 57 | + |
| 58 | + System.out.println("=== Print the result with salary middle point in descending order ==="); |
| 59 | + responseDataset.orderBy(responseDataset.col(SALARY_MIDPOINT ).desc()).show(); |
| 60 | + |
| 61 | + System.out.println("=== Group by country and aggregate by average salary middle point and max age middle point ==="); |
| 62 | + responseDataset |
| 63 | + .filter(response -> response.getSalaryMidPoint() != NULL_VALUE) |
| 64 | + .groupBy("country") |
| 65 | + .agg(avg(SALARY_MIDPOINT), max(AGE_MIDPOINT)) |
| 66 | + .show(); |
| 67 | + |
| 68 | + System.out.println("=== Group by salary bucket ==="); |
| 69 | + |
| 70 | + responseDataset |
| 71 | + .map(response -> Math.round(response.getSalaryMidPoint()/20000) * 20000, Encoders.INT()) |
| 72 | + .withColumnRenamed("value", SALARY_MIDPOINT_BUCKET) |
| 73 | + .groupBy(SALARY_MIDPOINT_BUCKET) |
| 74 | + .count() |
| 75 | + .orderBy(SALARY_MIDPOINT_BUCKET).show(); |
| 76 | + } |
| 77 | + |
| 78 | + private static float convertStringToFloat(String split) { |
| 79 | + return split.isEmpty() ? NULL_VALUE : Float.valueOf(split); |
| 80 | + } |
| 81 | + |
| 82 | +} |
0 commit comments