Coverage for src / ai_lls_lib / payment / webhook_processor.py: 90%

200 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-05-06 23:45 +0000

1"""Stripe webhook event processing.""" 

2 

3import base64 

4import json 

5import logging 

6from datetime import UTC, datetime 

7from decimal import Decimal 

8from typing import Any 

9 

10try: 

11 import stripe 

12 

13 HAS_STRIPE = True 

14except ImportError: 

15 stripe = None # type: ignore[assignment] 

16 HAS_STRIPE = False 

17 

18try: 

19 from botocore.exceptions import ClientError 

20 

21 HAS_BOTOCORE = True 

22except ImportError: 

23 HAS_BOTOCORE = False 

24 

25from .credit_manager import CreditManager 

26 

27logger = logging.getLogger(__name__) 

28 

29INITIAL_SUBSCRIPTION_CREDITS = 1000 

30RENEWAL_CREDITS = 1000 

31 

32 

33class WebhookProcessor: 

34 """Process Stripe webhook events.""" 

35 

36 def __init__(self, webhook_secret: str, credit_manager: CreditManager): 

37 """Initialize with webhook secret and credit manager.""" 

38 self.webhook_secret = webhook_secret 

39 self.credit_manager = credit_manager 

40 

41 def verify_and_parse(self, payload: str, signature: str) -> dict[str, Any]: 

42 """Verify webhook signature and parse event.""" 

43 if not HAS_STRIPE or not stripe: 

44 raise ImportError("stripe package not installed") 

45 

46 try: 

47 event = stripe.Webhook.construct_event(payload, signature, self.webhook_secret) 

48 return dict(event) 

49 except ValueError as e: 

50 logger.error(f"Invalid webhook payload: {e}") 

51 raise 

52 except stripe.error.SignatureVerificationError as e: 

53 logger.error(f"Invalid webhook signature: {e}") 

54 raise 

55 

56 def process_event(self, event: dict[str, Any]) -> dict[str, Any]: 

57 """ 

58 Process a verified webhook event. 

59 Returns response data. 

60 """ 

61 event_type = event.get("type") 

62 event_data = event.get("data", {}).get("object", {}) 

63 

64 logger.info(f"Processing webhook event: {event_type}") 

65 

66 if event_type == "payment_intent.succeeded": 

67 return self._handle_payment_intent_succeeded(event_data) 

68 

69 elif event_type == "checkout.session.completed": 

70 return self._handle_checkout_completed(event_data) 

71 

72 elif event_type == "customer.subscription.created": 

73 return self._handle_subscription_created(event_data) 

74 

75 elif event_type == "customer.subscription.updated": 

76 return self._handle_subscription_updated(event_data) 

77 

78 elif event_type == "customer.subscription.deleted": 

79 return self._handle_subscription_deleted(event_data) 

80 

81 elif event_type == "invoice.payment_succeeded": 

82 return self._handle_invoice_paid(event_data) 

83 

84 elif event_type == "invoice.payment_failed": 

85 return self._handle_invoice_failed(event_data) 

86 

87 elif event_type == "charge.dispute.created": 

88 return self._handle_dispute_created(event_data) 

89 

90 else: 

91 logger.info(f"Unhandled event type: {event_type}") 

92 return {"message": f"Event {event_type} received but not processed"} 

93 

94 def process_eventbridge_event(self, event: dict[str, Any]) -> dict[str, Any]: 

95 """Process a Stripe event delivered via EventBridge. 

96 

97 EventBridge events are pre-verified by AWS, so no signature 

98 verification is needed. The Stripe event is in event["detail"]. 

99 Falls back to parsing event["body"] for direct webhook calls. 

100 """ 

101 detail = event.get("detail", {}) 

102 

103 if detail and "type" in detail: 

104 return self.process_event(detail) 

105 

106 # Fallback: direct webhook call with body 

107 body = event.get("body", "") 

108 if event.get("isBase64Encoded") and body: 

109 body = base64.b64decode(body).decode("utf-8") 

110 

111 if body: 

112 stripe_event = json.loads(body) 

113 return self.process_event(stripe_event) 

114 

115 logger.warning("EventBridge event has no detail or body") 

116 return {"message": "No event data found"} 

117 

118 def _handle_checkout_completed(self, session: dict[str, Any]) -> dict[str, Any]: 

119 """Handle successful checkout session for credit purchase.""" 

120 metadata = session.get("metadata", {}) 

121 user_id = metadata.get("user_id") 

122 

123 if not user_id: 

124 logger.error("No user_id in checkout session metadata") 

125 return {"error": "Missing user_id"} 

126 

127 if session.get("mode") == "payment": 

128 credits = int(metadata.get("credits", 0)) 

129 

130 if credits > 0: 

131 new_balance = self.credit_manager.add_credits(user_id, credits) 

132 logger.info( 

133 f"Added {credits} credits to user {user_id}, new balance: {new_balance}" 

134 ) 

135 return {"credits_added": credits, "new_balance": new_balance} 

