Coverage for src / ai_lls_lib / apikeys / managed_key_service.py: 87%

116 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-05-06 23:45 +0000

1"""Managed API key CRUD operations with DynamoDB.""" 

2 

3import logging 

4import os 

5from datetime import UTC, datetime, timedelta 

6from typing import TYPE_CHECKING, Any 

7 

8try: 

9 import boto3 

10 

11 HAS_BOTO3 = True 

12except ImportError: 

13 boto3 = None # type: ignore[assignment] 

14 HAS_BOTO3 = False 

15 

16if TYPE_CHECKING: 

17 from mypy_boto3_dynamodb.service_resource import Table 

18 

19from ai_lls_lib.key_management import ( 

20 compute_key_hash, 

21 generate_key_id, 

22 generate_managed_key, 

23 validate_expiration_days, 

24) 

25 

26logger = logging.getLogger(__name__) 

27 

28MAX_ACTIVE_KEYS = 10 

29MAX_LABEL_LENGTH = 64 

30REVOKE_TTL_DAYS = 30 

31 

32 

33class KeyNotFoundError(ValueError): 

34 """Raised when a managed API key is not found.""" 

35 

36 

37class RevokedKeyError(ValueError): 

38 """Raised when attempting to modify a revoked key.""" 

39 

40 

41class LimitExceededError(ValueError): 

42 """Raised when the active key limit is reached.""" 

43 

44 

45class ManagedApiKeyService: 

46 """Manages user API keys with CRUD operations in DynamoDB. 

47 

48 DynamoDB table schema: 

49 - Hash key: user_id (S) 

50 - Range key: key_id (S) 

51 """ 

52 

53 table: "Table | None" 

54 

55 def __init__(self, table_name: str | None = None): 

56 """Initialize with DynamoDB table.""" 

57 if not HAS_BOTO3 or not boto3: 

58 raise RuntimeError("boto3 is required for ManagedApiKeyService") 

59 

60 self.dynamodb = boto3.resource("dynamodb") 

61 self.table_name = table_name if table_name else os.environ["MANAGED_API_KEYS_TABLE"] 

62 

63 try: 

64 self.table = self.dynamodb.Table(self.table_name) 

65 except Exception as e: 

66 logger.error(f"Failed to connect to DynamoDB table {self.table_name}: {e}") 

67 self.table = None 

68 

69 def _get_key(self, user_id: str, key_id: str) -> dict[str, Any]: 

70 """Fetch a key item, raising if not found or revoked.""" 

71 if not self.table: 

72 raise RuntimeError(f"DynamoDB table {self.table_name} not accessible") 

73 

74 response = self.table.get_item(Key={"user_id": user_id, "key_id": key_id}) 

75 item = response.get("Item") 

76 if not item: 

77 raise KeyNotFoundError(f"Key {key_id} not found for user {user_id}") 

78 if item.get("status") == "revoked": 

79 raise RevokedKeyError(f"Key {key_id} is revoked") 

80 return item 

81 

82 def list_keys(self, user_id: str) -> list[dict[str, Any]]: 

83 """List all API keys for a user, sorted by created_at descending. 

84 

85 Returns projected fields only (excludes key_hash). 

86 """ 

87 if not self.table: 

88 raise RuntimeError(f"DynamoDB table {self.table_name} not accessible") 

89 

90 response = self.table.query( 

91 KeyConditionExpression="user_id = :uid", 

92 ExpressionAttributeValues={":uid": user_id}, 

93 ) 

94 items = response.get("Items", []) 

95 

96 result = [] 

97 for item in items: 

98 result.append( 

99 { 

100 "key_id": item["key_id"], 

101 "key_last4": item.get("key_last4", ""), 

102 "label": item.get("label", ""), 

103 "status": item.get("status", "active"), 

104 "created_at": item.get("created_at", ""), 

105 "expires_at": item.get("expires_at"), 

106 "last_used_at": item.get("last_used_at"), 

107 } 

108 ) 

109 

110 result.sort(key=lambda x: str(x.get("created_at", "")), reverse=True) 

111 return result 

112 

113 def create_key(self, user_id: str, label: str, expires_in_days: int = 365) -> dict[str, Any]: 

114 """Create a new managed API key. 

115 

116 Returns the key_id and plaintext key (only time key is returned). 

117 """ 

118 if not self.table: 

119 raise RuntimeError(f"DynamoDB table {self.table_name} not accessible") 

120 

121 # Validate label 

122 label = label.strip() 

123 if not label or len(label) > MAX_LABEL_LENGTH: 

124 raise ValueError(f"Label must be 1-{MAX_LABEL_LENGTH} characters, got {len(label)}") 

125 

126 # Validate expiration 

127 if not validate_expiration_days(expires_in_days): 

128 raise ValueError(f"Expiration must be 1-730 days, got {expires_in_days}") 

129 

130 # Check active key count 

131 existing = self.list_keys(user_id) 

132 active_count = sum(1 for k in existing if k["status"] != "revoked") 

133 if active_count >= MAX_ACTIVE_KEYS: 

134 raise LimitExceededError(f"Maximum of {MAX_ACTIVE_KEYS} active keys reached") 

135 

136 # Generate key 

137 key_id = generate_key_id() 

138 plaintext_key = generate_managed_key() 

139 key_hash = compute_key_hash(plaintext_key) 

