Loading

Task 1: Next Product Recommendation

Next Product Recommendation

曾桉提交

Zengan

曾桉提交

In [22]:
!pip install aicrowd-cli pyarrow
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Requirement already satisfied: aicrowd-cli in /usr/local/lib/python3.10/dist-packages (0.1.15)
Requirement already satisfied: pyarrow in /usr/local/lib/python3.10/dist-packages (9.0.0)
Requirement already satisfied: click<8,>=7.1.2 in /usr/local/lib/python3.10/dist-packages (from aicrowd-cli) (7.1.2)
Requirement already satisfied: GitPython==3.1.18 in /usr/local/lib/python3.10/dist-packages (from aicrowd-cli) (3.1.18)
Requirement already satisfied: requests<3,>=2.25.1 in /usr/local/lib/python3.10/dist-packages (from aicrowd-cli) (2.27.1)
Requirement already satisfied: requests-toolbelt<1,>=0.9.1 in /usr/local/lib/python3.10/dist-packages (from aicrowd-cli) (0.10.1)
Requirement already satisfied: rich<11,>=10.0.0 in /usr/local/lib/python3.10/dist-packages (from aicrowd-cli) (10.16.2)
Requirement already satisfied: toml<1,>=0.10.2 in /usr/local/lib/python3.10/dist-packages (from aicrowd-cli) (0.10.2)
Requirement already satisfied: tqdm<5,>=4.56.0 in /usr/local/lib/python3.10/dist-packages (from aicrowd-cli) (4.65.0)
Requirement already satisfied: pyzmq==22.1.0 in /usr/local/lib/python3.10/dist-packages (from aicrowd-cli) (22.1.0)
Requirement already satisfied: python-slugify<6,>=5.0.0 in /usr/local/lib/python3.10/dist-packages (from aicrowd-cli) (5.0.2)
Requirement already satisfied: semver<3,>=2.13.0 in /usr/local/lib/python3.10/dist-packages (from aicrowd-cli) (2.13.0)
Requirement already satisfied: gitdb<5,>=4.0.1 in /usr/local/lib/python3.10/dist-packages (from GitPython==3.1.18->aicrowd-cli) (4.0.10)
Requirement already satisfied: numpy>=1.16.6 in /usr/local/lib/python3.10/dist-packages (from pyarrow) (1.22.4)
Requirement already satisfied: text-unidecode>=1.3 in /usr/local/lib/python3.10/dist-packages (from python-slugify<6,>=5.0.0->aicrowd-cli) (1.3)
Requirement already satisfied: urllib3<1.27,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests<3,>=2.25.1->aicrowd-cli) (1.26.15)
Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests<3,>=2.25.1->aicrowd-cli) (2022.12.7)
Requirement already satisfied: charset-normalizer~=2.0.0 in /usr/local/lib/python3.10/dist-packages (from requests<3,>=2.25.1->aicrowd-cli) (2.0.12)
Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests<3,>=2.25.1->aicrowd-cli) (3.4)
Requirement already satisfied: colorama<0.5.0,>=0.4.0 in /usr/local/lib/python3.10/dist-packages (from rich<11,>=10.0.0->aicrowd-cli) (0.4.6)
Requirement already satisfied: commonmark<0.10.0,>=0.9.0 in /usr/local/lib/python3.10/dist-packages (from rich<11,>=10.0.0->aicrowd-cli) (0.9.1)
Requirement already satisfied: pygments<3.0.0,>=2.6.0 in /usr/local/lib/python3.10/dist-packages (from rich<11,>=10.0.0->aicrowd-cli) (2.14.0)
Requirement already satisfied: smmap<6,>=3.0.1 in /usr/local/lib/python3.10/dist-packages (from gitdb<5,>=4.0.1->GitPython==3.1.18->aicrowd-cli) (5.0.0)
In [23]:
!aicrowd login
Please login here: https://api.aicrowd.com/auth/8OWHWMDzcd8R9YOEcvmMrdVM44M2_M02iU1CBENcAsA
/usr/bin/xdg-open: 869: www-browser: not found
/usr/bin/xdg-open: 869: links2: not found
/usr/bin/xdg-open: 869: elinks: not found
/usr/bin/xdg-open: 869: links: not found
/usr/bin/xdg-open: 869: lynx: not found
/usr/bin/xdg-open: 869: w3m: not found
xdg-open: no method available for opening 'https://api.aicrowd.com/auth/8OWHWMDzcd8R9YOEcvmMrdVM44M2_M02iU1CBENcAsA'
API Key valid
Gitlab access token valid
Saved details successfully!
In [24]:
!aicrowd dataset download --challenge task-1-next-product-recommendation
sessions_test_task1_phase1.csv: 100% 19.4M/19.4M [00:01<00:00, 10.6MB/s]
sessions_test_task2_phase1.csv: 100% 1.92M/1.92M [00:00<00:00, 10.4MB/s]
sessions_test_task3_phase1.csv: 100% 2.67M/2.67M [00:00<00:00, 13.5MB/s]
sessions_test_task1.csv: 100% 19.3M/19.3M [00:01<00:00, 14.2MB/s]
sessions_test_task2.csv: 100% 1.91M/1.91M [00:00<00:00, 11.1MB/s]
sessions_test_task3.csv: 100% 2.67M/2.67M [00:00<00:00, 13.3MB/s]
products_train.csv: 100% 589M/589M [01:03<00:00, 9.26MB/s]
sessions_train.csv: 100% 259M/259M [00:37<00:00, 6.84MB/s]
In [25]:
import os
import numpy as np
import pandas as pd
from functools import lru_cache
In [26]:
train_data_dir = '.'
test_data_dir = '.'
task = 'task1'
PREDS_PER_SESSION = 100
In [27]:
# Cache loading of data for multiple calls

