/*
 * Decompiled with CFR 0.152.
 */
package io.prestosql.execution;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Ordering;
import io.airlift.log.Logger;
import io.prestosql.execution.SqlTask;
import io.prestosql.execution.SqlTaskManager;
import io.prestosql.execution.TaskManagementExecutor;
import io.prestosql.execution.TaskState;
import io.prestosql.memory.LocalMemoryManager;
import io.prestosql.memory.MemoryPool;
import io.prestosql.memory.MemoryPoolListener;
import io.prestosql.memory.QueryContext;
import io.prestosql.memory.TraversingQueryContextVisitor;
import io.prestosql.memory.VoidTraversingQueryContextVisitor;
import io.prestosql.operator.OperatorContext;
import io.prestosql.sql.analyzer.FeaturesConfig;
import java.util.Collection;
import java.util.Comparator;
import java.util.List;
import java.util.Objects;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.ScheduledFuture;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicLong;
import java.util.function.Supplier;
import javax.annotation.Nullable;
import javax.annotation.PostConstruct;
import javax.annotation.PreDestroy;
import javax.inject.Inject;

public class MemoryRevokingScheduler {
    private static final Logger log = Logger.get(MemoryRevokingScheduler.class);
    private static final Ordering<SqlTask> ORDER_BY_CREATE_TIME = Ordering.natural().onResultOf(task -> task.getTaskInfo().getStats().getCreateTime());
    private final List<MemoryPool> memoryPools;
    private final Supplier<? extends Collection<SqlTask>> currentTasksSupplier;
    private final ScheduledExecutorService taskManagementExecutor;
    private final double memoryRevokingThreshold;
    private final double memoryRevokingTarget;
    private final MemoryPoolListener memoryPoolListener = MemoryPoolListener.onMemoryReserved(this::onMemoryReserved);
    @Nullable
    private ScheduledFuture<?> scheduledFuture;
    private final AtomicBoolean checkPending = new AtomicBoolean();

    @Inject
    public MemoryRevokingScheduler(LocalMemoryManager localMemoryManager, SqlTaskManager sqlTaskManager, TaskManagementExecutor taskManagementExecutor, FeaturesConfig config) {
        this((List<MemoryPool>)ImmutableList.copyOf(MemoryRevokingScheduler.getMemoryPools(localMemoryManager)), Objects.requireNonNull(sqlTaskManager, "sqlTaskManager cannot be null")::getAllTasks, Objects.requireNonNull(taskManagementExecutor, "taskManagementExecutor cannot be null").getExecutor(), config.getMemoryRevokingThreshold(), config.getMemoryRevokingTarget());
    }

    @VisibleForTesting
    MemoryRevokingScheduler(List<MemoryPool> memoryPools, Supplier<? extends Collection<SqlTask>> currentTasksSupplier, ScheduledExecutorService taskManagementExecutor, double memoryRevokingThreshold, double memoryRevokingTarget) {
        this.memoryPools = ImmutableList.copyOf((Collection)Objects.requireNonNull(memoryPools, "memoryPools is null"));
        this.currentTasksSupplier = Objects.requireNonNull(currentTasksSupplier, "currentTasksSupplier is null");
        this.taskManagementExecutor = Objects.requireNonNull(taskManagementExecutor, "taskManagementExecutor is null");
        this.memoryRevokingThreshold = MemoryRevokingScheduler.checkFraction(memoryRevokingThreshold, "memoryRevokingThreshold");
        this.memoryRevokingTarget = MemoryRevokingScheduler.checkFraction(memoryRevokingTarget, "memoryRevokingTarget");
        Preconditions.checkArgument((memoryRevokingTarget <= memoryRevokingThreshold ? 1 : 0) != 0, (String)"memoryRevokingTarget should be less than or equal memoryRevokingThreshold, but got %s and %s respectively", (Object)memoryRevokingTarget, (Object)memoryRevokingThreshold);
    }

    private static double checkFraction(double value, String valueName) {
        Objects.requireNonNull(valueName, "valueName is null");
        Preconditions.checkArgument((0.0 <= value && value <= 1.0 ? 1 : 0) != 0, (String)"%s should be within [0, 1] range, got %s", (Object)valueName, (Object)value);
        return value;
    }

    private static List<MemoryPool> getMemoryPools(LocalMemoryManager localMemoryManager) {
        Objects.requireNonNull(localMemoryManager, "localMemoryManager cannot be null");
        ImmutableList.Builder builder = new ImmutableList.Builder();
        builder.add((Object)localMemoryManager.getGeneralPool());
        localMemoryManager.getReservedPool().ifPresent(arg_0 -> ((ImmutableList.Builder)builder).add(arg_0));
        return builder.build();
    }

    @PostConstruct
    public void start() {
        this.registerPeriodicCheck();
        this.registerPoolListeners();
    }

    private void registerPeriodicCheck() {
        this.scheduledFuture = this.taskManagementExecutor.scheduleWithFixedDelay(() -> {
            try {
                this.requestMemoryRevokingIfNeeded();
            }
            catch (Throwable e) {
                log.error(e, "Error requesting system memory revoking");
            }
        }, 1L, 1L, TimeUnit.SECONDS);
    }