136 

137 return {"message": "Checkout processed"} 

138 

139 def _handle_subscription_created(self, subscription: dict[str, Any]) -> dict[str, Any]: 

140 """Handle new subscription creation with initial credit grant.""" 

141 metadata = subscription.get("metadata", {}) 

142 user_id = metadata.get("user_id") 

143 customer_id = subscription.get("customer") 

144 subscription_id = subscription.get("id") 

145 status = str(subscription.get("status", "unknown")) 

146 

147 if not user_id: 

148 logger.warning("No user_id in subscription metadata") 

149 return {"subscription_id": subscription_id, "status": status} 

150 

151 self.credit_manager.set_subscription_state( 

152 user_id=user_id, 

153 status=status, 

154 stripe_customer_id=customer_id, 

155 stripe_subscription_id=subscription_id, 

156 ) 

157 

158 # Grant initial credits 

159 new_balance = self.credit_manager.add_credits(user_id, INITIAL_SUBSCRIPTION_CREDITS) 

160 logger.info( 

161 f"Created subscription {subscription_id} for user {user_id}, " 

162 f"granted {INITIAL_SUBSCRIPTION_CREDITS} credits, balance: {new_balance}" 

163 ) 

164 

165 # Set initial last_credited_period 

166 current_period_start = subscription.get("current_period_start") 

167 if current_period_start and self.credit_manager.table: 

168 try: 

169 self.credit_manager.table.update_item( 

170 Key={"user_id": user_id}, 

171 UpdateExpression="SET last_credited_period = :period", 

172 ExpressionAttributeValues={ 

173 ":period": Decimal(str(current_period_start)), 

174 }, 

175 ) 

176 except Exception as e: 

177 logger.error(f"Error setting last_credited_period for {user_id}: {e}") 

178 

179 return { 

180 "subscription_id": subscription_id, 

181 "status": status, 

182 "credits_added": INITIAL_SUBSCRIPTION_CREDITS, 

183 "new_balance": new_balance, 

184 } 

185 

186 def _handle_subscription_updated(self, subscription: dict[str, Any]) -> dict[str, Any]: 

187 """Handle subscription updates with renewal credit logic. 

188 

189 When a billing period changes (detected via current_period_start), 

190 grants renewal credits using conditional update to prevent duplicates. 

191 """ 

192 metadata = subscription.get("metadata", {}) 

193 user_id = metadata.get("user_id") 

194 subscription_id = subscription.get("id") 

195 status = str(subscription.get("status", "unknown")) 

196 

197 if not user_id: 

198 return {"subscription_id": subscription_id, "status": status} 

199 

200 # Always update subscription status 

201 self.credit_manager.set_subscription_state( 

202 user_id=user_id, status=status, stripe_subscription_id=subscription_id 

203 ) 

204 

205 credits_added = 0 

206 current_period_start = subscription.get("current_period_start") 

207 

208 # Check for renewal credit grant 

209 if status == "active" and current_period_start and self.credit_manager.table: 

210 credits_added = self._try_renewal_credit_grant(user_id, current_period_start) 

211 

212 logger.info( 

213 f"Updated subscription {subscription_id} status to {status}" 

214 + (f", granted {credits_added} renewal credits" if credits_added else "") 

215 ) 

216 

217 result: dict[str, Any] = { 

218 "subscription_id": subscription_id, 

219 "status": status, 

220 } 

221 if credits_added: 

222 result["credits_added"] = credits_added 

223 return result 

224 

225 def _try_renewal_credit_grant(self, user_id: str, current_period_start: int) -> int: 

226 """Attempt to grant renewal credits for a new billing period. 

227 

228 Uses conditional update on last_credited_period to prevent 

229 duplicate grants from multiple webhook deliveries. 

230 

231 Returns the number of credits granted (0 if already processed). 

232 """ 

233 if not self.credit_manager.table: 

234 return 0 

235 

236 try: 

237 # Read current last_credited_period 

238 response = self.credit_manager.table.get_item(Key={"user_id": user_id}) 

239 item = response.get("Item", {}) 

240 last_period = item.get("last_credited_period") 

241 

242 period_decimal = Decimal(str(current_period_start)) 

243 

244 # Check if this is a new period 

245 if last_period is not None and period_decimal <= Decimal(str(last_period)): 

246 logger.info(f"Period {current_period_start} already credited for {user_id}") 

247 return 0 

248 

249 # Conditional update to prevent race conditions 

250 if last_period is None: 

251 condition = "attribute_not_exists(last_credited_period)" 

252 expr_values = {":new_period": period_decimal} 

253 else: 

254 condition = "last_credited_period = :prev_period" 

255 expr_values = { 

256 ":new_period": period_decimal, 

257 ":prev_period": Decimal(str(last_period)), 

258 } 

259 

260 self.credit_manager.table.update_item( 

261 Key={"user_id": user_id}, 

262 UpdateExpression="SET last_credited_period = :new_period", 

263 ConditionExpression=condition, 

264 ExpressionAttributeValues=expr_values, 

265 ) 

