0
respostas

Função para plotar múltiplos gráficos com Matplotlib

Olá! Estou querendo criar uma função que plote gráficos de acordo com um input dado pelo usuário. O usuário deve escolher a quantidade de curvas que deseja plotar na ordem em que deseja plotar e o nome da coluna do DataFrame que contém cada uma dessas curvas. É importante que os gráficos sejam dispostos um ao lado do outro, com a mesma escala vertical.

Desenvolvi o código abaixo, mas ele está retornando isso aqui:

Erro

O resultado esperado, é algo como abaixo: Resultado esperado

As funções criadas até o momento foram essas:

def cria_axe(well, curva, escala_horizontal, ordem=1, prof='MD', cor='black', nome_curva='log', log=False):
    ''' Função que cria axes em uma figura pré-existente do Matplotlib,
    focada no carregamento de well logs.
    ---
    parâmetros:
    - well: dataset com os dados
    - curva: mnemônico da curva (como escrito no dataset)
    - escala_horizontal: escala do log, como tupla
    - ordem: ordem do log (para mais de uma curva)
    - prof: mnemônico da profundidade (como escrito no dataset)
    - cor: cor da curva (padrão matplotlib)
    - nome: nome da curva (título para a curva)
    - log=False: se True, coloca a curva em escala logarítmica
    '''
    # plt.subplot2grid((n° de linhas, n° de colunas), (posição x da figura, posição y da figura))
    ax = plt.subplot2grid((1, ordem), (0, ordem-1))

    # Definindo os axis (curvas)
    ax.plot(curva, prof, data=well, color=cor) #definindo variáveis e cor
    ax.set_xlabel(nome_curva)   # Título da Curva
    ax.set_xlim(escala_horizontal)   # Limites do eixo x
    ax.invert_yaxis()   # Invertendo eixo Y
    ax.grid()    # Inserindo Grid

    # Aplicando escala logarítmica, caso True
    if log==True:
        ax.semilogx()

    return ax

def plot_curvas(well):
    ''' Função que plota várias curvas
    ---
    Parâmetros:
    well: dataset
    '''

    # Selecionando curvas
    print('Selecione as curvas que deseja plotar:')
    for x, item in enumerate(well.keys()):
        print(f'[{x}]\t', item)
    resposta = input('Coloque os números, separados por vírgulas (ex.: 1, 2, 4): ').replace(',', '').split()
    n_curvas = len(resposta)
    curvas = []
    for item in resposta:
        curvas.append(int(item))


    # Criando figura
    fig = plt.figure(figsize = (5, 25))

    # Criando Axes    

    for ordem, item in enumerate(curvas):
        fig = plt.subplots(1, (len(curvas) -1))
        cria_axe(well, well.keys()[item], (0, 150), ordem=ordem+1)


plot_curvas(teste)

Obrigado desde já!