    @PreDestroy
    public void stop() {
        if (this.scheduledFuture != null) {
            this.scheduledFuture.cancel(true);
            this.scheduledFuture = null;
        }
        this.memoryPools.forEach(memoryPool -> memoryPool.removeListener(this.memoryPoolListener));
    }

    @VisibleForTesting
    void registerPoolListeners() {
        this.memoryPools.forEach(memoryPool -> memoryPool.addListener(this.memoryPoolListener));
    }

    private void onMemoryReserved(MemoryPool memoryPool) {
        try {
            if (!this.memoryRevokingNeeded(memoryPool)) {
                return;
            }
            if (this.checkPending.compareAndSet(false, true)) {
                log.debug("Scheduling check for %s", new Object[]{memoryPool});
                this.scheduleRevoking();
            }
        }
        catch (Throwable e) {
            log.error(e, "Error when acting on memory pool reservation");
        }
    }

    @VisibleForTesting
    void requestMemoryRevokingIfNeeded() {
        if (this.checkPending.compareAndSet(false, true)) {
            this.runMemoryRevoking();
        }
    }

    private void scheduleRevoking() {
        this.taskManagementExecutor.execute(() -> {
            try {
                this.runMemoryRevoking();
            }
            catch (Throwable e) {
                log.error(e, "Error requesting memory revoking");
            }
        });
    }

    private synchronized void runMemoryRevoking() {
        if (this.checkPending.getAndSet(false)) {
            Collection<SqlTask> sqlTasks = null;
            for (MemoryPool memoryPool : this.memoryPools) {
                if (!this.memoryRevokingNeeded(memoryPool)) continue;
                if (sqlTasks == null) {
                    sqlTasks = Objects.requireNonNull(this.currentTasksSupplier.get());
                }
                this.requestMemoryRevoking(memoryPool, sqlTasks);
            }
        }
    }

    private void requestMemoryRevoking(MemoryPool memoryPool, Collection<SqlTask> sqlTasks) {
        long remainingBytesToRevoke = (long)((double)(-memoryPool.getFreeBytes()) + (double)memoryPool.getMaxBytes() * (1.0 - this.memoryRevokingTarget));
        this.requestRevoking(memoryPool, sqlTasks, remainingBytesToRevoke -= this.getMemoryAlreadyBeingRevoked(sqlTasks, memoryPool));
    }

    private boolean memoryRevokingNeeded(MemoryPool memoryPool) {
        return memoryPool.getReservedRevocableBytes() > 0L && (double)memoryPool.getFreeBytes() <= (double)memoryPool.getMaxBytes() * (1.0 - this.memoryRevokingThreshold);
    }

    private long getMemoryAlreadyBeingRevoked(Collection<SqlTask> sqlTasks, MemoryPool memoryPool) {
        return sqlTasks.stream().filter(task -> task.getTaskStatus().getState() == TaskState.RUNNING).filter(task -> task.getQueryContext().getMemoryPool() == memoryPool).mapToLong(task -> task.getQueryContext().accept(new TraversingQueryContextVisitor<Void, Long>(){

            @Override
            public Long visitOperatorContext(OperatorContext operatorContext, Void context) {
                if (operatorContext.isMemoryRevokingRequested()) {
                    return operatorContext.getReservedRevocableBytes();
                }
                return 0L;
            }

            @Override
            public Long mergeResults(List<Long> childrenResults) {
                return childrenResults.stream().mapToLong(i -> i).sum();
            }
        }, null)).sum();
    }

    private void requestRevoking(final MemoryPool memoryPool, Collection<SqlTask> sqlTasks, long remainingBytesToRevoke) {
        AtomicLong remainingBytesToRevokeAtomic = new AtomicLong(remainingBytesToRevoke);
        sqlTasks.stream().filter(task -> task.getTaskStatus().getState() == TaskState.RUNNING).filter(task -> task.getQueryContext().getMemoryPool() == memoryPool).sorted((Comparator<SqlTask>)ORDER_BY_CREATE_TIME).forEach(task -> task.getQueryContext().accept(new VoidTraversingQueryContextVisitor<AtomicLong>(){

            @Override
            public Void visitQueryContext(QueryContext queryContext, AtomicLong remainingBytesToRevoke) {
                if (remainingBytesToRevoke.get() < 0L) {
                    return null;
                }
                return (Void)super.visitQueryContext(queryContext, remainingBytesToRevoke);
            }

            @Override
            public Void visitOperatorContext(OperatorContext operatorContext, AtomicLong remainingBytesToRevoke) {
                long revokedBytes;
                if (remainingBytesToRevoke.get() > 0L && (revokedBytes = operatorContext.requestMemoryRevoking()) > 0L) {
                    remainingBytesToRevoke.addAndGet(-revokedBytes);
                    log.debug("memoryPool=%s: requested revoking %s; remaining %s", new Object[]{memoryPool.getId(), revokedBytes, remainingBytesToRevoke.get()});
                }
                return null;
            }
        }, remainingBytesToRevokeAtomic));
    }
}