@lru_cache(maxsize=1)
def read_product_data():
    return pd.read_csv(os.path.join(train_data_dir, 'products_train.csv'))

@lru_cache(maxsize=1)
def read_train_data():
    return pd.read_csv(os.path.join(train_data_dir, 'sessions_train.csv'))

@lru_cache(maxsize=3)
def read_test_data(task):
    return pd.read_csv(os.path.join(test_data_dir, f'sessions_test_{task}.csv'))
In [28]:
def read_locale_data(locale, task):
    products = read_product_data().query(f'locale == "{locale}"')
    sess_train = read_train_data().query(f'locale == "{locale}"')
    sess_test = read_test_data(task).query(f'locale == "{locale}"')
    return products, sess_train, sess_test

def show_locale_info(locale, task):
    products, sess_train, sess_test = read_locale_data(locale, task)

    train_l = sess_train['prev_items'].apply(lambda sess: len(sess))
    test_l = sess_test['prev_items'].apply(lambda sess: len(sess))

    print(f"Locale: {locale} \n"
          f"Number of products: {products['id'].nunique()} \n"
          f"Number of train sessions: {len(sess_train)} \n"
          f"Train session lengths - "
          f"Mean: {train_l.mean():.2f} | Median {train_l.median():.2f} | "
          f"Min: {train_l.min():.2f} | Max {train_l.max():.2f} \n"
          f"Number of test sessions: {len(sess_test)}"
        )
    if len(sess_test) > 0:
        print(
             f"Test session lengths - "
            f"Mean: {test_l.mean():.2f} | Median {test_l.median():.2f} | "
            f"Min: {test_l.min():.2f} | Max {test_l.max():.2f} \n"
        )
    print("======================================================================== \n")
In [29]:
products = read_product_data()
locale_names = products['locale'].unique()
for locale in locale_names:
    show_locale_info(locale, task)
Locale: DE 
Number of products: 518327 
Number of train sessions: 1111416 
Train session lengths - Mean: 57.89 | Median 40.00 | Min: 27.00 | Max 2060.00 
Number of test sessions: 104568
Test session lengths - Mean: 56.91 | Median 40.00 | Min: 27.00 | Max 1043.00 

======================================================================== 

Locale: JP 
Number of products: 395009 
Number of train sessions: 979119 
Train session lengths - Mean: 59.61 | Median 40.00 | Min: 27.00 | Max 6257.00 
Number of test sessions: 96467
Test session lengths - Mean: 59.84 | Median 40.00 | Min: 27.00 | Max 1466.00 

======================================================================== 

Locale: UK 
Number of products: 500180 
Number of train sessions: 1182181 
Train session lengths - Mean: 54.85 | Median 40.00 | Min: 27.00 | Max 2654.00 
Number of test sessions: 115937
Test session lengths - Mean: 53.25 | Median 40.00 | Min: 27.00 | Max 753.00 