266 

267 # Grant renewal credits 

268 self.credit_manager.add_credits(user_id, RENEWAL_CREDITS) 

269 logger.info( 

270 f"Granted {RENEWAL_CREDITS} renewal credits to {user_id} " 

271 f"for period {current_period_start}" 

272 ) 

273 return RENEWAL_CREDITS 

274 

275 except ClientError as e: 

276 if e.response["Error"]["Code"] == "ConditionalCheckFailedException": 

277 logger.info( 

278 f"Renewal credits already granted for period " 

279 f"{current_period_start} for {user_id} (race condition)" 

280 ) 

281 return 0 

282 logger.error(f"Error granting renewal credits for {user_id}: {e}") 

283 raise 

284 except Exception as e: 

285 logger.error(f"Error in renewal credit grant for {user_id}: {e}") 

286 return 0 

287 

288 def _handle_subscription_deleted(self, subscription: dict[str, Any]) -> dict[str, Any]: 

289 """Handle subscription cancellation with timestamp.""" 

290 metadata = subscription.get("metadata", {}) 

291 user_id = metadata.get("user_id") 

292 subscription_id = subscription.get("id") 

293 

294 if user_id: 

295 self.credit_manager.set_subscription_state( 

296 user_id=user_id, status="cancelled", stripe_subscription_id=subscription_id 

297 ) 

298 

299 # Set cancellation timestamp 

300 if self.credit_manager.table: 

301 try: 

302 self.credit_manager.table.update_item( 

303 Key={"user_id": user_id}, 

304 UpdateExpression="SET subscription_cancelled_at = :now", 

305 ExpressionAttributeValues={ 

306 ":now": datetime.now(UTC).isoformat(), 

307 }, 

308 ) 

309 except Exception as e: 

310 logger.error(f"Error setting cancelled_at for {user_id}: {e}") 

311 

312 logger.info(f"Cancelled subscription {subscription_id} for user {user_id}") 

313 

314 return {"subscription_id": subscription_id, "status": "cancelled"} 

315 

316 def _handle_invoice_paid(self, invoice: dict[str, Any]) -> dict[str, Any]: 

317 """Handle successful subscription payment.""" 

318 customer_id = invoice.get("customer") 

319 amount = invoice.get("amount_paid", 0) / 100.0 

320 logger.info(f"Invoice paid: ${amount} from customer {customer_id}") 

321 return {"amount_paid": amount} 

322 

323 def _handle_invoice_failed(self, invoice: dict[str, Any]) -> dict[str, Any]: 

324 """Handle failed subscription payment.""" 

325 customer_id = invoice.get("customer") 

326 logger.warning(f"Invoice payment failed for customer {customer_id}") 

327 return {"status": "payment_failed"} 

328 

329 def _handle_payment_intent_succeeded(self, payment_intent: dict[str, Any]) -> dict[str, Any]: 

330 """Handle successful payment intent with idempotent credit grant.""" 

331 metadata = payment_intent.get("metadata", {}) 

332 user_id = metadata.get("user_id") 

333 

334 if not user_id: 

335 logger.error("No user_id in payment_intent metadata") 

336 return {"error": "Missing user_id"} 

337 

338 # Check if this is a verification charge ($1) 

339 if metadata.get("type") == "verification": 

340 logger.info(f"Verification charge completed for user {user_id}") 

341 return {"type": "verification", "status": "completed"} 

342 

343 # Get credits from metadata (set during payment creation) 

344 credits = int(metadata.get("credits", 0)) 

345 payment_intent_id = payment_intent.get("id", "") 

346 

347 if credits > 0 and payment_intent_id: 

348 granted = self.credit_manager.idempotent_credit_grant( 

349 user_id, credits, payment_intent_id 

350 ) 

351 if granted: 

352 new_balance = self.credit_manager.get_balance(user_id) 

353 logger.info( 

354 f"Added {credits} credits to user {user_id}, new balance: {new_balance}" 

355 ) 

356 return {"credits_added": credits, "new_balance": new_balance} 

357 else: 

358 logger.info(f"Payment {payment_intent_id} already processed for {user_id}") 

359 return {"message": "Payment already processed, credits not added"} 

360 

361 return {"message": "Payment processed"} 

362 

363 def _handle_dispute_created(self, dispute: dict[str, Any]) -> dict[str, Any]: 

364 """Handle charge dispute (mark account as disputed).""" 

365 charge_id = dispute.get("charge") 

366 

367 if not charge_id: 

368 logger.error("No charge_id in dispute") 

369 return {"error": "Missing charge_id"} 

370 

371 amount = dispute.get("amount", 0) / 100.0 

372 reason = dispute.get("reason", "unknown") 

373 

374 logger.warning(f"Dispute created for charge {charge_id}: ${amount}, reason: {reason}") 

375 

376 return {"dispute_id": dispute.get("id"), "status": "created", "amount": amount}