/*
 * Decompiled with CFR 0.152.
 */
package io.prestosql.operator.scalar;

import com.google.common.collect.ImmutableList;
import com.google.common.primitives.Primitives;
import io.airlift.bytecode.Access;
import io.airlift.bytecode.BytecodeBlock;
import io.airlift.bytecode.BytecodeNode;
import io.airlift.bytecode.ClassDefinition;
import io.airlift.bytecode.MethodDefinition;
import io.airlift.bytecode.Parameter;
import io.airlift.bytecode.ParameterizedType;
import io.airlift.bytecode.Scope;
import io.airlift.bytecode.Variable;
import io.airlift.bytecode.control.ForLoop;
import io.airlift.bytecode.control.IfStatement;
import io.airlift.bytecode.expression.BytecodeExpression;
import io.airlift.bytecode.expression.BytecodeExpressions;
import io.airlift.bytecode.instruction.VariableInstruction;
import io.prestosql.annotation.UsedByGeneratedCode;
import io.prestosql.metadata.FunctionArgumentDefinition;
import io.prestosql.metadata.FunctionBinding;
import io.prestosql.metadata.FunctionKind;
import io.prestosql.metadata.FunctionMetadata;
import io.prestosql.metadata.LongVariableConstraint;
import io.prestosql.metadata.Signature;
import io.prestosql.metadata.SqlScalarFunction;
import io.prestosql.metadata.TypeVariableConstraint;
import io.prestosql.operator.aggregation.TypedSet;
import io.prestosql.operator.scalar.ChoicesScalarFunctionImplementation;
import io.prestosql.operator.scalar.ScalarFunctionImplementation;
import io.prestosql.spi.ErrorCodeSupplier;
import io.prestosql.spi.PageBuilder;
import io.prestosql.spi.PrestoException;
import io.prestosql.spi.StandardErrorCode;
import io.prestosql.spi.block.Block;
import io.prestosql.spi.block.BlockBuilder;
import io.prestosql.spi.connector.ConnectorSession;
import io.prestosql.spi.function.InvocationConvention;
import io.prestosql.spi.type.MapType;
import io.prestosql.spi.type.Type;
import io.prestosql.spi.type.TypeSignature;
import io.prestosql.spi.type.TypeSignatureParameter;
import io.prestosql.sql.gen.BytecodeUtils;
import io.prestosql.sql.gen.CallSiteBinder;
import io.prestosql.sql.gen.SqlTypeBytecodeExpression;
import io.prestosql.sql.gen.lambda.BinaryFunctionInterface;
import io.prestosql.type.BlockTypeOperators;
import io.prestosql.type.UnknownType;
import io.prestosql.util.CompilerUtils;
import io.prestosql.util.Reflection;
import java.lang.invoke.MethodHandle;
import java.util.List;
import java.util.Objects;
import java.util.Optional;