======================================================================== 

Locale: ES 
Number of products: 42503 
Number of train sessions: 89047 
Train session lengths - Mean: 48.82 | Median 40.00 | Min: 27.00 | Max 792.00 
Number of test sessions: 0
======================================================================== 

Locale: FR 
Number of products: 44577 
Number of train sessions: 117561 
Train session lengths - Mean: 47.25 | Median 40.00 | Min: 27.00 | Max 687.00 
Number of test sessions: 0
======================================================================== 

Locale: IT 
Number of products: 50461 
Number of train sessions: 126925 
Train session lengths - Mean: 48.80 | Median 40.00 | Min: 27.00 | Max 621.00 
Number of test sessions: 0
======================================================================== 

In [30]:
products.sample(5)
Out[30]:
id locale title price brand color size model material author desc
545634 B004L21FCU JP 古河薬品工業(KYK) ロングライフクーラントエルコン補充液 赤 400ml[HTRC3] 145.00 古河薬品工業 NaN NaN 30-401 NaN NaN 30-401
1250290 B096442DF1 UK Dog Training Lead Leash Extra Long Line, 5m Ny... 5.99 5RIDGE Black 5 m (Pack of 1) NaN Nylon NaN 【 WEIVEL STYLE STAINLESS STEEL BUCKLE】 The dog...
475964 B01CR0GWLQ DE Miele 10231860 Duftflakon Aqua für ein frische... 29.90 Miele Aqua NaN 10231860 NaN NaN Aqua: pure Reinheit, frischer Duft und porenti...
189766 B09T5GN8CN DE MoKo 13,3-14" Laptop Hülle, Tragetasche Kompat... 30.99 MoKo Indigo 13.3-14 inch NaN NaN NaN 🔥HOCHWERTIGES MATERIAL - Die Polyesterfaserobe...
1192967 B0883S5F33 UK NETGEAR Orbi Tri-band 4G Router with SIM Slot ... 357.14 NETGEAR Blanc Routeur WiFi Mesh 4G 1.2 Gbit/s LBR20-100EUS NaN NaN ADVANCED CYBER THREAT PROTECTION: NETGEAR Armo...
In [31]:
train_sessions = read_train_data()
train_sessions.sample(5)
Out[31]:
prev_items next_item locale
3334638 ['B07C53WWSC' 'B0912LLRLW'] B085RN9GRX ES
3596014 ['B0B3VQN4MC' 'B0B3VSTQWG' 'B0798LTQVX'] B07N9MVJKB IT
2334968 ['B00SLB6WXE' 'B06XP21DK6'] B097H1YZ7H UK
2483820 ['B07ZZG8Z3D' 'B07ZZG8Z3D'] B09BZCG3SP UK
1263970 ['B0BG844PGF' 'B0BG844PGF' 'B09TZVSKW5' 'B08X4... B0B38KYDSD JP
In [32]:
test_sessions = read_test_data(task)
test_sessions.sample(5)
Out[32]:
prev_items locale
116589 ['B0BDJ1S7JN' 'B0B49YQM4L'] JP
70399 ['B09V1JDWL4' 'B0B9H8HS7K' 'B0B9H8HS7K'] DE
16533 ['B07D166BLR' 'B0090JSX8I'] DE
269175 ['B00VANZFGU' 'B09534TD8K' 'B07W7DZSD2' 'B0B93... UK
269480 ['B09RGWYG7G' 'B00WHXISDE' 'B00DE6FZCK' 'B00DE... UK
In [33]:
def random_predicitons(locale, sess_test_locale):
    random_state = np.random.RandomState(42)
    products = read_product_data().query(f'locale == "{locale}"')
    predictions = []
    for _ in range(len(sess_test_locale)):
        predictions.append(
            list(products['id'].sample(PREDS_PER_SESSION, replace=True, random_state=random_state))
        ) 
    sess_test_locale['next_item_prediction'] = predictions
    sess_test_locale.drop('prev_items', inplace=True, axis=1)
    return sess_test_locale
In [34]:
test_sessions = read_test_data(task)
predictions = []
test_locale_names = test_sessions['locale'].unique()
for locale in test_locale_names:
    sess_test_locale = test_sessions.query(f'locale == "{locale}"').copy()
    predictions.append(
        random_predicitons(locale, sess_test_locale)
    )
predictions = pd.concat(predictions).reset_index(drop=True)
predictions.sample(5)
Out[34]:
locale next_item_prediction
141077 JP [B0BJ1BGGHS, B000THNH0O, B08TR4D3PD, B07N2RQ4V...
301565 UK [B09HWQNVW4, B0B5RPH4M7, B0BF7CQ1YQ, B07F1MHNH...
300333 UK [B095P5KW8Y, B09LC5RM69, B096WVY9FR, B07MDXBTL...
252551 UK [B07ZSFY7ZS, B072L2GWDK, B09RN4F4Y7, B07KZYPS7...
219475 UK [B08Q2WRGZK, B09G2R7BHV, B07L7C9T1M, B09G2KHQP...
In [35]:
def check_predictions(predictions, check_products=False):
    """
    These tests need to pass as they will also be applied on the evaluator
    """
    test_locale_names = test_sessions['locale'].unique()
    for locale in test_locale_names:
        sess_test = test_sessions.query(f'locale == "{locale}"')
        preds_locale =  predictions[predictions['locale'] == sess_test['locale'].iloc[0]]
        assert sorted(preds_locale.index.values) == sorted(sess_test.index.values), f"Session ids of {locale} doesn't match"

        if check_products:
            # This check is not done on the evaluator
            # but you can run it to verify there is no mixing of products between locales
            # Since the ground truth next item will always belong to the same locale
            # Warning - This can be slow to run
            products = read_product_data().query(f'locale == "{locale}"')
            predicted_products = np.unique( np.array(list(preds_locale["next_item_prediction"].values)) )
            assert np.all( np.isin(predicted_products, products['id']) ), f"Invalid products in {locale} predictions"
In [36]:
check_predictions(predictions)
# Its important that the parquet file you submit is saved with pyarrow backend
predictions.to_parquet(f'submission_{task}.parquet', engine='pyarrow')
In [ ]:
# 读取产品数据的前100行
products = pd.read_csv('products_train.csv', nrows=100)
In [38]:
# 针对每个会话进行预测
predictions = []
for index, row in test_sessions.iterrows():
    session = row['prev_items']
    locale = row['locale']
    locale_products = products[products['locale'] == locale]
    
    if not locale_products.empty:
        predicted_products = predict_next_products(session, locale_products)
        predictions.append(predicted_products)
    else:
        predictions.append([])  # If no matching products, append an empty list

# 创建预测结果DataFrame
output_data = pd.DataFrame({'next_item': predictions})

# 将预测结果存储为parquet文件
output_data.to_parquet('predictions.parquet', index=False)
In [42]:
import pandas as pd

# 从Parquet文件读取数据
predictions = pd.read_parquet('predictions.parquet')

# 打印数据前几行
print(predictions.head(100))
                                            next_item
0   [B0B8VXQFRL, B08JVGH7RN, B08R62WZ1Y, B06XKPB3G...
1   [B07CKPD4RP, B000S1KWPE, B09QMCVW6P, B08BPH1WM...
2   [B08JQL5F31, B002RAWQ3K, B07NWCJZWV, B0881P6YF...
3   [B09JCC7FVK, B076G629HF, B07Y25MGQL, B09TLCZ14...
4   [B07KTKFYYS, B09D89LNLV, B08JQL5F31, B0B9MPKYJ...
..                                                ...
95  [B01NBLDOJN, B093GH3HK6, B07K15DDKQ, B094DGRV7...
96  [B0BHWPR939, B07NWCJZWV, B08VGQG3VP, B08DKHWNV...
97  [B099RB4X99, B00VW4SYQ0, B08B85T2J4, B06XKPB3G...
98  [B08JQL5F31, B005ZSSN10, B07QCQK6CV, B08H21WG6...
99  [B094V3614V, B08PW1VS3R, B0BD7145W7, B07XPFNBH...

[100 rows x 1 columns]
In [17]:
!aicrowd submission create -c task-1-next-product-recommendation -f "submission_task1.parquet"
Submission Error: You haven't registered for this challenge
Please go to the challenge homepage and register

Comments

You must login before you can post a comment.

Execute