#!/usr/bin/env python3
"""生成月度汇总Excel，含三张柱状图（冰激淋/饮料/合计）"""
import json
from openpyxl import Workbook
from openpyxl.chart import BarChart, LineChart, Reference
from openpyxl.chart.series import DataPoint
from openpyxl.chart.label import DataLabelList
from openpyxl.styles import Font, PatternFill, Alignment, Border, Side
from openpyxl.utils import get_column_letter
from copy import deepcopy

with open('/tmp/monthly_summary.json') as f:
    data = json.load(f)

wb = Workbook()

# ── Color scheme ──
ICE_COLOR = '4472C4'     # blue
BEV_COLOR = 'ED7D31'     # orange
COMB_COLOR = 'A5A5A5'    # gray
LINE1_COLOR = 'FF0000'   # red for count
LINE2_COLOR = '70AD47'   # green for avg

header_fill = PatternFill(start_color='2F5496', end_color='2F5496', fill_type='solid')
header_font = Font(bold=True, color='FFFFFF', size=11)
thin_border = Border(
    left=Side(style='thin'), right=Side(style='thin'),
    top=Side(style='thin'), bottom=Side(style='thin')
)

def create_chart_sheet(wb, sheet_name, series_data, category, bar_color):
    """Create a sheet with data + bar chart"""
    ws = wb.create_sheet(title=sheet_name)
    
    # ── Write data table ──
    headers = ['月份', '销售额', '机器数', '单机月均']
    for c, h in enumerate(headers, 1):
        cell = ws.cell(row=1, column=c, value=h)
        cell.font = header_font
        cell.fill = header_fill
        cell.alignment = Alignment(horizontal='center')
        cell.border = thin_border
    
    months = []
    for i, r in enumerate(series_data):
        row = i + 2
        ws.cell(row=row, column=1, value=r['month']).border = thin_border
        ws.cell(row=row, column=1).alignment = Alignment(horizontal='center')
        ws.cell(row=row, column=2, value=r['total']).border = thin_border
        ws.cell(row=row, column=2).number_format = '#,##0'
        ws.cell(row=row, column=3, value=r['count']).border = thin_border
        ws.cell(row=row, column=3).alignment = Alignment(horizontal='center')
        ws.cell(row=row, column=4, value=r['avg']).border = thin_border
        ws.cell(row=row, column=4).number_format = '#,##0'
        months.append(r['month'])
    
    # Total row
    tr = len(series_data) + 2
    total_sales = sum(r['total'] for r in series_data)
    ws.cell(row=tr, column=1, value='累计').font = Font(bold=True)
    ws.cell(row=tr, column=1).border = thin_border
    ws.cell(row=tr, column=2, value=round(total_sales, 2)).font = Font(bold=True)
    ws.cell(row=tr, column=2).number_format = '#,##0'
    ws.cell(row=tr, column=2).border = thin_border
    
    # Column widths
    ws.column_dimensions['A'].width = 12
    ws.column_dimensions['B'].width = 14
    ws.column_dimensions['C'].width = 10
    ws.column_dimensions['D'].width = 14
    
    data_rows = len(series_data)
    
    # ── Bar Chart (销售额) ──
    chart = BarChart()
    chart.type = 'col'
    chart.style = 10
    chart.title = f'{category}月度销售额趋势（2025.01 - 2026.05）'
    chart.y_axis.title = '销售额 (¥)'
    chart.x_axis.title = '月份'
    chart.width = 28
    chart.height = 16
    
    cats = Reference(ws, min_col=1, min_row=2, max_row=data_rows + 1)
    data_ref = Reference(ws, min_col=2, min_row=1, max_row=data_rows + 1)
    chart.add_data(data_ref, titles_from_data=True)
    chart.set_categories(cats)
    
    # Color the bars
    chart.series[0].graphicalProperties.solidFill = bar_color
    
    # Data labels
    chart.series[0].dLbls = DataLabelList()
    chart.series[0].dLbls.showVal = True
    chart.series[0].dLbls.numFmt = '#,##0'
    
    # ── Line Chart for 机器数 (secondary axis) ──
    line1 = LineChart()
    count_ref = Reference(ws, min_col=3, min_row=1, max_row=data_rows + 1)
    line1.add_data(count_ref, titles_from_data=True)
    line1.y_axis.axId = 200
    line1.y_axis.title = '机器数'
    line1.y_axis.crosses = 'max'
    line1.series[0].graphicalProperties.solidFill = LINE1_COLOR
    line1.series[0].graphicalProperties.line.solidFill = LINE1_COLOR
    line1.series[0].marker.symbol = 'circle'
    line1.series[0].marker.size = 6
    
    chart.y_axis.crosses = 'min'
    chart += line1
    
    # ── Line Chart for 单机月均 (tersier axis) ──
    line2 = LineChart()
    avg_ref = Reference(ws, min_col=4, min_row=1, max_row=data_rows + 1)
    line2.add_data(avg_ref, titles_from_data=True)
    line2.y_axis.axId = 300
    line2.y_axis.title = '单机月均 (¥)'
    line2.y_axis.crosses = 'max'
    line2.series[0].graphicalProperties.solidFill = LINE2_COLOR
    line2.series[0].graphicalProperties.line.solidFill = LINE2_COLOR
    line2.series[0].marker.symbol = 'diamond'
    line2.series[0].marker.size = 6
    
    chart += line2
    
    # Position chart below data
    ws.add_chart(chart, f'A{tr + 3}')
    
    return ws

