|
| 1 | +package com.sparkTutorial.sparkSql; |
| 2 | + |
| 3 | +import org.apache.log4j.Level; |
| 4 | +import org.apache.log4j.Logger; |
| 5 | +import org.apache.spark.sql.Column; |
| 6 | +import org.apache.spark.sql.Dataset; |
| 7 | +import org.apache.spark.sql.Row; |
| 8 | +import org.apache.spark.sql.SparkSession; |
| 9 | + |
| 10 | +import static org.apache.spark.sql.functions.avg; |
| 11 | +import static org.apache.spark.sql.functions.max; |
| 12 | + |
| 13 | +public class StackOverFlowSurvey { |
| 14 | + |
| 15 | + private static final String AGE_MIDPOINT = "age_midpoint"; |
| 16 | + private static final String SALARY_MIDPOINT = "salary_midpoint"; |
| 17 | + |
| 18 | + public static void main(String[] args) throws Exception { |
| 19 | + |
| 20 | + Logger.getLogger("org").setLevel(Level.ERROR); |
| 21 | + SparkSession session = SparkSession.builder().appName("StackOverFlowSurvey").master("local[1]").getOrCreate(); |
| 22 | + |
| 23 | + Dataset<Row> responses = session.read().option("header","true").csv("in/2016-stack-overflow-survey-responses.csv"); |
| 24 | + |
| 25 | + System.out.println("=== Print out schema ==="); |
| 26 | + responses.printSchema(); |
| 27 | + |
| 28 | + System.out.println("=== Creates a temporary view called response ==="); |
| 29 | + responses.createOrReplaceTempView("response"); |
| 30 | + |
| 31 | + System.out.println("=== Print 20 records of responses table ==="); |
| 32 | + responses.show(20); |
| 33 | + |
| 34 | + System.out.println("=== Print the so_region and self_identification columns of gender table ==="); |
| 35 | + responses.select(new Column("so_region"), new Column("self_identification")).show(); |
| 36 | + |
| 37 | + System.out.println("=== Print records where the response is from Afghanistan ==="); |
| 38 | + responses.filter(new Column("country").equalTo("Afghanistan")).show(); |
| 39 | + |
| 40 | + System.out.println("=== Print the count of occupations ==="); |
| 41 | + responses.groupBy(new Column("occupation")).count().show(); |
| 42 | + |
| 43 | + |
| 44 | + System.out.println("=== Cast the salary mid point and age mid point to integer ==="); |
| 45 | + Dataset<Row> castedResponse = responses.withColumn(SALARY_MIDPOINT, new Column(SALARY_MIDPOINT).cast("integer")) |
| 46 | + .withColumn(AGE_MIDPOINT, new Column(AGE_MIDPOINT).cast("integer")); |
| 47 | + |
| 48 | + System.out.println("=== Print out casted schema ==="); |
| 49 | + castedResponse.printSchema(); |
| 50 | + |
| 51 | + System.out.println("=== Print records with average mid age less than 20 ==="); |
| 52 | + castedResponse.filter(new Column(AGE_MIDPOINT).$less(20)).show(); |
| 53 | + |
| 54 | + System.out.println("=== Print the result with salary middle point in descending order ==="); |
| 55 | + castedResponse.orderBy(new Column(SALARY_MIDPOINT ).desc()).show(); |
| 56 | + |
| 57 | + System.out.println("=== Group by country and aggregate by average salary middle point and max age middle point ==="); |
| 58 | + castedResponse.groupBy("country").agg(avg(SALARY_MIDPOINT), max(AGE_MIDPOINT)).show(); |
| 59 | + |
| 60 | + } |
| 61 | +} |
0 commit comments