-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmain.py
More file actions
125 lines (78 loc) · 4.09 KB
/
main.py
File metadata and controls
125 lines (78 loc) · 4.09 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
import pandas as pd #pandas库,用来处理抓取的数据
import matplotlib.pyplot as plt #matplotlib库,用来绘制股价走势图
import os #os库,用来文件处理方面
from stock_crawler import StockCrawler
#从stock_crawler.py中导入封装好的SimpleStockCrawler类,核心功能
import matplotlib
matplotlib.use('Agg')
# 强制matplotlib使用Agg后端,避免Tcl/Tk依赖,彻底解决init.tcl报错
# 设置中文显示
plt.rcParams['font.sans-serif'] = ['SimHei'] # 用来正常显示中文标签
plt.rcParams['axes.unicode_minus'] = False # 用来正常显示负号
class StockVisualizer: #核心类
#方法
def __init__(self, csv_path): #初始化方法,接受CSV文件路径作为参数,加__XX__设置方法为特殊方法,构造函数,在创建实例时自动调用
self.csv_path = csv_path #CSV文件路径
self.df = None #用于存储加载后的数据,初始为None,后续会被load_data方法加载数据并赋值
self.load_data() #保存数据后,立即加载数据以供后续分析和可视化使用
def load_data(self): #加载数据的方法
# 检查文件是否存在,如果存在则加载数据到DataFrame中,并预处理
if os.path.exists(self.csv_path): #检查CSV文件是否存在
self.df = pd.read_csv(self.csv_path) #read
self.df['日期'] = pd.to_datetime(self.df['日期'])
self.df.sort_values('日期', inplace=True) #按日期排序
print(f"😎成功加载数据: {self.csv_path}")
else:
print(f"😭文件不存在: {self.csv_path}")
def draw(self, stock_code):
#绘图的方法,根据股价
if self.df is None:# df不存在就直接返回,不然报错
return
#绘图的核心代码
plt.figure(figsize=(12, 6))
plt.plot(self.df['日期'], self.df['收盘价'], label='收盘价', color='blue')
plt.title(f'股票代码 {stock_code} 股价走势')
plt.xlabel('日期')
plt.ylabel('价格 (RMB)')
plt.legend() #显示图例
plt.grid(True) #显示网格
# 自动调整日期格式,好吧没懂这串代码的工作原理,高度打包好的预制代码,能用就直接用了
plt.gcf().autofmt_xdate()
# 新增:保存到根目录的stock_plot文件夹
plot_dir = 'stock_plot'
if not os.path.exists(plot_dir):
os.makedirs(plot_dir)
save_path = os.path.join(plot_dir, f'stock_{stock_code}_plot.png')
plt.savefig(save_path)
print(f"股价走势图已保存至: {save_path}")
# plt.show() # 移除,避免后端报错
def analyze_risedoswntrends(self):
#分析上升下降时间段的方法,核心是计算每日价格变化,并标记上涨和下跌的时间段
if self.df is None:
return
# 计算每日价格变化
self.df['price_diff'] = self.df['收盘价'].diff()
# 标记上涨和下跌
rising_periods = self.df[self.df['price_diff'] > 0] #上升段
falling_periods = self.df[self.df['price_diff'] < 0] #下降段
return rising_periods, falling_periods
def main():
stock_code = '000001' # 默认抓取000001
# 抓取数据
crawler = StockCrawler()
df = crawler.fetch_data(stock_code)
if df is not None: #如果成功抓取到数据
crawler.save_data(df, stock_code) #保存数据
# 可视化分析
csv_file = os.path.join("stock_data", f"stock_{stock_code}.csv")
#根据股票代码构建CSV文件路径,方便后续操作
visualizer = StockVisualizer(csv_file)
#创建实例,传入CSV文件路径,进行数据加载和预处理
visualizer.draw(stock_code)
#调用方法绘制股价走势图,传入股票代码用于图表标题和文件命名
visualizer.analyze_risedoswntrends()
#分析上升下降时间段
else:
print("抓取数据失败😭,可能是网络问题或股票代码有问题,建议检查日志")
if __name__ == "__main__":
main()