public final class MapTransformKeysFunction
extends SqlScalarFunction {
    public static final String NAME = "transform_keys";
    private static final MethodHandle STATE_FACTORY = Reflection.methodHandle(MapTransformKeysFunction.class, "createState", MapType.class);
    private final BlockTypeOperators blockTypeOperators;

    public MapTransformKeysFunction(BlockTypeOperators blockTypeOperators) {
        super(new FunctionMetadata(new Signature(NAME, (List<TypeVariableConstraint>)ImmutableList.of((Object)Signature.typeVariable("K1"), (Object)Signature.typeVariable("K2"), (Object)Signature.typeVariable("V")), (List<LongVariableConstraint>)ImmutableList.of(), TypeSignature.mapType((TypeSignature)new TypeSignature("K2", new TypeSignatureParameter[0]), (TypeSignature)new TypeSignature("V", new TypeSignatureParameter[0])), (List<TypeSignature>)ImmutableList.of((Object)TypeSignature.mapType((TypeSignature)new TypeSignature("K1", new TypeSignatureParameter[0]), (TypeSignature)new TypeSignature("V", new TypeSignatureParameter[0])), (Object)TypeSignature.functionType((TypeSignature)new TypeSignature("K1", new TypeSignatureParameter[0]), (TypeSignature[])new TypeSignature[]{new TypeSignature("V", new TypeSignatureParameter[0]), new TypeSignature("K2", new TypeSignatureParameter[0])})), false), false, (List<FunctionArgumentDefinition>)ImmutableList.of((Object)new FunctionArgumentDefinition(false), (Object)new FunctionArgumentDefinition(false)), false, false, "Apply lambda to each entry of the map and transform the key", FunctionKind.SCALAR));
        this.blockTypeOperators = Objects.requireNonNull(blockTypeOperators, "blockTypeOperators is null");
    }

    @Override
    protected ScalarFunctionImplementation specialize(FunctionBinding functionBinding) {
        Type keyType = functionBinding.getTypeVariable("K1");
        Type transformedKeyType = functionBinding.getTypeVariable("K2");
        Type valueType = functionBinding.getTypeVariable("V");
        MapType resultMapType = (MapType)functionBinding.getBoundSignature().getReturnType();
        return new ChoicesScalarFunctionImplementation(functionBinding, InvocationConvention.InvocationReturnConvention.FAIL_ON_NULL, (List<InvocationConvention.InvocationArgumentConvention>)ImmutableList.of((Object)InvocationConvention.InvocationArgumentConvention.NEVER_NULL, (Object)InvocationConvention.InvocationArgumentConvention.FUNCTION), (List<Class<?>>)ImmutableList.of(BinaryFunctionInterface.class), this.generateTransformKey(keyType, transformedKeyType, valueType, (Type)resultMapType), Optional.of(STATE_FACTORY.bindTo(resultMapType)));
    }

    @UsedByGeneratedCode
    public static Object createState(MapType mapType) {
        return new PageBuilder((List)ImmutableList.of((Object)mapType));
    }

    private MethodHandle generateTransformKey(Type keyType, Type transformedKeyType, Type valueType, Type resultMapType) {
        BytecodeBlock throwDuplicatedKeyException;
        BytecodeBlock writeKeyElement;
        CallSiteBinder binder = new CallSiteBinder();
        Class keyJavaType = Primitives.wrap((Class)keyType.getJavaType());
        Class transformedKeyJavaType = Primitives.wrap((Class)transformedKeyType.getJavaType());
        Class valueJavaType = Primitives.wrap((Class)valueType.getJavaType());
        ClassDefinition definition = new ClassDefinition(Access.a((Access[])new Access[]{Access.PUBLIC, Access.FINAL}), CompilerUtils.makeClassName("MapTransformKey"), ParameterizedType.type(Object.class), new ParameterizedType[0]);
        definition.declareDefaultConstructor(Access.a((Access[])new Access[]{Access.PRIVATE}));
        Parameter state = Parameter.arg((String)"state", Object.class);
        Parameter session = Parameter.arg((String)"session", ConnectorSession.class);
        Parameter block = Parameter.arg((String)"block", Block.class);
        Parameter function = Parameter.arg((String)"function", BinaryFunctionInterface.class);
        MethodDefinition method = definition.declareMethod(Access.a((Access[])new Access[]{Access.PUBLIC, Access.STATIC}), "transform", ParameterizedType.type(Block.class), (Iterable)ImmutableList.of((Object)state, (Object)session, (Object)block, (Object)function));
        BytecodeBlock body = method.getBody();
        Scope scope = method.getScope();
        Variable positionCount = scope.declareVariable(Integer.TYPE, "positionCount");
        Variable position = scope.declareVariable(Integer.TYPE, "position");
        Variable pageBuilder = scope.declareVariable(PageBuilder.class, "pageBuilder");
        Variable mapBlockBuilder = scope.declareVariable(BlockBuilder.class, "mapBlockBuilder");
        Variable blockBuilder = scope.declareVariable(BlockBuilder.class, "blockBuilder");
        Variable typedSet = scope.declareVariable(TypedSet.class, "typeSet");
        Variable keyElement = scope.declareVariable(keyJavaType, "keyElement");
        Variable transformedKeyElement = scope.declareVariable(transformedKeyJavaType, "transformedKeyElement");
        Variable valueElement = scope.declareVariable(valueJavaType, "valueElement");
        body.append((BytecodeNode)positionCount.set(block.invoke("getPositionCount", Integer.TYPE, new BytecodeExpression[0])));
        body.append((BytecodeNode)pageBuilder.set(state.cast(PageBuilder.class)));
        body.append((BytecodeNode)new IfStatement().condition((BytecodeNode)pageBuilder.invoke("isFull", Boolean.TYPE, new BytecodeExpression[0])).ifTrue((BytecodeNode)pageBuilder.invoke("reset", Void.TYPE, new BytecodeExpression[0])));
        body.append((BytecodeNode)mapBlockBuilder.set(pageBuilder.invoke("getBlockBuilder", BlockBuilder.class, new BytecodeExpression[]{BytecodeExpressions.constantInt((int)0)})));
        body.append((BytecodeNode)blockBuilder.set(mapBlockBuilder.invoke("beginBlockEntry", BlockBuilder.class, new BytecodeExpression[0])));
        body.append((BytecodeNode)typedSet.set(BytecodeExpressions.invokeStatic(TypedSet.class, (String)"createEqualityTypedSet", TypedSet.class, (BytecodeExpression[])new BytecodeExpression[]{SqlTypeBytecodeExpression.constantType(binder, transformedKeyType), BytecodeUtils.loadConstant(binder, this.blockTypeOperators.getEqualOperator(transformedKeyType), BlockTypeOperators.BlockPositionEqual.class), BytecodeUtils.loadConstant(binder, this.blockTypeOperators.getHashCodeOperator(transformedKeyType), BlockTypeOperators.BlockPositionHashCode.class), BytecodeExpressions.divide((BytecodeExpression)positionCount, (BytecodeExpression)BytecodeExpressions.constantInt((int)2)), BytecodeExpressions.constantString((String)NAME)})));
        BytecodeBlock throwNullKeyException = new BytecodeBlock().append((BytecodeNode)BytecodeExpressions.newInstance(PrestoException.class, (BytecodeExpression[])new BytecodeExpression[]{BytecodeExpressions.getStatic((Class)StandardErrorCode.INVALID_FUNCTION_ARGUMENT.getDeclaringClass(), (String)"INVALID_FUNCTION_ARGUMENT").cast(ErrorCodeSupplier.class), BytecodeExpressions.constantString((String)"map key cannot be null")})).throwObject();
        SqlTypeBytecodeExpression keySqlType = SqlTypeBytecodeExpression.constantType(binder, keyType);
        BytecodeBlock loadKeyElement = !keyType.equals((Object)UnknownType.UNKNOWN) ? new BytecodeBlock().append((BytecodeNode)keyElement.set(keySqlType.getValue((BytecodeExpression)block, (BytecodeExpression)position).cast(keyJavaType))) : new BytecodeBlock().append((BytecodeNode)mapBlockBuilder.invoke("closeEntry", BlockBuilder.class, new BytecodeExpression[0]).pop()).append((BytecodeNode)keyElement.set(BytecodeExpressions.constantNull((Class)keyJavaType))).append((BytecodeNode)throwNullKeyException);
        SqlTypeBytecodeExpression valueSqlType = SqlTypeBytecodeExpression.constantType(binder, valueType);
        Object loadValueElement = !valueType.equals((Object)UnknownType.UNKNOWN) ? new IfStatement().condition((BytecodeNode)block.invoke("isNull", Boolean.TYPE, new BytecodeExpression[]{BytecodeExpressions.add((BytecodeExpression)position, (BytecodeExpression)BytecodeExpressions.constantInt((int)1))})).ifTrue((BytecodeNode)valueElement.set(BytecodeExpressions.constantNull((Class)valueJavaType))).ifFalse((BytecodeNode)valueElement.set(valueSqlType.getValue((BytecodeExpression)block, BytecodeExpressions.add((BytecodeExpression)position, (BytecodeExpression)BytecodeExpressions.constantInt((int)1))).cast(valueJavaType))) : new BytecodeBlock().append((BytecodeNode)valueElement.set(BytecodeExpressions.constantNull((Class)valueJavaType)));
        SqlTypeBytecodeExpression transformedKeySqlType = SqlTypeBytecodeExpression.constantType(binder, transformedKeyType);
        if (!transformedKeyType.equals((Object)UnknownType.UNKNOWN)) {
            writeKeyElement = new BytecodeBlock().append((BytecodeNode)transformedKeyElement.set(function.invoke("apply", Object.class, new BytecodeExpression[]{keyElement.cast(Object.class), valueElement.cast(Object.class)}).cast(transformedKeyJavaType))).append((BytecodeNode)new IfStatement().condition((BytecodeNode)BytecodeExpressions.equal((BytecodeExpression)transformedKeyElement, (BytecodeExpression)BytecodeExpressions.constantNull((Class)transformedKeyJavaType))).ifTrue((BytecodeNode)throwNullKeyException).ifFalse((BytecodeNode)new BytecodeBlock().append((BytecodeNode)SqlTypeBytecodeExpression.constantType(binder, transformedKeyType).writeValue((BytecodeExpression)blockBuilder, transformedKeyElement.cast(transformedKeyType.getJavaType()))).append((BytecodeNode)valueSqlType.invoke("appendTo", Void.TYPE, new BytecodeExpression[]{block, BytecodeExpressions.add((BytecodeExpression)position, (BytecodeExpression)BytecodeExpressions.constantInt((int)1)), blockBuilder}))));
            throwDuplicatedKeyException = new BytecodeBlock().append((BytecodeNode)mapBlockBuilder.invoke("closeEntry", BlockBuilder.class, new BytecodeExpression[0]).pop()).append((BytecodeNode)BytecodeExpressions.newInstance(PrestoException.class, (BytecodeExpression[])new BytecodeExpression[]{BytecodeExpressions.getStatic((Class)StandardErrorCode.INVALID_FUNCTION_ARGUMENT.getDeclaringClass(), (String)"INVALID_FUNCTION_ARGUMENT").cast(ErrorCodeSupplier.class), BytecodeExpressions.invokeStatic(String.class, (String)"format", String.class, (BytecodeExpression[])new BytecodeExpression[]{BytecodeExpressions.constantString((String)"Duplicate keys (%s) are not allowed"), BytecodeExpressions.newArray((ParameterizedType)ParameterizedType.type(Object[].class), (Iterable)ImmutableList.of((Object)transformedKeySqlType.invoke("getObjectValue", Object.class, new BytecodeExpression[]{session, blockBuilder.cast(Block.class), position})))})})).throwObject();
        } else {
            writeKeyElement = throwNullKeyException;
            throwDuplicatedKeyException = throwNullKeyException;
        }
        body.append((BytecodeNode)new ForLoop().initialize((BytecodeNode)position.set(BytecodeExpressions.constantInt((int)0))).condition((BytecodeNode)BytecodeExpressions.lessThan((BytecodeExpression)position, (BytecodeExpression)positionCount)).update((BytecodeNode)VariableInstruction.incrementVariable((Variable)position, (byte)2)).body((BytecodeNode)new BytecodeBlock().append((BytecodeNode)loadKeyElement).append((BytecodeNode)loadValueElement).append((BytecodeNode)writeKeyElement).append((BytecodeNode)new IfStatement().condition((BytecodeNode)typedSet.invoke("contains", Boolean.TYPE, new BytecodeExpression[]{blockBuilder.cast(Block.class), position})).ifTrue((BytecodeNode)throwDuplicatedKeyException).ifFalse((BytecodeNode)typedSet.invoke("add", Void.TYPE, new BytecodeExpression[]{blockBuilder.cast(Block.class), position})))));
        body.append((BytecodeNode)mapBlockBuilder.invoke("closeEntry", BlockBuilder.class, new BytecodeExpression[0]).pop());
        body.append((BytecodeNode)pageBuilder.invoke("declarePosition", Void.TYPE, new BytecodeExpression[0]));
        body.append((BytecodeNode)SqlTypeBytecodeExpression.constantType(binder, resultMapType).invoke("getObject", Object.class, new BytecodeExpression[]{mapBlockBuilder.cast(Block.class), BytecodeExpressions.subtract((BytecodeExpression)mapBlockBuilder.invoke("getPositionCount", Integer.TYPE, new BytecodeExpression[0]), (BytecodeExpression)BytecodeExpressions.constantInt((int)1))}).ret());
        Class<Object> generatedClass = CompilerUtils.defineClass(definition, Object.class, binder.getBindings(), MapTransformKeysFunction.class.getClassLoader());
        return Reflection.methodHandle(generatedClass, "transform", Object.class, ConnectorSession.class, Block.class, BinaryFunctionInterface.class);
    }
}

