diff --git a/src/agents/sql_agent_graph.py b/src/agents/sql_agent_graph.py index 6e39284..8ce18b8 100644 --- a/src/agents/sql_agent_graph.py +++ b/src/agents/sql_agent_graph.py @@ -29,6 +29,7 @@ def resource_path(relative_path): # --- 프롬프트 로드 --- INTENT_CLASSIFIER_PROMPT = load_prompt(resource_path(os.path.join(PROMPT_DIR, "intent_classifier.yaml"))) +DB_CLASSIFIER_PROMPT = load_prompt(resource_path(os.path.join(PROMPT_DIR, "db_classifier.yaml"))) SQL_GENERATOR_PROMPT = load_prompt(resource_path(os.path.join(PROMPT_DIR, "sql_generator.yaml"))) RESPONSE_SYNTHESIZER_PROMPT = load_prompt(resource_path(os.path.join(PROMPT_DIR, "response_synthesizer.yaml"))) @@ -58,6 +59,51 @@ def unsupported_question_node(state: SqlAgentState): state['final_response'] = "죄송합니다, 해당 질문에는 답변할 수 없습니다. 데이터베이스 관련 질문만 가능합니다." return state +def db_classifier_node(state: SqlAgentState): + print("--- 0.5. DB 분류 중 ---") + + # TODO: BE API 호출로 대체 필요 + available_dbs = [ + { + "connection_name": "local_mysql", + "database_name": "sakila", + "description": "DVD 대여점 비즈니스 모델을 다루는 샘플 데이터베이스로, 영화, 배우, 고객, 대여 기록 등의 정보를 포함합니다." + }, + { + "connection_name": "local_mysql", + "database_name": "ecom_prod", + "description": "온라인 쇼핑몰의 운영 데이터베이스로, 상품 카탈로그, 고객 주문, 재고 및 배송 정보를 관리합니다." + }, + { + "connection_name": "local_mysql", + "database_name": "hr_analytics", + "description": "회사의 인사 관리 데이터베이스로, 직원 정보, 급여, 부서, 성과 평가 기록을 포함합니다." + }, + { + "connection_name": "local_mysql", + "database_name": "web_logs", + "description": "웹사이트 트래픽 분석을 위한 로그 데이터베이스로, 사용자 방문 기록, 페이지 뷰, 에러 로그 등을 저장합니다." + } + ] + + db_options = "\n".join([f"- {db['database_name']}: {db['description']}" for db in available_dbs]) + + chain = DB_CLASSIFIER_PROMPT | llm_instance | StrOutputParser() + selected_db_name = chain.invoke({ + "db_options": db_options, + "question": state['question'] + }) + + state['selected_db'] = selected_db_name.strip() + + # 선택된 DB의 스키마 정보를 가져와서 상태에 업데이트합니다. + print(f'--- 선택된 DB: {selected_db_name} ---') + + # TODO: get_schema_for_db + state['db_schema'] = db_instance.get_schema_for_db(db_name=selected_db_name) + + return state + def sql_generator_node(state: SqlAgentState): print("--- 1. SQL 생성 중 ---") parser = PydanticOutputParser(pydantic_object=SqlQuery) @@ -163,7 +209,7 @@ def response_synthesizer_node(state: SqlAgentState): def route_after_intent_classification(state: SqlAgentState): if state['intent'] == "SQL": print("--- 의도: SQL 관련 질문 ---") - return "sql_generator" + return "db_classifier" print("--- 의도: SQL과 관련 없는 질문 ---") return "unsupported_question" @@ -192,6 +238,7 @@ def create_sql_agent_graph() -> StateGraph: graph = StateGraph(SqlAgentState) graph.add_node("intent_classifier", intent_classifier_node) + graph.add_node("db_classifier", db_classifier_node) graph.add_node("unsupported_question", unsupported_question_node) graph.add_node("sql_generator", sql_generator_node) graph.add_node("sql_validator", sql_validator_node) @@ -204,11 +251,13 @@ def create_sql_agent_graph() -> StateGraph: "intent_classifier", route_after_intent_classification, { - "sql_generator": "sql_generator", + "db_classifier": "db_classifier", "unsupported_question": "unsupported_question" } ) graph.add_edge("unsupported_question", END) + + graph.add_edge("db_classifier", "sql_generator") graph.add_edge("sql_generator", "sql_validator") diff --git a/src/prompts/v1/sql_agent/db_classifier.yaml b/src/prompts/v1/sql_agent/db_classifier.yaml new file mode 100644 index 0000000..b4e2270 --- /dev/null +++ b/src/prompts/v1/sql_agent/db_classifier.yaml @@ -0,0 +1,15 @@ +_type: "prompt" +input_variables: + - db_options + - question +template: | + Based on the user's question, which of the following databases is most likely to contain the answer? + Please respond with only the database name. + + Available databases: + {db_options} + + User Question: + {question} + + Selected Database: \ No newline at end of file