140 now = datetime.now(UTC).isoformat() 

141 expires_at = (datetime.now(UTC) + timedelta(days=expires_in_days)).isoformat() 

142 

143 self.table.put_item( 

144 Item={ 

145 "user_id": user_id, 

146 "key_id": key_id, 

147 "key_hash": key_hash, 

148 "key_last4": plaintext_key[-4:], 

149 "label": label, 

150 "status": "active", 

151 "created_at": now, 

152 "expires_at": expires_at, 

153 "last_used_at": None, 

154 "usage_count": 0, 

155 } 

156 ) 

157 

158 logger.info(f"Created managed key {key_id} for user {user_id}") 

159 return { 

160 "key_id": key_id, 

161 "api_key": plaintext_key, 

162 "label": label, 

163 "expires_at": expires_at, 

164 } 

165 

166 def update_key( 

167 self, 

168 user_id: str, 

169 key_id: str, 

170 label: str | None = None, 

171 expires_in_days: int | None = None, 

172 ) -> dict[str, Any]: 

173 """Update key label and/or expiration.""" 

174 if label is None and expires_in_days is None: 

175 raise ValueError("At least one of label or expires_in_days must be provided") 

176 

177 # This will raise KeyNotFoundError or RevokedKeyError 

178 self._get_key(user_id, key_id) 

179 

180 update_parts = ["SET updated_at = :now"] 

181 expr_values: dict[str, Any] = {":now": datetime.now(UTC).isoformat()} 

182 

183 if label is not None: 

184 label = label.strip() 

185 if not label or len(label) > MAX_LABEL_LENGTH: 

186 raise ValueError(f"Label must be 1-{MAX_LABEL_LENGTH} characters") 

187 update_parts.append("label = :label") 

188 expr_values[":label"] = label 

189 

190 if expires_in_days is not None: 

191 if not validate_expiration_days(expires_in_days): 

192 raise ValueError(f"Expiration must be 1-730 days, got {expires_in_days}") 

193 expires_at = (datetime.now(UTC) + timedelta(days=expires_in_days)).isoformat() 

194 update_parts.append("expires_at = :expires_at") 

195 expr_values[":expires_at"] = expires_at 

196 

197 update_expr = update_parts[0] 

198 if len(update_parts) > 1: 

199 update_expr += ", " + ", ".join(update_parts[1:]) 

200 

201 if not self.table: 

202 raise RuntimeError(f"DynamoDB table {self.table_name} not accessible") 

203 

204 self.table.update_item( 

205 Key={"user_id": user_id, "key_id": key_id}, 

206 UpdateExpression=update_expr, 

207 ExpressionAttributeValues=expr_values, 

208 ) 

209 

210 logger.info(f"Updated managed key {key_id} for user {user_id}") 

211 return {"message": "Key updated"} 

212 

213 def rotate_key(self, user_id: str, key_id: str) -> dict[str, Any]: 

214 """Generate a new key value while keeping the same key_id. 

215 

216 Returns the new plaintext key (only time it's returned). 

217 """ 

218 # This will raise KeyNotFoundError or RevokedKeyError 

219 self._get_key(user_id, key_id) 

220 

221 plaintext_key = generate_managed_key() 

222 key_hash = compute_key_hash(plaintext_key) 

223 now = datetime.now(UTC).isoformat() 

224 

225 if not self.table: 

226 raise RuntimeError(f"DynamoDB table {self.table_name} not accessible") 

227 

228 self.table.update_item( 

229 Key={"user_id": user_id, "key_id": key_id}, 

230 UpdateExpression=("SET key_hash = :hash, key_last4 = :last4, updated_at = :now"), 

231 ExpressionAttributeValues={ 

232 ":hash": key_hash, 

233 ":last4": plaintext_key[-4:], 

234 ":now": now, 

235 }, 

236 ) 

237 

238 logger.info(f"Rotated managed key {key_id} for user {user_id}") 

239 return { 

240 "key_id": key_id, 

241 "api_key": plaintext_key, 

242 "label": "", 

243 "expires_at": "", 

244 } 

245 

246 def revoke_key(self, user_id: str, key_id: str) -> None: 

247 """Mark a key as revoked with TTL for automatic cleanup.""" 

248 if not self.table: 

249 raise RuntimeError(f"DynamoDB table {self.table_name} not accessible") 

250 

251 # Check key exists (but allow revoking already-revoked keys) 

252 response = self.table.get_item(Key={"user_id": user_id, "key_id": key_id}) 

253 if not response.get("Item"): 

254 raise KeyNotFoundError(f"Key {key_id} not found for user {user_id}") 

255 

256 now = datetime.now(UTC) 

257 ttl = int((now + timedelta(days=REVOKE_TTL_DAYS)).timestamp()) 

258 

259 self.table.update_item( 

260 Key={"user_id": user_id, "key_id": key_id}, 

261 UpdateExpression=("SET #s = :revoked, revoked_at = :now, #ttl = :ttl"), 

262 ExpressionAttributeNames={"#s": "status", "#ttl": "ttl"}, 

263 ExpressionAttributeValues={ 

264 ":revoked": "revoked", 

265 ":now": now.isoformat(), 

266 ":ttl": ttl, 

267 }, 

268 ) 

269 

270 logger.info(f"Revoked managed key {key_id} for user {user_id}")