Python Code

Config 파일 관리 (yaml)

Kimhj 2023. 9. 19. 14:13
  • 코드내 path 등을 관리 하기 위해 config 파일을 사용함.
  • .ini 나 .yalm 등 여러 방법이 있지만, yaml을 자주 사용하고 있음.
  • config 파일 경로 적어서 Read 시, 현재 폴더를 기준으로 변경해주는 아래 코드를 실행한다.
current_path = osp.dirname(osp.realpath(__file__))

 

  • config.yaml 예시
Database:
  host: 
  port: 
  user: 
  passwd: 
  dbname: 
  schema:  
  datatype: 

Database_wave:
  host: 
  port: 
  user: 
  passwd: 
  dbname: 
  schema:  
  datatype: 

# SSH 터널 사용 시
SSHTunnel: 
  host: 
  port: 
  user: 
  passwd: 
  dbname: 
  schema:  
  datatype:

 

  • yaml 파일 로드 
# module directory 내 utils.py 생성

import os
import yaml, json, pickle
import re
import collections
import time
from shutil import copyfile
from configparser import ConfigParser

def save_pickle(fname, data):
    with open(fname, "wb") as f:
        pickle.dump(data, f)


def load_pickle(fname):
    with open(fname, "rb") as f:
        data = pickle.load(f)
    return data


def load_json(fname):
    with open(fname, "r") as f:
        data = json.load(f)
    return data


def save_json(fname, data):
    with open(fname, "w") as json_file:
        json.dump(data, json_file)


def load_yaml(fname):
    with open(fname, "r") as fp:
        data = yaml.safe_load(fp)
    return data


def write_yaml(fname, obj):
    with open(fname, "w") as f:
        yaml.dump(obj, f)


def load_ini(fname):
    parser=ConfigParser()
    config = parser.read(fname)
    return config

def get_time_str(format="%Y%m%d-%Hh%Mm%Ss"):
    return time.strftime(format, time.localtime(time.time() + 9 * 3600))

 

  • PostgreSQL 에 접속해서 사용 
import os.path as osp
import sys
import psycopg2 as pg
from sshtunnel import SSHTunnelForwarder
import pandas as pd
from typing import Dict

current_path = osp.dirname(osp.realpath(__file__))
sys.path.append(current_path)
import utils as cutils

class PostgreSQLDatabase:
    def __init__(self, config_path: str):
        self.db_config = self.load_db_config(config_path)
        self.conn = None

    @staticmethod
    def load_db_config(filename="config/db_config.yaml") -> Dict:
        return cutils.load_yaml(filename)

    def connect(self):
        if "SSHTunnel" in self.db_config:
            ssh_tunnel = SSHTunnelForwarder(
                (self.db_config["SSHTunnel"]["host"], self.db_config["SSHTunnel"]["port"]),
                ssh_username=self.db_config["SSHTunnel"]["user"],
                ssh_password=self.db_config["SSHTunnel"]["passwd"],
                remote_bind_address=(
                    self.db_config["Database"]["host"],
                    self.db_config["Database"]["port"],
                ),
            )
            ssh_tunnel.start()

            self.conn = pg.connect(
                host="localhost",
                user=self.db_config["Database"]["user"],
                password=self.db_config["Database"]["passwd"],
                port=ssh_tunnel.local_bind_port,
                database=self.db_config["Database"]["dbname"],
            )
        else:
            self.conn = pg.connect(
                host=self.db_config["Database"]["host"],
                user=self.db_config["Database"]["user"],
                password=self.db_config["Database"]["passwd"],
                port=self.db_config["Database"]["port"],
                database=self.db_config["Database"]["dbname"],
            )
        return self.conn

    def execute_sql(self, query: str):
        if not self.conn:
            self.connect()
        db_cursor = self.conn.cursor()
        db_cursor.execute(query)
        return db_cursor.fetchall()

    def create_df_from_query(self, query: str) -> pd.DataFrame:
        if not self.conn:
            self.connect()
        return pd.read_sql_query(query, self.conn)

    def create_table(self, df: pd.DataFrame, table_name: str, schema_name: str):
        if not self.conn:
            self.connect()
        df.to_sql(
            con=self.conn, name=table_name, schema=schema_name, if_exists="replace", index=False
        )

if __name__ == "__main__":
    db = PostgreSQLDatabase(current_path + "config/db_config.yaml")
    # query = f"select * from {db.db_config['Database']['schema']}.{db.db_config['Database']['datatype']} limit 100"
    # res = db.execute_sql(query)
    # df = db.create_df_from_query(query)