{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "6840265c",
   "metadata": {},
   "source": [
    "# Chunking Service dynamic dataset example\n",
    "\n",
    "**Requirements:** DataRobot Python SDK, credentials, named credential for the dataset, and the `SAMPLE_DATA_TO_START_PROJECT` feature flag enabled.\n",
    "\n",
    "Sample helpers use module-level `DR_API_TOKEN` and `DR_ENDPOINT` (set in the credentials cell).\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6df829b4",
   "metadata": {},
   "source": [
    "## Authentication"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "371ca563",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "\n",
    "DR_API_TOKEN = os.environ.get(\"DR_API_TOKEN\", \"\")\n",
    "DR_ENDPOINT = os.environ.get(\"DR_ENDPOINT\", \"\")\n",
    "\n",
    "if not DR_API_TOKEN or not DR_ENDPOINT:\n",
    "    raise ValueError(\"Set DR_API_TOKEN and DR_ENDPOINT in the environment or edit this cell.\")\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b3cd2021",
   "metadata": {},
   "source": [
    "## Import libraries"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e6639d02",
   "metadata": {},
   "outputs": [],
   "source": [
    "import math\n",
    "import time\n",
    "from typing import Dict, List, Optional\n",
    "\n",
    "import datarobot as dr\n",
    "import pandas as pd\n",
    "from datarobot import UseCase\n",
    "from datarobot.enums import ChunkingPartitionMethod, ChunkingStrategy\n",
    "from datarobot.models.chunking_service_v2 import ChunkDefinition, DatasetDefinition\n",
    "from datarobot.models.project import Project\n",
    "from datarobot.utils.waiters import wait_for_async_resolution\n",
    "\n",
    "SAMPLE_SIZE = 500 * 1024**2  # bytes (500 MB)\n",
    "\n",
    "_ = dr.Client(token=DR_API_TOKEN, endpoint=DR_ENDPOINT)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "478aa754",
   "metadata": {},
   "source": [
    "## Configure pipeline functions\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "632c3730",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_credential_id(credential_name: str) -> str:\n",
    "    credentials = [cr for cr in dr.Credential.list() if cr.name == credential_name]\n",
    "    if not credentials:\n",
    "        raise ValueError(f\"No credential found with name: {credential_name}\")\n",
    "    print(f\"Credential found: {credential_name}\")\n",
    "    return str(credentials[0].credential_id)\n",
    "\n",
    "\n",
    "def add_project_to_use_case(use_case_id: str, project_id: str) -> None:\n",
    "    project = dr.Project.get(project_id)\n",
    "    use_case = UseCase.get(use_case_id=use_case_id)\n",
    "    use_case.add(entity=project)\n",
    "\n",
    "\n",
    "def create_dataset_definition(\n",
    "    dataset_id: str,\n",
    "    dataset_version_id: Optional[str],\n",
    "    name: str,\n",
    "    credential_name: str,\n",
    ") -> DatasetDefinition:\n",
    "    credential_id = get_credential_id(credential_name)\n",
    "    dataset_definition = DatasetDefinition.create(\n",
    "        dataset_id, dataset_version_id, name=name, credentials_id=credential_id\n",
    "    )\n",
    "    print(f\"Created dataset definition: {dataset_definition.id}\")\n",
    "    DatasetDefinition.analyze(dataset_definition.id)\n",
    "    dataset_definition = DatasetDefinition.get(dataset_definition.id)\n",
    "    dataset_info = dataset_definition.dataset_info\n",
    "    if dataset_info is not None:\n",
    "        print(\n",
    "            f\"Dataset info - Total rows: {dataset_info.total_rows}, \"\n",
    "            f\"Estimated size per row: {dataset_info.estimated_size_per_row}\"\n",
    "        )\n",
    "    return dataset_definition\n",
    "\n",
    "\n",
    "def create_sample_row_based(dataset_definition: DatasetDefinition) -> str:\n",
    "    client = dr.Client(token=DR_API_TOKEN, endpoint=DR_ENDPOINT)\n",
    "    dataset_info = dataset_definition.dataset_info\n",
    "    if dataset_info is None:\n",
    "        raise ValueError(\"Dataset definition has no dataset_info; ensure it has been analyzed\")\n",
    "    estimated_size_per_row = dataset_info.estimated_size_per_row\n",
    "    dataset_props = dataset_definition.dataset_props\n",
    "    sample_size_rows = math.ceil(SAMPLE_SIZE / estimated_size_per_row)\n",
    "    print(f\"Row count for {SAMPLE_SIZE / (1024**2):.1f}MB: {sample_size_rows}\")\n",
    "    sample_rows_payload = {\n",
    "        \"samplingStrategy\": {\n",
    "            \"directive\": \"efficient-rowbased-sample\",\n",
    "            \"arguments\": {\"size\": sample_size_rows, \"samplingMethod\": \"rows\"},\n",
    "        }\n",
    "    }\n",
    "    response = client.post(\n",
    "        f\"{DR_ENDPOINT}/datasets/{dataset_props.dataset_id}/samples\",\n",
    "        json=sample_rows_payload,\n",
    "    )\n",
    "    sample_response = response.json()\n",
    "    row_based_sample_version_id = sample_response[\"catalogVersionId\"]\n",
    "    wait_for_async_resolution(client, response.headers[\"Location\"])\n",
    "    response = client.get(\n",
    "        f\"{DR_ENDPOINT}/datasets/{dataset_props.dataset_id}/versions/?category=SAMPLE\"\n",
    "    )\n",
    "    samples = response.json()\n",
    "    if samples[\"totalCount\"] > 0:\n",
    "        sample_df = pd.DataFrame(samples[\"data\"])\n",
    "        mask = sample_df[\"categories\"].apply(lambda x: \"SAMPLE\" in x)\n",
    "        sample_df = sample_df[mask]\n",
    "        assert row_based_sample_version_id in sample_df[\"versionId\"].tolist(), (\n",
    "            f\"Row based sample version {row_based_sample_version_id} not found in samples\"\n",
    "        )\n",
    "    else:\n",
    "        raise ValueError(f\"No sample version found for dataset {dataset_props.dataset_id}\")\n",
    "    return str(row_based_sample_version_id)\n",
    "\n",
    "\n",
    "def create_sample_percentage_based(dataset_definition: DatasetDefinition) -> str:\n",
    "    client = dr.Client(token=DR_API_TOKEN, endpoint=DR_ENDPOINT)\n",
    "    dataset_info = dataset_definition.dataset_info\n",
    "    if dataset_info is None:\n",
    "        raise ValueError(\"Dataset definition has no dataset_info; ensure it has been analyzed\")\n",
    "    source_size = dataset_info.source_size\n",
    "    dataset_props = dataset_definition.dataset_props\n",
    "    sample_size_percentage = (SAMPLE_SIZE / source_size) * 100\n",
    "    print(f\"Percentage of dataset size for {SAMPLE_SIZE / (1024**2):.1f}MB: {sample_size_percentage:.2f}%\")\n",
    "    sample_percentage_payload = {\n",
    "        \"samplingStrategy\": {\n",
    "            \"directive\": \"efficient-rowbased-sample\",\n",
    "            \"arguments\": {\"size\": sample_size_percentage, \"samplingMethod\": \"percent\"},\n",
    "        }\n",
    "    }\n",
    "    response = client.post(\n",
    "        f\"{DR_ENDPOINT}/datasets/{dataset_props.dataset_id}/samples\",\n",
    "        json=sample_percentage_payload,\n",
    "    )\n",
    "    sample_response = response.json()\n",
    "    percentage_sample_version_id = sample_response[\"catalogVersionId\"]\n",
    "    wait_for_async_resolution(client, response.headers[\"Location\"])\n",
    "    response = client.get(\n",
    "        f\"{DR_ENDPOINT}/datasets/{dataset_props.dataset_id}/versions/?category=SAMPLE\"\n",
    "    )\n",
    "    samples = response.json()\n",
    "    if samples[\"totalCount\"] > 0:\n",
    "        sample_df = pd.DataFrame(samples[\"data\"])\n",
    "        mask = sample_df[\"categories\"].apply(lambda x: \"SAMPLE\" in x)\n",
    "        sample_df = sample_df[mask]\n",
    "        assert percentage_sample_version_id in sample_df[\"versionId\"].tolist(), (\n",
    "            f\"Percentage sample version {percentage_sample_version_id} not found in samples\"\n",
    "        )\n",
    "    else:\n",
    "        raise ValueError(f\"No sample version found for dataset {dataset_props.dataset_id}\")\n",
    "    return str(percentage_sample_version_id)\n",
    "\n",
    "\n",
    "def create_project(\n",
    "    dataset_id: str,\n",
    "    target_column: str,\n",
    "    chunk_definition: ChunkDefinition,\n",
    "    use_case_id: Optional[str] = None,\n",
    "    datetime_partition_column: Optional[str] = None,\n",
    ") -> Project:\n",
    "    project: Project = dr.Project.create_from_dataset(\n",
    "        dataset_id,\n",
    "        project_name=\"Dynamic Dataset Project with Incremental Learning OTV\",\n",
    "        use_sample_from_dataset=True,\n",
    "        max_wait=6000,\n",
    "    )\n",
    "    project_partitioning_method = None\n",
    "    if datetime_partition_column is not None:\n",
    "        print(f\"Setting up datetime partitioning for project {project.id}\")\n",
    "        spec = dr.DatetimePartitioningSpecification(\n",
    "            datetime_partition_column=datetime_partition_column,\n",
    "            use_time_series=False,\n",
    "        )\n",
    "        full_part = dr.DatetimePartitioning.generate_optimized(project.id, spec, target_column)\n",
    "        project_partitioning_method = dr.helpers.partitioning_methods.DatetimePartitioningId(\n",
    "            full_part.datetime_partitioning_id, project.id\n",
    "        )\n",
    "        print(f\"Datetime partitioning set for project {project.id}\")\n",
    "    else:\n",
    "        print(f\"No datetime partitioning; proceeding with defaults for project {project.id}\")\n",
    "    if use_case_id:\n",
    "        add_project_to_use_case(use_case_id=use_case_id, project_id=project.id)\n",
    "    advanced_options = dr.helpers.AdvancedOptions(\n",
    "        incremental_learning_only_mode=True,\n",
    "        incremental_learning_on_best_model=True,\n",
    "        chunk_definition_id=chunk_definition.id,\n",
    "        incremental_learning_early_stopping_rounds=0,\n",
    "    )\n",
    "    project.analyze_and_model(\n",
    "        target=target_column,\n",
    "        mode=dr.enums.AUTOPILOT_MODE.QUICK,\n",
    "        partitioning_method=project_partitioning_method,\n",
    "        advanced_options=advanced_options,\n",
    "        worker_count=-1,\n",
    "        max_wait=6000,\n",
    "    )\n",
    "    return project\n",
    "\n",
    "\n",
    "def create_chunk_definition(\n",
    "    dataset_definition: DatasetDefinition,\n",
    "    sorting_columns: List[str],\n",
    "    target_column: Optional[str] = None,\n",
    "    datetime_partition_column: Optional[str] = None,\n",
    "    target_class: Optional[str] = None,\n",
    "    chunking_partition_method: Optional[ChunkingPartitionMethod] = ChunkingPartitionMethod.RANDOM,\n",
    ") -> ChunkDefinition:\n",
    "    partition_args_dict: Dict[str, Optional[str]]\n",
    "    if chunking_partition_method == ChunkingPartitionMethod.RANDOM:\n",
    "        partition_args_dict = {\n",
    "            \"target_column\": None,\n",
    "            \"target_class\": None,\n",
    "            \"datetime_partition_column\": None,\n",
    "        }\n",
    "    elif chunking_partition_method == ChunkingPartitionMethod.STRATIFIED:\n",
    "        partition_args_dict = {\n",
    "            \"target_column\": target_column,\n",
    "            \"target_class\": target_class,\n",
    "            \"datetime_partition_column\": None,\n",
    "        }\n",
    "    else:\n",
    "        partition_args_dict = {\n",
    "            \"target_column\": None,\n",
    "            \"target_class\": None,\n",
    "            \"datetime_partition_column\": datetime_partition_column,\n",
    "        }\n",
    "    chunk_definition = ChunkDefinition.create(\n",
    "        dataset_definition.id,\n",
    "        partition_method=chunking_partition_method,\n",
    "        chunking_strategy_type=ChunkingStrategy.ROWS,\n",
    "        target_column=partition_args_dict[\"target_column\"],\n",
    "        target_class=partition_args_dict[\"target_class\"],\n",
    "        datetime_partition_column=partition_args_dict[\"datetime_partition_column\"],\n",
    "        order_by_columns=sorting_columns,\n",
    "    )\n",
    "    print(f\"Created chunk definition: {chunk_definition.id}\")\n",
    "    ChunkDefinition.analyze(dataset_definition.id, chunk_definition.id)\n",
    "    chunk_definition = ChunkDefinition.get(dataset_definition.id, chunk_definition.id)\n",
    "    stats = chunk_definition.chunk_definition_stats\n",
    "    if stats is not None:\n",
    "        print(f\"Chunk definition stats - Number of chunks: {stats.total_number_of_chunks}\")\n",
    "    return chunk_definition\n",
    "\n",
    "\n",
    "def run_dynamic_dataset_pipeline(\n",
    "    dataset_id: str,\n",
    "    target_column: str,\n",
    "    sorting_columns: List[str],\n",
    "    credential_name: str,\n",
    "    target_class: Optional[str] = None,\n",
    "    chunking_partition_method: Optional[ChunkingPartitionMethod] = ChunkingPartitionMethod.RANDOM,\n",
    "    dataset_version_id: Optional[str] = None,\n",
    "    use_case_id: Optional[str] = None,\n",
    "    datetime_partition_column: Optional[str] = None,\n",
    ") -> Project:\n",
    "    start_time = time.time()\n",
    "    print(\"Starting dynamic dataset pipeline...\")\n",
    "    dataset_definition = create_dataset_definition(\n",
    "        dataset_id, dataset_version_id, name=\"Dynamic Dataset\", credential_name=credential_name\n",
    "    )\n",
    "    print(\"Creating samples...\")\n",
    "    _ = create_sample_row_based(dataset_definition)\n",
    "    print(\"Creating chunk definition...\")\n",
    "    chunk_definition = create_chunk_definition(\n",
    "        dataset_definition,\n",
    "        sorting_columns,\n",
    "        target_column,\n",
    "        datetime_partition_column,\n",
    "        target_class=target_class,\n",
    "        chunking_partition_method=chunking_partition_method,\n",
    "    )\n",
    "    print(\"Creating and starting project...\")\n",
    "    project = create_project(\n",
    "        dataset_id, target_column, chunk_definition, use_case_id, datetime_partition_column\n",
    "    )\n",
    "    print(f\"Time taken: {time.time() - start_time:.1f}s — Project ID: {project.id}\")\n",
    "    return project\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b27f6978",
   "metadata": {},
   "source": [
    "## Configure and run\n",
    "\n",
    "Edit variables, then run. **STRATIFIED** needs `TARGET_CLASS`; **DATE** needs `DATETIME_PARTITION_COLUMN`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "468d52f3",
   "metadata": {},
   "outputs": [],
   "source": [
    "# --- edit these ---\n",
    "DATASET_ID = \"your-dataset-id\"\n",
    "TARGET_COLUMN = \"your_target\"\n",
    "SORTING_COLUMNS = [\"TIME\", \"SERIES\"]  # order_by columns for chunking\n",
    "CREDENTIAL_NAME = \"your-credential-name\"  # must exist in DataRobot\n",
    "DATASET_VERSION_ID = None\n",
    "USE_CASE_ID = None\n",
    "TARGET_CLASS = None\n",
    "DATETIME_PARTITION_COLUMN = None\n",
    "CHUNKING_PARTITION_METHOD = \"RANDOM\"  # RANDOM | STRATIFIED | DATE\n",
    "\n",
    "method = ChunkingPartitionMethod[CHUNKING_PARTITION_METHOD.upper()]\n",
    "if method == ChunkingPartitionMethod.DATE and not DATETIME_PARTITION_COLUMN:\n",
    "    raise ValueError(\"DATETIME_PARTITION_COLUMN required for DATE\")\n",
    "if method == ChunkingPartitionMethod.STRATIFIED and not TARGET_CLASS:\n",
    "    raise ValueError(\"TARGET_CLASS required for STRATIFIED\")\n",
    "\n",
    "project = run_dynamic_dataset_pipeline(\n",
    "    dataset_id=DATASET_ID,\n",
    "    target_column=TARGET_COLUMN,\n",
    "    sorting_columns=SORTING_COLUMNS,\n",
    "    credential_name=CREDENTIAL_NAME,\n",
    "    target_class=TARGET_CLASS,\n",
    "    chunking_partition_method=method,\n",
    "    dataset_version_id=DATASET_VERSION_ID,\n",
    "    use_case_id=USE_CASE_ID,\n",
    "    datetime_partition_column=DATETIME_PARTITION_COLUMN,\n",
    ")\n",
    "print(f\"Project ID: {project.id}\")\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9d66216a",
   "metadata": {},
   "source": [
    "## Optional: percentage-based sample\n",
    "\n",
    "After `create_dataset_definition`, you can call `create_sample_percentage_based(dataset_definition)` instead of row-based inside a custom flow.\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
