import matplotlib.pyplot as plt

# Kyte-Doolittle hydropathy scale dictionary
hydropathy_scale = {
    'A': 1.8, 'C': 2.5, 'D': -3.5, 'E': -3.5, 'F': 2.8,
    'G': -0.4, 'H': -3.2, 'I': 4.5, 'K': -3.9, 'L': 3.8,
    'M': 1.9, 'N': -3.5, 'P': -1.6, 'Q': -3.5, 'R': -4.5,
    'S': -0.8, 'T': -0.7, 'V': 4.2, 'W': -0.9, 'Y': -1.3
}

def calculate_hydropathy(sequence, window_size=1):
    """
    Calculates hydropathy values for each amino acid in the sequence.
    Applies sliding window average if window_size > 1.
    Returns:
        - list of hydropathy values (smoothed if window_size > 1)
        - average hydropathy score of the full sequence
    """
    sequence = sequence.upper()
    values = [hydropathy_scale.get(aa, 0) for aa in sequence]

    if window_size > 1:
        averaged = [
            sum(values[i:i+window_size]) / window_size
            for i in range(len(values) - window_size + 1)
        ]
    else:
        averaged = values

    avg_hydropathy = sum(values) / len(values) if values else 0
    return averaged, avg_hydropathy

def plot_hydropathy(sequence, window_size=1):
    """
    Plots the hydropathy values along the sequence.
    Fills positive values in red and negative values in blue.
    """
    hydropathy_values, avg_score = calculate_hydropathy(sequence, window_size)
    x_positions = range(1, len(hydropathy_values) + 1)

    plt.figure(figsize=(10, 5))
    plt.plot(x_positions, hydropathy_values, color='black', linewidth=1.5)

    # Fill above 0 in red
    plt.fill_between(
        x_positions, 0, hydropathy_values,
        where=[y >= 0 for y in hydropathy_values],
        color='red', alpha=0.3
    )

    # Fill below 0 in blue
    plt.fill_between(
        x_positions, 0, hydropathy_values,
        where=[y < 0 for y in hydropathy_values],
        color='blue', alpha=0.3
    )

    plt.axhline(0, color='gray', linestyle='--')
    plt.title(f"Hydropathy Plot (Avg Score = {avg_score:.2f})")
    plt.xlabel("Position in Sequence")
    plt.ylabel("Hydropathy Index")
    plt.grid(True)
    plt.show()

if __name__ == "__main__":
    seq = input("Enter amino acid sequence: ")
    try:
        window = int(input("Enter smoothing window size (1 for none): "))
        if window < 1:
            print("Window size must be at least 1, setting to 1.")
            window = 1
    except ValueError:
        print("Invalid input for window size, setting to 1.")
        window = 1

    plot_hydropathy(seq, window_size=window)
