import numpy as np

# Define the grid size and discount factor
grid_sizeX = 3
grid_sizeY = 4
gamma = 0.1
epsilon = 1e-4 #tolerance

#Define rewards and the initial value function
rewards = np.full((grid_sizeX, grid_sizeY), -0.04)
#goal_state = (grid_sizeX-1, grid_sizeY-1)
goal_state = (0, grid_sizeY-1)
rewards[goal_state] = 1  # Goal state reward

U = np.zeros((grid_sizeX, grid_sizeY))

#Allowed actions: up, down, left, right
actions = [(-1, 0), (1, 0), (0, -1), (0, 1)]

#Check validity of an action
def is_valid(state):
    x, y = state
    return 0 <= x < grid_sizeX and 0 <= y < grid_sizeY

#Define the probabilities.
#The inteded direction has probability "main_action_index"
def get_probabilities(main_action_index):
    probabilities = [0.1] * len(actions)
    probabilities[main_action_index] = 0.7
    return probabilities

#Value Iteration Algorithm
def value_iteration(U, rewards, gamma, epsilon, actions):
    iterations = 0
    while True:
        delta = 0
        for i in range(grid_sizeX):
            for j in range(grid_sizeY):
                if (i, j) == goal_state: continue  # Skip goal state
                u = U[i, j]
                action_values = []
                for main_action_index, (di, dj) in enumerate(actions):
                    new_u = 0
                    probabilities = get_probabilities(main_action_index)
                    for (dii, djj), p in zip(actions, probabilities):
                        next_state = (i + dii, j + djj)
                        if is_valid(next_state):
                            new_u += p * (rewards[next_state] + gamma * U[next_state])
                        else:
                            new_u += p * (rewards[i, j] + gamma * U[i, j])  # If out of grid, stay
                    action_values.append(new_u)
                U[i, j] = max(action_values)  # Choose the action with the highest value
                delta = max(delta, abs(u - U[i, j]))
        if (gamma==1): gamma-=1e-5 #Avoid infinite loop
        if delta < epsilon*(1-gamma)/gamma: break
        iterations += 1
    return U, iterations

# Run value iteration
U, iterations = value_iteration(U, rewards, gamma, epsilon, actions)

# Print the result
print(U)
print("Iterations = ",iterations)
