-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathrandom_forest.py
More file actions
65 lines (48 loc) · 1.83 KB
/
random_forest.py
File metadata and controls
65 lines (48 loc) · 1.83 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error
from sklearn.ensemble import RandomForestRegressor
import seaborn as sns
from query import query
import pandas as pd
import matplotlib.pyplot as plt
df = pd.DataFrame(query())
# Assign names to the columns
df.columns = ['month', 'year', 'product', 'quantity', 'total']
# Drop 'product' column
df = df.drop(columns=['product'])
# Transform month and year into a int type
df['month'] = df['month'].astype(int)
df['year'] = df['year'].astype(int)
# Create a date column by combining 'year' and 'month
df['date'] = pd.to_datetime(df[['year', 'month']].assign(day=1))
print(df.dtypes)
# Features and target variables
X = df.drop(columns=['quantity', 'total', 'date'])
y = df['quantity']
# Split data into trainig and testing sets
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)
# Train the model
model = RandomForestRegressor()
model.fit(X_train, y_train)
# Predict and evaluate
predictions = model.predict(X_test)
mse = mean_squared_error(y_test, predictions)
print(f"Mean Squared Error: {round(mse, 2)}")
# Visualization
plt.figure(figsize=(12, 8))
# Plot of actual sales
sns.lineplot(x=df['date'], y=df['quantity'], label='Real Sales', marker='o')
# Prepare the predictions for plotting
X_test['predicted_quantity'] = predictions
X_test['date'] = df.loc[X_test.index, 'date']
sns.lineplot(x=X_test['date'], y=X_test['predicted_quantity'], label='Predicted Sales', linestyle='--', marker='o')
# Prepare date range from 2022 to 2023
dates = pd.date_range(start='2022-01', end='2023-12', freq='MS')
plt.xticks(dates, labels=[date.strftime('%Y-%m') for date in dates], rotation=45)
# Configure the plot
plt.xlabel('Date')
plt.ylabel('Sales')
plt.title('Real Sales vs Predictions')
plt.legend()
plt.grid(True)
plt.show()