# ── Create 3 chart sheets ──
create_chart_sheet(wb, '🍦冰激淋月度趋势', data['ice_cream'], '冰激淋', ICE_COLOR)
create_chart_sheet(wb, '🥤饮料月度趋势', data['beverage'], '饮料', BEV_COLOR)
create_chart_sheet(wb, '📊合计月度趋势', data['combined'], '全品类（冰激淋+饮料）', COMB_COLOR)

# ── Summary sheet ──
ws_summary = wb.create_sheet(title='汇总对比', index=0)
ws_summary.column_dimensions['A'].width = 12

# Headers
sum_headers = ['月份', '冰激淋销售额', '冰激淋机器数', '冰单机月均', 
               '饮料销售额', '饮料机器数', '饮单机月均',
               '合计销售额', '合计机器数', '合计单机月均']
for c, h in enumerate(sum_headers, 1):
    cell = ws_summary.cell(row=1, column=c, value=h)
    cell.font = Font(bold=True, color='FFFFFF', size=9)
    cell.fill = header_fill
    cell.alignment = Alignment(horizontal='center', wrap_text=True)
    cell.border = thin_border

for i in range(len(data['combined'])):
    row = i + 2
    ice = data['ice_cream'][i]
    bev = data['beverage'][i]
    comb = data['combined'][i]
    
    vals = [
        comb['month'],
        ice['total'], ice['count'], ice['avg'],
        bev['total'], bev['count'], bev['avg'],
        comb['total'], comb['count'], comb['avg'],
    ]
    for c, v in enumerate(vals, 1):
        cell = ws_summary.cell(row=row, column=c, value=v)
        cell.border = thin_border
        if c > 1:
            cell.number_format = '#,##0' if c in (2,5,8) else ('#,##0' if c in (4,7,10) else '0')
        cell.alignment = Alignment(horizontal='center')

# Column widths for summary
for col in range(1, 11):
    ws_summary.column_dimensions[get_column_letter(col)].width = 12

# Remove default sheet
if 'Sheet' in wb.sheetnames:
    del wb['Sheet']

filepath = '/mnt/c/Users/kingw/Desktop/三系统月度销售汇总_2025年1月至2026年5月.xlsx'
wb.save(filepath)
print(f'✅ 已保存: {filepath}')
print(f'包含4个工作表: 汇总对比 + 3张图表')
